import os import sys import argparse import tensorflow as tf import hyperparameters as hp from losses import YourModel # from tensorboard_utils import \ # ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver from skimage.io import imread, imsave from matplotlib import pyplot as plt import numpy as np def parse_args(): """ Perform command-line argument parsing. """ parser = argparse.ArgumentParser( description="Let's train some neural nets!", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--content', required=True, help='''Content image filepath''') parser.add_argument( '--style', required=True, help='Style image filepath') parser.add_argument( '--savefile', required=True, help='Filename to save image') return parser.parse_args() def train(model): for _ in range(hp.num_epochs): model.train_step() def main(): """ Main function. """ if os.path.exists(ARGS.content): ARGS.content = os.path.abspath(ARGS.content) if os.path.exists(ARGS.style): ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) print('this is',ARGS.content) content_image = imread(ARGS.content) content_image = np.resize(content_image, (255, 255, 3)) style_image = imread(ARGS.style) style_image = np.resize(style_image, (255, 255, 3)) my_model = YourModel(content_image=content_image, style_image=style_image) my_model.vgg16.build([1, 255, 255, 3]) my_model.vgg16.load_weights('vgg16_imagenet.h5', by_name=True) train(my_model) final_image = tf.squeeze(my_model.x) plt.imshow(final_image) imsave(ARGS.savefile, final_image) ARGS = parse_args() main()