aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
authorMichael Foiani <sotech117@michaels-mbp-4.devices.brown.edu>2022-05-09 02:02:42 -0400
committerMichael Foiani <sotech117@michaels-mbp-4.devices.brown.edu>2022-05-09 02:02:42 -0400
commitc8386be631f38cde31391124423f77b7479d1bad (patch)
tree3615b99b52fa2f620bc88118bf4aee9e697361bf /main.py
parent13826824ec58e85557efb3a12fb0456ffb20e46d (diff)
added examples and fixed the io for checkpoints
Diffstat (limited to 'main.py')
-rw-r--r--main.py71
1 files changed, 48 insertions, 23 deletions
diff --git a/main.py b/main.py
index 527394cb..e60f14d1 100644
--- a/main.py
+++ b/main.py
@@ -31,54 +31,79 @@ def parse_args():
'--savefile',
required=True,
help='Filename to save image')
- parser.add_argument(
- '--load',
- required=False,
- default='N',
- help='Y if you want to load the most recent weights'
- )
return parser.parse_args()
+def save_tensor_as_image(tensor, path, img_name):
+ # make copy of tensor
+ copy = tf.identity(tensor)
+ copy = tf.squeeze(copy)
+ # convert tensor to back to uint8 image
+ copy = tf.image.convert_image_dtype(copy, tf.uint8)
+
+ # save image (make path if it doesn't exist)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ imsave(path + img_name, copy)
+
+ # return copy if used
+ return copy
+
+
def train(model: YourModel):
+ # do as many epochs from hyperparameters.py
for i in range(hp.num_epochs):
- if i % 50 == 0:
- copy = tf.identity(model.x)
- copy = tf.squeeze(copy)
- copy = tf.image.convert_image_dtype(copy, tf.uint8)
- imsave('save_images/epoch' + str(i) + '.jpg', copy)
- np.save('checkpoint.npy', model.x)
+ # save a checkpoint every 100 epochs
+ if i % 100 == 0:
+ save_tensor_as_image(model.x, 'out/checkpoints/cp-{}/'.format(ARGS.savefile),
+ '{}_epoch.{}'.format(i, ARGS.savefile.split('.')[-1]))
+
+ # do the training step
model.train_step(i)
def main():
""" Main function. """
+ # --------------------------------------------------------------------------------------------------------------
+ # PART 1 : parse the arguments #
+ # --------------------------------------------------------------------------------------------------------------
if os.path.exists(ARGS.content):
ARGS.content = os.path.abspath(ARGS.content)
if os.path.exists(ARGS.style):
ARGS.style = os.path.abspath(ARGS.style)
+ if os.path.exists('out/checkpoints/cp-{}/'.format(ARGS.savefile)):
+ # print an error to the console if the checkpoint directory already exists
+ print('Error: out/checkpoints/cp-{}/ already exists. Please choose a different name.'.format(ARGS.savefile))
+ return
os.chdir(sys.path[0])
- print('this is', ARGS.content)
+ # --------------------------------------------------------------------------------------------------------------
+ # PART 2 : read and process the style and content images #
+ # --------------------------------------------------------------------------------------------------------------
+ # 1) read content and style images
content_image = imread(ARGS.content)
style_image = imread(ARGS.style)
+ # 2) make the style image the same size as the content image
style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True)
+ # --------------------------------------------------------------------------------------------------------------
+ # PART 3 : make and train our model #
+ # --------------------------------------------------------------------------------------------------------------
+ # 1) initialize our model class
my_model = YourModel(content_image=content_image, style_image=style_image)
-
- if (ARGS.load == 'Y'):
- checkpoint = np.load('checkpoint.npy')
- image = tf.Variable(initial_value=checkpoint)
-
+ # 2) train the model calling the helper
train(my_model)
- # convert the tensor into an image
- my_model.x = tf.squeeze(my_model.x)
- final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8)
-
- imsave(ARGS.savefile, final_image)
+ # --------------------------------------------------------------------------------------------------------------
+ # PART 4 : save and show result from final epoch #
+ # --------------------------------------------------------------------------------------------------------------
+ # model.x is the most recent image created by the model
+ result_tensor = my_model.x
+ # save image in output folder
+ final_image = save_tensor_as_image(result_tensor, 'out/results/', ARGS.savefile)
+ # show the final result :)
plt.imshow(final_image)
plt.show()