blob: 0d84f216b6770726f57cdf78b751b4b9bdb38ca6 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
import os
import sys
import argparse
import re
from datetime import datetime
import tensorflow as tf
import hyperparameters as hp
from losses import YourModel
from preprocess import Datasets
from skimage.transform import resize
# from tensorboard_utils import \
# ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver
from skimage.io import imread, imsave
from lime import lime_image
from skimage.segmentation import mark_boundaries
from matplotlib import pyplot as plt
import numpy as np
def parse_args():
""" Perform command-line argument parsing. """
parser = argparse.ArgumentParser(
description="Let's train some neural nets!",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--content',
required=True,
help='''Content image filepath''')
parser.add_argument(
'--style',
required=True,
help='Style image filepath')
parser.add_argument(
'--savefile',
required=True,
help='Filename to save image')
return parser.parse_args()
def train(model):
for _ in hp.num_epochs:
model.train_step()
def main():
""" Main function. """
if os.path.exists(ARGS.content):
ARGS.content = os.path.abspath(ARGS.content)
if os.path.exists(ARGS.style):
ARGS.style = os.path.abspath(ARGS.style)
os.chdir(sys.path[0])
content_image = imread(ARGS.content)
style_image = imread(ARGS.style)
my_model = YourModel(content_image=content_image, style_image=style_image)
train(my_model)
final_image = my_model.x
plt.imshow(final_image)
imsave(ARGS.savefile, final_image)
ARGS = parse_args()
main()
|