aboutsummaryrefslogtreecommitdiff
path: root/losses.py
diff options
context:
space:
mode:
Diffstat (limited to 'losses.py')
-rw-r--r--losses.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/losses.py b/losses.py
index 93449962..6ebba671 100644
--- a/losses.py
+++ b/losses.py
@@ -63,13 +63,13 @@ class YourModel(tf.keras.Model):
# Dense(15, activation='softmax')
]
- self.vgg16 = tf.keras.Sequential(self.vgg16, name="vgg_base")
- self.head = tf.keras.Sequential(self.head, name="vgg_head")
+ # self.vgg16 = tf.keras.Sequential(self.vgg16, name="vgg_base")
+ # self.head = tf.keras.Sequential(self.head, name="vgg_head")
- self.indexed_layers = [layer for layer in self.vgg16 if layer.name.contains("conv1")]
- self.desired = [layer.name for layer in self.vgg16 if layer.name.contains("conv1")]
+ self.indexed_layers = [layer for layer in self.vgg16 if layer.name == "conv1"]
+ self.desired = [layer.name for layer in self.vgg16 if layer.name == "conv1"]
- def forward_pass(self, x):
+ def call(self, x):
layers = []
for layer in self.vgg16.layers:
# pass the x through
@@ -83,7 +83,7 @@ class YourModel(tf.keras.Model):
return x, np.array(layers)
- def loss_function(self, p, a, x):
+ def loss_fn(self, p, a, x):
_, photo_layers = self.forward_pass(p)
_, art_layers = self.forward_pass(a)
_, input_layers = self.forward_pass(x)