aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
authorDavid Doan <daviddoan@davids-mbp-3.devices.brown.edu>2022-05-09 00:25:39 -0400
committerDavid Doan <daviddoan@davids-mbp-3.devices.brown.edu>2022-05-09 00:25:39 -0400
commita0870ac3f1f84278c5b9fe7f78f6b1af1d1f33e9 (patch)
tree440f9c17f23042e32c7c495c49d36bb838fcd73a /main.py
parent18f1f7bddcb63502120581f3fa24b980559ffa9f (diff)
clean and refactor code for submission
Diffstat (limited to 'main.py')
-rw-r--r--main.py56
1 files changed, 31 insertions, 25 deletions
diff --git a/main.py b/main.py
index c36658f3..b4c1b228 100644
--- a/main.py
+++ b/main.py
@@ -5,6 +5,8 @@ import argparse
import tensorflow as tf
from skimage import transform
+import PIL.Image
+
import hyperparameters as hp
from losses import YourModel
# from tensorboard_utils import \
@@ -13,7 +15,6 @@ from losses import YourModel
from skimage.io import imread, imsave
from matplotlib import pyplot as plt
import numpy as np
-from skimage import transform
def parse_args():
@@ -34,23 +35,25 @@ def parse_args():
'--savefile',
required=True,
help='Filename to save image')
+ parser.add_argument(
+ '--load',
+ required=False,
+ default='N',
+ help='Y if you want to load the most recent weights'
+ )
return parser.parse_args()
-def train(model):
+def train(model: YourModel):
for i in range(hp.num_epochs):
- if i % 100 == 0:
- fn = f"checkpoint-images/img-epoch{i}.jpg"
- save_image(fn, tf.squeeze(model.x))
-
- print('batch', i)
- model.train_step()
-
-def save_image(filename, image):
- image = transform.resize(image, tf.shape(image), anti_aliasing=True).astype('uint8')
- imsave(filename, image)
-
+ if i % 50 == 0:
+ copy = tf.identity(model.x)
+ copy = tf.squeeze(copy)
+ copy = tf.image.convert_image_dtype(copy, tf.uint8)
+ imsave('save_images/epoch' + str(i) + '.jpg', copy)
+ np.save('checkpoint.npy', model.x)
+ model.train_step(i)
def main():
""" Main function. """
@@ -59,27 +62,30 @@ def main():
if os.path.exists(ARGS.style):
ARGS.style = os.path.abspath(ARGS.style)
os.chdir(sys.path[0])
- print('this is',ARGS.content)
+ print('this is', ARGS.content)
+
+
content_image = imread(ARGS.content)
style_image = imread(ARGS.style)
- style_image = transform.resize(style_image, content_image.shape)
- my_model = YourModel(content_image=content_image, style_image=style_image)
- my_model.vgg16.build([1, 255, 255, 3])
- my_model.vgg16.load_weights('vgg16_imagenet.h5', by_name=True)
- train(my_model)
+ style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True)
- final_image = tf.squeeze(my_model.x)
- final_image = transform.resize(final_image, tf.shape(final_image), anti_aliasing=True).astype('uint8')
+ my_model = YourModel(content_image=content_image, style_image=style_image)
- # convert image to uint8
- final_image = tf.cast(final_image, tf.uint8)
+ if (ARGS.load == 'Y'):
+ checkpoint = np.load('checkpoint.npy')
+ image = tf.Variable(initial_value=checkpoint)
- plt.imshow(final_image)
+ train(my_model)
+
+ # convert the tensor into an image
+ my_model.x = tf.squeeze(my_model.x)
+ final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8)
imsave(ARGS.savefile, final_image)
-
+ plt.imshow(final_image)
+ plt.show()
ARGS = parse_args()