1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
import os
import sys
import argparse
import tensorflow as tf
from skimage import transform
import hyperparameters as hp
from model import YourModel
from skimage.io import imread, imsave
from matplotlib import pyplot as plt
def parse_args():
""" Perform command-line argument parsing. """
parser = argparse.ArgumentParser(
description="Let's train some neural nets!",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--content',
required=True,
help='''Content image filepath''')
parser.add_argument(
'--style',
required=True,
help='Style image filepath')
parser.add_argument(
'--savefile',
required=True,
help='Filename to save image')
return parser.parse_args()
def save_tensor_as_image(tensor, path, img_name):
# make copy of tensor
copy = tf.identity(tensor)
copy = tf.squeeze(copy)
# convert tensor to back to uint8 image
copy = tf.image.convert_image_dtype(copy, tf.uint8)
# save image (make path if it doesn't exist)
if not os.path.exists(path):
os.makedirs(path)
imsave(path + img_name, copy)
# return copy if used
return copy
def train(model: YourModel):
# do as many epochs from hyperparameters.py
for i in range(hp.num_epochs):
# save a checkpoint every 100 epochs
if i % 100 == 0:
save_tensor_as_image(model.x, 'out/checkpoints/cp-{}/'.format(ARGS.savefile),
'{}_epoch.{}'.format(i, ARGS.savefile.split('.')[-1]))
# do the training step
model.train_step(i)
def main():
""" Main function. """
# --------------------------------------------------------------------------------------------------------------
# PART 1 : parse the arguments
# --------------------------------------------------------------------------------------------------------------
if os.path.exists(ARGS.content):
ARGS.content = os.path.abspath(ARGS.content)
if os.path.exists(ARGS.style):
ARGS.style = os.path.abspath(ARGS.style)
if os.path.exists('out/checkpoints/cp-{}/'.format(ARGS.savefile)):
# print an error to the console if the checkpoint directory already exists
print('Error: out/checkpoints/cp-{}/ already exists. Please choose a different name.'.format(ARGS.savefile))
return
os.chdir(sys.path[0])
# --------------------------------------------------------------------------------------------------------------
# PART 2 : read and process the style and content images
# --------------------------------------------------------------------------------------------------------------
# 1) read content and style images
content_image = imread(ARGS.content)
style_image = imread(ARGS.style)
# 2) make the style image the same size as the content image
style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True)
# --------------------------------------------------------------------------------------------------------------
# PART 3 : make and train our model
# --------------------------------------------------------------------------------------------------------------
# 1) initialize our model class
my_model = YourModel(content_image=content_image, style_image=style_image)
# 2) train the model calling the helper
train(my_model)
# --------------------------------------------------------------------------------------------------------------
# PART 4 : save and show result from final epoch
# --------------------------------------------------------------------------------------------------------------
# model.x is the most recent image created by the model
result_tensor = my_model.x
# save image in output folder
final_image = save_tensor_as_image(result_tensor, 'out/results/', ARGS.savefile)
# store + show the final result :)
plt.imshow(final_image)
print('\nDONE: saved the final result to out/results/{}.'.format(ARGS.savefile))
ARGS = parse_args()
main()
|