aboutsummaryrefslogtreecommitdiff
path: root/losses.py
diff options
context:
space:
mode:
authorBenjamin Fiske <bffiske@gmail.com>2022-05-04 15:40:04 -0400
committerBenjamin Fiske <bffiske@gmail.com>2022-05-04 15:40:04 -0400
commitdf0a9240bac34d2bda0d3c7c836dbce2ce781344 (patch)
tree4c76febae25c7deba20281406fe69ce1ad6dd4d8 /losses.py
parent137d14b2ceffa95407cacfc82d1222cb9ef35072 (diff)
backprop and main
Diffstat (limited to 'losses.py')
-rw-r--r--losses.py27
1 files changed, 11 insertions, 16 deletions
diff --git a/losses.py b/losses.py
index 36dc91cc..fd68e199 100644
--- a/losses.py
+++ b/losses.py
@@ -13,12 +13,14 @@ class YourModel(tf.keras.Model):
super(YourModel, self).__init__()
self.content_image = content_image
+
+ #perhaps consider cropping to avoid distortion
self.style_image = transform.resize(style_image, np.shape(style_image), anti_aliasing=True)
self.x = tf.Variable(tf.random.uniform(np.shape(content_image)), trainable=True)
self.alpha = 1
self.beta = 1
- self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=1e-4, momentum=0.01)
+ self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=hp.learning_rate, momentum=hp.momentum)
self.vgg16 = [
# Block 1
@@ -61,18 +63,6 @@ class YourModel(tf.keras.Model):
for layer in self.vgg16:
layer.trainable = False
- self.head = [
- # Dropout(.2),
- # Dense(256, activation='silu'),
- # Dense(512, activation='silu'),
- # Dropout(.3),
- # tf.keras.layers.GlobalAveragePooling2D(),
- # Dense(15, activation='softmax')
- ]
-
- # self.vgg16 = tf.keras.Sequential(self.vgg16, name="vgg_base")
- # self.head = tf.keras.Sequential(self.head, name="vgg_head")
-
self.indexed_layers = [layer for layer in self.vgg16 if layer.name == "conv1"]
self.desired = [layer.name for layer in self.vgg16 if layer.name == "conv1"]
@@ -89,7 +79,6 @@ class YourModel(tf.keras.Model):
return x, np.array(layers)
-
def loss_fn(self, p, a, x):
_, photo_layers = self.call(p)
_, art_layers = self.call(a)
@@ -99,8 +88,6 @@ class YourModel(tf.keras.Model):
style_l = self.style_loss(art_layers, input_layers)
# Equation 7
return (self.alpha * content_l) + (self.beta * style_l)
-
-
def content_loss(self, photo_layers, input_layers):
L_content = tf.reduce_mean(tf.square(photo_layers - input_layers))
@@ -138,3 +125,11 @@ class YourModel(tf.keras.Model):
L_style += self.layer_loss(art_layers, input_layers, layer)
return L_style
+ def train_step(self):
+ with tf.GradientTape as tape:
+ loss = self.loss_fn(self.content_image, self.style_image, self.x)
+ gradients = tape.gradient(loss, self.x)
+ self.optimizer.apply_gradients(zip(gradients, self.x))
+
+
+