aboutsummaryrefslogtreecommitdiff
path: root/losses.py
diff options
context:
space:
mode:
authorBenjamin Fiske <bffiske@gmail.com>2022-05-04 15:23:51 -0400
committerBenjamin Fiske <bffiske@gmail.com>2022-05-04 15:23:51 -0400
commit137d14b2ceffa95407cacfc82d1222cb9ef35072 (patch)
tree198f9918018f2147bb0668e3d9c5311ddc2df054 /losses.py
parent10edad30706e0c6f11c0e48563e0b6d5c5bbca64 (diff)
parent4ad25cde30edf19f6bd128d6f54c6396cab49773 (diff)
Merge branch 'main' of https://github.com/bfiske/cv-cartoon
Diffstat (limited to 'losses.py')
-rw-r--r--losses.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/losses.py b/losses.py
index f4a4f757..36dc91cc 100644
--- a/losses.py
+++ b/losses.py
@@ -3,13 +3,18 @@ from tensorflow.keras.layers import \
Conv2D, MaxPool2D, Dropout, Flatten, Dense, AveragePool2D
import numpy as np
+
+from skimage import transform
import hyperparameters as hp
class YourModel(tf.keras.Model):
""" Your own neural network model. """
- def __init__(self):
+ def __init__(self, content_image, style_image): #normalize these images to float values
super(YourModel, self).__init__()
-
+
+ self.content_image = content_image
+ self.style_image = transform.resize(style_image, np.shape(style_image), anti_aliasing=True)
+ self.x = tf.Variable(tf.random.uniform(np.shape(content_image)), trainable=True)
self.alpha = 1
self.beta = 1
@@ -53,6 +58,8 @@ class YourModel(tf.keras.Model):
activation="relu", name="block5_conv3"),
AveragePool2D(2, name="block5_pool"),
]
+ for layer in self.vgg16:
+ layer.trainable = False
self.head = [
# Dropout(.2),