diff options
author | David Doan <daviddoan@davids-mbp-3.devices.brown.edu> | 2022-05-09 00:25:39 -0400 |
---|---|---|
committer | David Doan <daviddoan@davids-mbp-3.devices.brown.edu> | 2022-05-09 00:25:39 -0400 |
commit | a0870ac3f1f84278c5b9fe7f78f6b1af1d1f33e9 (patch) | |
tree | 440f9c17f23042e32c7c495c49d36bb838fcd73a /main.py | |
parent | 18f1f7bddcb63502120581f3fa24b980559ffa9f (diff) |
clean and refactor code for submission
Diffstat (limited to 'main.py')
-rw-r--r-- | main.py | 56 |
1 files changed, 31 insertions, 25 deletions
@@ -5,6 +5,8 @@ import argparse import tensorflow as tf from skimage import transform +import PIL.Image + import hyperparameters as hp from losses import YourModel # from tensorboard_utils import \ @@ -13,7 +15,6 @@ from losses import YourModel from skimage.io import imread, imsave from matplotlib import pyplot as plt import numpy as np -from skimage import transform def parse_args(): @@ -34,23 +35,25 @@ def parse_args(): '--savefile', required=True, help='Filename to save image') + parser.add_argument( + '--load', + required=False, + default='N', + help='Y if you want to load the most recent weights' + ) return parser.parse_args() -def train(model): +def train(model: YourModel): for i in range(hp.num_epochs): - if i % 100 == 0: - fn = f"checkpoint-images/img-epoch{i}.jpg" - save_image(fn, tf.squeeze(model.x)) - - print('batch', i) - model.train_step() - -def save_image(filename, image): - image = transform.resize(image, tf.shape(image), anti_aliasing=True).astype('uint8') - imsave(filename, image) - + if i % 50 == 0: + copy = tf.identity(model.x) + copy = tf.squeeze(copy) + copy = tf.image.convert_image_dtype(copy, tf.uint8) + imsave('save_images/epoch' + str(i) + '.jpg', copy) + np.save('checkpoint.npy', model.x) + model.train_step(i) def main(): """ Main function. """ @@ -59,27 +62,30 @@ def main(): if os.path.exists(ARGS.style): ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) - print('this is',ARGS.content) + print('this is', ARGS.content) + + content_image = imread(ARGS.content) style_image = imread(ARGS.style) - style_image = transform.resize(style_image, content_image.shape) - my_model = YourModel(content_image=content_image, style_image=style_image) - my_model.vgg16.build([1, 255, 255, 3]) - my_model.vgg16.load_weights('vgg16_imagenet.h5', by_name=True) - train(my_model) + style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True) - final_image = tf.squeeze(my_model.x) - final_image = transform.resize(final_image, tf.shape(final_image), anti_aliasing=True).astype('uint8') + my_model = YourModel(content_image=content_image, style_image=style_image) - # convert image to uint8 - final_image = tf.cast(final_image, tf.uint8) + if (ARGS.load == 'Y'): + checkpoint = np.load('checkpoint.npy') + image = tf.Variable(initial_value=checkpoint) - plt.imshow(final_image) + train(my_model) + + # convert the tensor into an image + my_model.x = tf.squeeze(my_model.x) + final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8) imsave(ARGS.savefile, final_image) - + plt.imshow(final_image) + plt.show() ARGS = parse_args() |