aboutsummaryrefslogtreecommitdiff
path: root/main.py
blob: 063670b8287c1d9a97ae4889586a099200c05fd6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import sys
import argparse
import re
from datetime import datetime
import tensorflow as tf

import hyperparameters as hp
from losses import YourModel
from preprocess import Datasets
from skimage.transform import resize
# from tensorboard_utils import \
#         ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver

from skimage.io import imread, imsave
from lime import lime_image
from skimage.segmentation import mark_boundaries
from matplotlib import pyplot as plt
import numpy as np


def parse_args():
    """ Perform command-line argument parsing. """

    parser = argparse.ArgumentParser(
        description="Let's train some neural nets!",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--content',
        required=True,
        help='''Content image filepath''')
    parser.add_argument(
        '--style',
        required=True,
        help='Style image filepath')
    parser.add_argument(
        '--savefile',
        required=True,
        help='Filename to save image')


    return parser.parse_args()

def train(model):
    for _ in range(hp.num_epochs):
        model.train_step()

def main():
    """ Main function. """
    if os.path.exists(ARGS.content):
        ARGS.content = os.path.abspath(ARGS.content)
    if os.path.exists(ARGS.style):
        ARGS.style = os.path.abspath(ARGS.style)
    os.chdir(sys.path[0])
    print('this is',ARGS.content)

    content_image = imread(ARGS.content)
    style_image = imread(ARGS.style)
    my_model = YourModel(content_image=content_image, style_image=style_image)
    train(my_model)
    
    final_image = my_model.x

    plt.imshow(final_image)

    imsave(ARGS.savefile, final_image)


ARGS = parse_args()
main()