1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
import os
import sys
import argparse
import tensorflow as tf
from skimage import transform
import hyperparameters as hp
from model import YourModel
from skimage.io import imread, imsave
from matplotlib import pyplot as plt
import numpy as np
def parse_args():
""" Perform command-line argument parsing. """
parser = argparse.ArgumentParser(
description="Let's train some neural nets!",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--content',
required=True,
help='''Content image filepath''')
parser.add_argument(
'--style',
required=True,
help='Style image filepath')
parser.add_argument(
'--savefile',
required=True,
help='Filename to save image')
parser.add_argument(
'--load',
required=False,
default='N',
help='Y if you want to load the most recent weights'
)
return parser.parse_args()
def train(model: YourModel):
for i in range(hp.num_epochs):
if i % 50 == 0:
copy = tf.identity(model.x)
copy = tf.squeeze(copy)
copy = tf.image.convert_image_dtype(copy, tf.uint8)
imsave('save_images/epoch' + str(i) + '.jpg', copy)
np.save('checkpoint.npy', model.x)
model.train_step(i)
def main():
""" Main function. """
if os.path.exists(ARGS.content):
ARGS.content = os.path.abspath(ARGS.content)
if os.path.exists(ARGS.style):
ARGS.style = os.path.abspath(ARGS.style)
os.chdir(sys.path[0])
print('this is', ARGS.content)
content_image = imread(ARGS.content)
style_image = imread(ARGS.style)
style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True)
my_model = YourModel(content_image=content_image, style_image=style_image)
if (ARGS.load == 'Y'):
checkpoint = np.load('checkpoint.npy')
image = tf.Variable(initial_value=checkpoint)
train(my_model)
# convert the tensor into an image
my_model.x = tf.squeeze(my_model.x)
final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8)
imsave(ARGS.savefile, final_image)
plt.imshow(final_image)
plt.show()
ARGS = parse_args()
main()
|