diff options
Diffstat (limited to 'main.py')
-rw-r--r-- | main.py | 71 |
1 files changed, 48 insertions, 23 deletions
@@ -31,54 +31,79 @@ def parse_args(): '--savefile', required=True, help='Filename to save image') - parser.add_argument( - '--load', - required=False, - default='N', - help='Y if you want to load the most recent weights' - ) return parser.parse_args() +def save_tensor_as_image(tensor, path, img_name): + # make copy of tensor + copy = tf.identity(tensor) + copy = tf.squeeze(copy) + # convert tensor to back to uint8 image + copy = tf.image.convert_image_dtype(copy, tf.uint8) + + # save image (make path if it doesn't exist) + if not os.path.exists(path): + os.makedirs(path) + imsave(path + img_name, copy) + + # return copy if used + return copy + + def train(model: YourModel): + # do as many epochs from hyperparameters.py for i in range(hp.num_epochs): - if i % 50 == 0: - copy = tf.identity(model.x) - copy = tf.squeeze(copy) - copy = tf.image.convert_image_dtype(copy, tf.uint8) - imsave('save_images/epoch' + str(i) + '.jpg', copy) - np.save('checkpoint.npy', model.x) + # save a checkpoint every 100 epochs + if i % 100 == 0: + save_tensor_as_image(model.x, 'out/checkpoints/cp-{}/'.format(ARGS.savefile), + '{}_epoch.{}'.format(i, ARGS.savefile.split('.')[-1])) + + # do the training step model.train_step(i) def main(): """ Main function. """ + # -------------------------------------------------------------------------------------------------------------- + # PART 1 : parse the arguments # + # -------------------------------------------------------------------------------------------------------------- if os.path.exists(ARGS.content): ARGS.content = os.path.abspath(ARGS.content) if os.path.exists(ARGS.style): ARGS.style = os.path.abspath(ARGS.style) + if os.path.exists('out/checkpoints/cp-{}/'.format(ARGS.savefile)): + # print an error to the console if the checkpoint directory already exists + print('Error: out/checkpoints/cp-{}/ already exists. Please choose a different name.'.format(ARGS.savefile)) + return os.chdir(sys.path[0]) - print('this is', ARGS.content) + # -------------------------------------------------------------------------------------------------------------- + # PART 2 : read and process the style and content images # + # -------------------------------------------------------------------------------------------------------------- + # 1) read content and style images content_image = imread(ARGS.content) style_image = imread(ARGS.style) + # 2) make the style image the same size as the content image style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True) + # -------------------------------------------------------------------------------------------------------------- + # PART 3 : make and train our model # + # -------------------------------------------------------------------------------------------------------------- + # 1) initialize our model class my_model = YourModel(content_image=content_image, style_image=style_image) - - if (ARGS.load == 'Y'): - checkpoint = np.load('checkpoint.npy') - image = tf.Variable(initial_value=checkpoint) - + # 2) train the model calling the helper train(my_model) - # convert the tensor into an image - my_model.x = tf.squeeze(my_model.x) - final_image = tf.image.convert_image_dtype(my_model.x, tf.uint8) - - imsave(ARGS.savefile, final_image) + # -------------------------------------------------------------------------------------------------------------- + # PART 4 : save and show result from final epoch # + # -------------------------------------------------------------------------------------------------------------- + # model.x is the most recent image created by the model + result_tensor = my_model.x + # save image in output folder + final_image = save_tensor_as_image(result_tensor, 'out/results/', ARGS.savefile) + # show the final result :) plt.imshow(final_image) plt.show() |