diff options
Diffstat (limited to 'losses.py')
-rw-r--r-- | losses.py | 27 |
1 files changed, 11 insertions, 16 deletions
@@ -13,12 +13,14 @@ class YourModel(tf.keras.Model): super(YourModel, self).__init__() self.content_image = content_image + + #perhaps consider cropping to avoid distortion self.style_image = transform.resize(style_image, np.shape(style_image), anti_aliasing=True) self.x = tf.Variable(tf.random.uniform(np.shape(content_image)), trainable=True) self.alpha = 1 self.beta = 1 - self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=1e-4, momentum=0.01) + self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=hp.learning_rate, momentum=hp.momentum) self.vgg16 = [ # Block 1 @@ -61,18 +63,6 @@ class YourModel(tf.keras.Model): for layer in self.vgg16: layer.trainable = False - self.head = [ - # Dropout(.2), - # Dense(256, activation='silu'), - # Dense(512, activation='silu'), - # Dropout(.3), - # tf.keras.layers.GlobalAveragePooling2D(), - # Dense(15, activation='softmax') - ] - - # self.vgg16 = tf.keras.Sequential(self.vgg16, name="vgg_base") - # self.head = tf.keras.Sequential(self.head, name="vgg_head") - self.indexed_layers = [layer for layer in self.vgg16 if layer.name == "conv1"] self.desired = [layer.name for layer in self.vgg16 if layer.name == "conv1"] @@ -89,7 +79,6 @@ class YourModel(tf.keras.Model): return x, np.array(layers) - def loss_fn(self, p, a, x): _, photo_layers = self.call(p) _, art_layers = self.call(a) @@ -99,8 +88,6 @@ class YourModel(tf.keras.Model): style_l = self.style_loss(art_layers, input_layers) # Equation 7 return (self.alpha * content_l) + (self.beta * style_l) - - def content_loss(self, photo_layers, input_layers): L_content = tf.reduce_mean(tf.square(photo_layers - input_layers)) @@ -138,3 +125,11 @@ class YourModel(tf.keras.Model): L_style += self.layer_loss(art_layers, input_layers, layer) return L_style + def train_step(self): + with tf.GradientTape as tape: + loss = self.loss_fn(self.content_image, self.style_image, self.x) + gradients = tape.gradient(loss, self.x) + self.optimizer.apply_gradients(zip(gradients, self.x)) + + + |