diff options
Diffstat (limited to 'losses.py')
-rw-r--r-- | losses.py | 12 |
1 files changed, 6 insertions, 6 deletions
@@ -63,13 +63,13 @@ class YourModel(tf.keras.Model): # Dense(15, activation='softmax') ] - self.vgg16 = tf.keras.Sequential(self.vgg16, name="vgg_base") - self.head = tf.keras.Sequential(self.head, name="vgg_head") + # self.vgg16 = tf.keras.Sequential(self.vgg16, name="vgg_base") + # self.head = tf.keras.Sequential(self.head, name="vgg_head") - self.indexed_layers = [layer for layer in self.vgg16 if layer.name.contains("conv1")] - self.desired = [layer.name for layer in self.vgg16 if layer.name.contains("conv1")] + self.indexed_layers = [layer for layer in self.vgg16 if layer.name == "conv1"] + self.desired = [layer.name for layer in self.vgg16 if layer.name == "conv1"] - def forward_pass(self, x): + def call(self, x): layers = [] for layer in self.vgg16.layers: # pass the x through @@ -83,7 +83,7 @@ class YourModel(tf.keras.Model): return x, np.array(layers) - def loss_function(self, p, a, x): + def loss_fn(self, p, a, x): _, photo_layers = self.forward_pass(p) _, art_layers = self.forward_pass(a) _, input_layers = self.forward_pass(x) |