diff options
author | Benjamin Fiske <bffiske@gmail.com> | 2022-05-04 15:23:51 -0400 |
---|---|---|
committer | Benjamin Fiske <bffiske@gmail.com> | 2022-05-04 15:23:51 -0400 |
commit | 137d14b2ceffa95407cacfc82d1222cb9ef35072 (patch) | |
tree | 198f9918018f2147bb0668e3d9c5311ddc2df054 /losses.py | |
parent | 10edad30706e0c6f11c0e48563e0b6d5c5bbca64 (diff) | |
parent | 4ad25cde30edf19f6bd128d6f54c6396cab49773 (diff) |
Merge branch 'main' of https://github.com/bfiske/cv-cartoon
Diffstat (limited to 'losses.py')
-rw-r--r-- | losses.py | 11 |
1 files changed, 9 insertions, 2 deletions
@@ -3,13 +3,18 @@ from tensorflow.keras.layers import \ Conv2D, MaxPool2D, Dropout, Flatten, Dense, AveragePool2D import numpy as np + +from skimage import transform import hyperparameters as hp class YourModel(tf.keras.Model): """ Your own neural network model. """ - def __init__(self): + def __init__(self, content_image, style_image): #normalize these images to float values super(YourModel, self).__init__() - + + self.content_image = content_image + self.style_image = transform.resize(style_image, np.shape(style_image), anti_aliasing=True) + self.x = tf.Variable(tf.random.uniform(np.shape(content_image)), trainable=True) self.alpha = 1 self.beta = 1 @@ -53,6 +58,8 @@ class YourModel(tf.keras.Model): activation="relu", name="block5_conv3"), AveragePool2D(2, name="block5_pool"), ] + for layer in self.vgg16: + layer.trainable = False self.head = [ # Dropout(.2), |