diff options
Diffstat (limited to 'main.py')
-rw-r--r-- | main.py | 13 |
1 files changed, 4 insertions, 9 deletions
@@ -5,12 +5,8 @@ import argparse import tensorflow as tf from skimage import transform -import PIL.Image - import hyperparameters as hp -from losses import YourModel -# from tensorboard_utils import \ -# ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver +from model import YourModel from skimage.io import imread, imsave from matplotlib import pyplot as plt @@ -42,9 +38,9 @@ def parse_args(): help='Y if you want to load the most recent weights' ) - return parser.parse_args() + def train(model: YourModel): for i in range(hp.num_epochs): if i % 50 == 0: @@ -55,6 +51,7 @@ def train(model: YourModel): np.save('checkpoint.npy', model.x) model.train_step(i) + def main(): """ Main function. """ if os.path.exists(ARGS.content): @@ -63,13 +60,11 @@ def main(): ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) print('this is', ARGS.content) - - content_image = imread(ARGS.content) style_image = imread(ARGS.style) style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True) - + my_model = YourModel(content_image=content_image, style_image=style_image) if (ARGS.load == 'Y'): |