aboutsummaryrefslogtreecommitdiff
path: root/main.py
blob: 527394cbf161977e57d2fb024439eec088a1d11e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import sys
import argparse

import tensorflow as tf
from skimage import transform

import hyperparameters as hp
from model import YourModel

from skimage.io import imread, imsave
from matplotlib import pyplot as plt
import numpy as np


def parse_args():
    """ Perform command-line argument parsing. """

    parser = argparse.ArgumentParser(
        description="Let's train some neural nets!",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--content',
        required=True,
        help='''Content image filepath''')
    parser.add_argument(
        '--style',
        required=True,
        help='Style image filepath')
    parser.add_argument(
        '--savefile',
        required=True,
        help='Filename to save image')
    parser.add_argument(
        '--load',
        required=False,
        default='N',
        help='Y if you want to load the most recent weights'
    )

    return parser.parse_args()


def train(model: YourModel):
    for i in range(hp.num_epochs):
        if i % 50 == 0:
            copy = tf.identity(model.x)
            copy = tf.squeeze(copy)
            copy = tf.image.convert_image_dtype(copy, tf.uint8)
            imsave('save_images/epoch' + str(i) + '.jpg', copy)
            np.save('checkpoint.npy', model.x)
        model.train_step(i)


def main():
    """ Main function. """
    if os.path.exists(ARGS.content):
        ARGS.content = os.path.abspath(ARGS.content)
    if os.path.exists(ARGS.style):
        ARGS.style = os.path.abspath(ARGS.style)
    os.chdir(sys.path[0])
    print('this is', ARGS.content)

    content_image = imread(ARGS.content)
    style_image = imread(ARGS.style)
    style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True)

    my_model = YourModel(content_image=content_image, style_image=style_image)

    if (ARGS.load == 'Y'):
        checkpoint = np.load('checkpoint.npy')
        image = tf.Variable(initial_value=checkpoint)

    train(my_model)

    # convert the tensor into an image
    my_model.x = tf.squeeze(my_model.x)
    final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8)

    imsave(ARGS.savefile, final_image)

    plt.imshow(final_image)
    plt.show()


ARGS = parse_args()
main()