diff options
Diffstat (limited to 'losses.py')
-rw-r--r-- | losses.py | 11 |
1 files changed, 9 insertions, 2 deletions
@@ -3,13 +3,18 @@ from tensorflow.keras.layers import \ Conv2D, MaxPool2D, Dropout, Flatten, Dense, AveragePool2D import numpy as np + +from skimage import transform import hyperparameters as hp class YourModel(tf.keras.Model): """ Your own neural network model. """ - def __init__(self): + def __init__(self, content_image, style_image): #normalize these images to float values super(YourModel, self).__init__() - + + self.content_image = content_image + self.style_image = transform.resize(style_image, np.shape(style_image), anti_aliasing=True) + self.x = tf.Variable(tf.random.uniform(np.shape(content_image)), trainable=True) self.alpha = 1 self.beta = 1 @@ -53,6 +58,8 @@ class YourModel(tf.keras.Model): activation="relu", name="block5_conv3"), AveragePool2D(2, name="block5_pool"), ] + for layer in self.vgg16: + layer.trainable = False self.head = [ # Dropout(.2), |