import os import sys import argparse import re from datetime import datetime import tensorflow as tf import hyperparameters as hp from losses import YourModel from preprocess import Datasets from skimage.transform import resize # from tensorboard_utils import \ # ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver from skimage.io import imread, imsave from lime import lime_image from skimage.segmentation import mark_boundaries from matplotlib import pyplot as plt import numpy as np def parse_args(): """ Perform command-line argument parsing. """ parser = argparse.ArgumentParser( description="Let's train some neural nets!", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--content', required=True, help='''Content image filepath''') parser.add_argument( '--style', required=True, help='Style image filepath') parser.add_argument( '--savefile', required=True, help='Filename to save image') return parser.parse_args() def train(model): for _ in range(hp.num_epochs): model.train_step() def main(): """ Main function. """ if os.path.exists(ARGS.content): ARGS.content = os.path.abspath(ARGS.content) if os.path.exists(ARGS.style): ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) print('this is',ARGS.content) content_image = imread(ARGS.content) style_image = imread(ARGS.style) my_model = YourModel(content_image=content_image, style_image=style_image) train(my_model) final_image = my_model.x plt.imshow(final_image) imsave(ARGS.savefile, final_image) ARGS = parse_args() main()