import os import sys import argparse import tensorflow as tf from skimage import transform import hyperparameters as hp from losses import YourModel # from tensorboard_utils import \ # ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver from skimage.io import imread, imsave from matplotlib import pyplot as plt import numpy as np def parse_args(): """ Perform command-line argument parsing. """ parser = argparse.ArgumentParser( description="Let's train some neural nets!", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--content', required=True, help='''Content image filepath''') parser.add_argument( '--style', required=True, help='Style image filepath') parser.add_argument( '--savefile', required=True, help='Filename to save image') return parser.parse_args() def train(model): for i in range(hp.num_epochs): print('batch', i) model.train_step() def main(): """ Main function. """ if os.path.exists(ARGS.content): ARGS.content = os.path.abspath(ARGS.content) if os.path.exists(ARGS.style): ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) print('this is',ARGS.content) content_image = imread(ARGS.content) style_image = imread(ARGS.style) style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True) my_model = YourModel(content_image=content_image, style_image=style_image) my_model.vgg16.build([1, 255, 255, 3]) my_model.vgg16.load_weights('vgg16_imagenet.h5', by_name=True) train(my_model) final_image = tf.squeeze(my_model.x) plt.imshow(final_image) imsave(ARGS.savefile, final_image) ARGS = parse_args() main()