import os import sys import argparse import tensorflow as tf from skimage import transform import PIL.Image import hyperparameters as hp from losses import YourModel # from tensorboard_utils import \ # ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver from skimage.io import imread, imsave from matplotlib import pyplot as plt import numpy as np def parse_args(): """ Perform command-line argument parsing. """ parser = argparse.ArgumentParser( description="Let's train some neural nets!", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--content', required=True, help='''Content image filepath''') parser.add_argument( '--style', required=True, help='Style image filepath') parser.add_argument( '--savefile', required=True, help='Filename to save image') parser.add_argument( '--load', required=False, default='N', help='Y if you want to load the most recent weights' ) return parser.parse_args() def train(model: YourModel): for i in range(hp.num_epochs): if i % 50 == 0: copy = tf.identity(model.x) copy = tf.squeeze(copy) copy = tf.image.convert_image_dtype(copy, tf.uint8) imsave('save_images/epoch' + str(i) + '.jpg', copy) np.save('checkpoint.npy', model.x) model.train_step(i) def main(): """ Main function. """ if os.path.exists(ARGS.content): ARGS.content = os.path.abspath(ARGS.content) if os.path.exists(ARGS.style): ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) print('this is', ARGS.content) content_image = imread(ARGS.content) style_image = imread(ARGS.style) style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True) my_model = YourModel(content_image=content_image, style_image=style_image) if (ARGS.load == 'Y'): checkpoint = np.load('checkpoint.npy') image = tf.Variable(initial_value=checkpoint) train(my_model) # convert the tensor into an image my_model.x = tf.squeeze(my_model.x) final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8) imsave(ARGS.savefile, final_image) plt.imshow(final_image) plt.show() ARGS = parse_args() main()