aboutsummaryrefslogtreecommitdiff
path: root/losses.py
diff options
context:
space:
mode:
authorLogan Bauman <logan_bauman@brown.edu>2022-05-04 15:05:30 -0400
committerLogan Bauman <logan_bauman@brown.edu>2022-05-04 15:05:30 -0400
commit9608ec6a7bdf73d9d0d7fe406c575eb209cf50e0 (patch)
tree77c682c3ea238bb5320306afa0e47fd79e7ff745 /losses.py
parentd19f9ab05c189ce0cdc9271669d61b5f0e5db0fb (diff)
finish total loss
Diffstat (limited to 'losses.py')
-rw-r--r--losses.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/losses.py b/losses.py
index 6ebba671..caad3067 100644
--- a/losses.py
+++ b/losses.py
@@ -88,9 +88,14 @@ class YourModel(tf.keras.Model):
_, art_layers = self.forward_pass(a)
_, input_layers = self.forward_pass(x)
+ content_l = self.content_loss(photo_layers, input_layers)
+ style_l = self.style_loss(art_layers, input_layers)
+ # Equation 7
+ return (self.alpha * content_l) + (self.beta * style_l)
+
- def content_loss(photo_layers, input_layers):
+ def content_loss(self, photo_layers, input_layers):
L_content = tf.reduce_mean(tf.square(photo_layers - input_layers))
return L_content