import os import sys import argparse import tensorflow as tf from skimage import transform import hyperparameters as hp from model import YourModel from skimage.io import imread, imsave from matplotlib import pyplot as plt import numpy as np def parse_args(): """ Perform command-line argument parsing. """ parser = argparse.ArgumentParser( description="Let's train some neural nets!", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--content', required=True, help='''Content image filepath''') parser.add_argument( '--style', required=True, help='Style image filepath') parser.add_argument( '--savefile', required=True, help='Filename to save image') parser.add_argument( '--load', required=False, default='N', help='Y if you want to load the most recent weights' ) return parser.parse_args() def train(model: YourModel): for i in range(hp.num_epochs): if i % 50 == 0: copy = tf.identity(model.x) copy = tf.squeeze(copy) copy = tf.image.convert_image_dtype(copy, tf.uint8) imsave('save_images/epoch' + str(i) + '.jpg', copy) np.save('checkpoint.npy', model.x) model.train_step(i) def main(): """ Main function. """ if os.path.exists(ARGS.content): ARGS.content = os.path.abspath(ARGS.content) if os.path.exists(ARGS.style): ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) print('this is', ARGS.content) content_image = imread(ARGS.content) style_image = imread(ARGS.style) style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True) my_model = YourModel(content_image=content_image, style_image=style_image) if (ARGS.load == 'Y'): checkpoint = np.load('checkpoint.npy') image = tf.Variable(initial_value=checkpoint) train(my_model) # convert the tensor into an image my_model.x = tf.squeeze(my_model.x) final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8) imsave(ARGS.savefile, final_image) plt.imshow(final_image) plt.show() ARGS = parse_args() main()