aboutsummaryrefslogtreecommitdiff
path: root/losses.py
diff options
context:
space:
mode:
authorBenjamin Fiske <bffiske@gmail.com>2022-05-04 23:17:41 -0400
committerBenjamin Fiske <bffiske@gmail.com>2022-05-04 23:17:41 -0400
commit9d87471579c80d1c8baff6711c1297dec8f0dcf4 (patch)
treedd8fc19cf4081cf265baaec9d7da85ffebc498e1 /losses.py
parent52674666126214aaa8dd54b6bf3e488823538e1a (diff)
adding vgg weights
Diffstat (limited to 'losses.py')
-rw-r--r--losses.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/losses.py b/losses.py
index bc241282..c0989ed1 100644
--- a/losses.py
+++ b/losses.py
@@ -64,15 +64,17 @@ class YourModel(tf.keras.Model):
for layer in self.vgg16:
layer.trainable = False
+ self.layer_to_filters = {layer.name: layer.filters for layer in self.vgg16 if "conv" in layer.name}
self.indexed_layers = [layer for layer in self.vgg16 if "conv1" in layer.name]
self.desired = [layer.name for layer in self.vgg16 if "conv1" in layer.name]
+ self.vgg16 = tf.keras.Sequential(self.vgg16, name="vgg")
+
# create a map of the layers to their corresponding number of filters if it is a convolutional layer
- self.layer_to_filters = {layer.name: layer.filters for layer in self.vgg16 if "conv" in layer.name}
def call(self, x):
layers = []
- for layer in self.vgg16:
+ for layer in self.vgg16.layers:
# pass the x through
x = layer(x)
# print("Sotech117 is so so sus")
@@ -159,6 +161,3 @@ class YourModel(tf.keras.Model):
print(type(self.x))
print(type(gradients))
self.optimizer.apply_gradients(zip(gradients, [self.x]))
-
-
-