aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--losses.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/losses.py b/losses.py
index 6ebba671..caad3067 100644
--- a/losses.py
+++ b/losses.py
@@ -88,9 +88,14 @@ class YourModel(tf.keras.Model):
_, art_layers = self.forward_pass(a)
_, input_layers = self.forward_pass(x)
+ content_l = self.content_loss(photo_layers, input_layers)
+ style_l = self.style_loss(art_layers, input_layers)
+ # Equation 7
+ return (self.alpha * content_l) + (self.beta * style_l)
+
- def content_loss(photo_layers, input_layers):
+ def content_loss(self, photo_layers, input_layers):
L_content = tf.reduce_mean(tf.square(photo_layers - input_layers))
return L_content