aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py248
1 files changed, 248 insertions, 0 deletions
diff --git a/main.py b/main.py
new file mode 100644
index 00000000..ca87788d
--- /dev/null
+++ b/main.py
@@ -0,0 +1,248 @@
+import os
+import sys
+import argparse
+import re
+from datetime import datetime
+import tensorflow as tf
+
+import hyperparameters as hp
+from models import YourModel, VGGModel
+from preprocess import Datasets
+from skimage.transform import resize
+from tensorboard_utils import \
+ ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver
+
+from skimage.io import imread
+from lime import lime_image
+from skimage.segmentation import mark_boundaries
+from matplotlib import pyplot as plt
+import numpy as np
+
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+
+
+def parse_args():
+ """ Perform command-line argument parsing. """
+
+ parser = argparse.ArgumentParser(
+ description="Let's train some neural nets!",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '--task',
+ required=True,
+ choices=['1', '3'],
+ help='''Which task of the assignment to run -
+ training from scratch (1), or fine tuning VGG-16 (3).''')
+ parser.add_argument(
+ '--data',
+ default='..'+os.sep+'data'+os.sep,
+ help='Location where the dataset is stored.')
+ parser.add_argument(
+ '--load-vgg',
+ default='vgg16_imagenet.h5',
+ help='''Path to pre-trained VGG-16 file (only applicable to
+ task 3).''')
+ parser.add_argument(
+ '--load-checkpoint',
+ default=None,
+ help='''Path to model checkpoint file (should end with the
+ extension .h5). Checkpoints are automatically saved when you
+ train your model. If you want to continue training from where
+ you left off, this is how you would load your weights.''')
+ parser.add_argument(
+ '--confusion',
+ action='store_true',
+ help='''Log a confusion matrix at the end of each
+ epoch (viewable in Tensorboard). This is turned off
+ by default as it takes a little bit of time to complete.''')
+ parser.add_argument(
+ '--evaluate',
+ action='store_true',
+ help='''Skips training and evaluates on the test set once.
+ You can use this to test an already trained model by loading
+ its checkpoint.''')
+ parser.add_argument(
+ '--lime-image',
+ default='test/Bedroom/image_0003.jpg',
+ help='''Name of an image in the dataset to use for LIME evaluation.''')
+
+ return parser.parse_args()
+
+
+def LIME_explainer(model, path, preprocess_fn):
+ """
+ This function takes in a trained model and a path to an image and outputs 5
+ visual explanations using the LIME model
+ """
+
+ def image_and_mask(title, positive_only=True, num_features=5,
+ hide_rest=True):
+ temp, mask = explanation.get_image_and_mask(
+ explanation.top_labels[0], positive_only=positive_only,
+ num_features=num_features, hide_rest=hide_rest)
+ plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
+ plt.title(title)
+ plt.show()
+
+ image = imread(path)
+ if len(image.shape) == 2:
+ image = np.stack([image, image, image], axis=-1)
+ image = preprocess_fn(image)
+ image = resize(image, (hp.img_size, hp.img_size, 3))
+
+ explainer = lime_image.LimeImageExplainer()
+
+ explanation = explainer.explain_instance(
+ image.astype('double'), model.predict, top_labels=5, hide_color=0,
+ num_samples=1000)
+
+ # The top 5 superpixels that are most positive towards the class with the
+ # rest of the image hidden
+ image_and_mask("Top 5 superpixels", positive_only=True, num_features=5,
+ hide_rest=True)
+
+ # The top 5 superpixels with the rest of the image present
+ image_and_mask("Top 5 with the rest of the image present",
+ positive_only=True, num_features=5, hide_rest=False)
+
+ # The 'pros and cons' (pros in green, cons in red)
+ image_and_mask("Pros(green) and Cons(red)",
+ positive_only=False, num_features=10, hide_rest=False)
+
+ # Select the same class explained on the figures above.
+ ind = explanation.top_labels[0]
+ # Map each explanation weight to the corresponding superpixel
+ dict_heatmap = dict(explanation.local_exp[ind])
+ heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
+ plt.imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max())
+ plt.colorbar()
+ plt.title("Map each explanation weight to the corresponding superpixel")
+ plt.show()
+
+
+def train(model, datasets, checkpoint_path, logs_path, init_epoch):
+ """ Training routine. """
+
+ # Keras callbacks for training
+ callback_list = [
+ tf.keras.callbacks.TensorBoard(
+ log_dir=logs_path,
+ update_freq='batch',
+ profile_batch=0),
+ ImageLabelingLogger(logs_path, datasets),
+ CustomModelSaver(checkpoint_path, ARGS.task, hp.max_num_weights)
+ ]
+
+ # Include confusion logger in callbacks if flag set
+ if ARGS.confusion:
+ callback_list.append(ConfusionMatrixLogger(logs_path, datasets))
+
+ # Begin training
+ model.fit(
+ x=datasets.train_data,
+ validation_data=datasets.test_data,
+ epochs=hp.num_epochs,
+ batch_size=None,
+ callbacks=callback_list,
+ initial_epoch=init_epoch,
+ )
+
+
+def test(model, test_data):
+ """ Testing routine. """
+
+ # Run model on test set
+ model.evaluate(
+ x=test_data,
+ verbose=1,
+ )
+
+
+def main():
+ """ Main function. """
+
+ time_now = datetime.now()
+ timestamp = time_now.strftime("%m%d%y-%H%M%S")
+ init_epoch = 0
+
+ # If loading from a checkpoint, the loaded checkpoint's directory
+ # will be used for future checkpoints
+ if ARGS.load_checkpoint is not None:
+ ARGS.load_checkpoint = os.path.abspath(ARGS.load_checkpoint)
+
+ # Get timestamp and epoch from filename
+ regex = r"(?:.+)(?:\.e)(\d+)(?:.+)(?:.h5)"
+ init_epoch = int(re.match(regex, ARGS.load_checkpoint).group(1)) + 1
+ timestamp = os.path.basename(os.path.dirname(ARGS.load_checkpoint))
+
+ # If paths provided by program arguments are accurate, then this will
+ # ensure they are used. If not, these directories/files will be
+ # set relative to the directory of run.py
+ if os.path.exists(ARGS.data):
+ ARGS.data = os.path.abspath(ARGS.data)
+ if os.path.exists(ARGS.load_vgg):
+ ARGS.load_vgg = os.path.abspath(ARGS.load_vgg)
+
+ # Run script from location of run.py
+ os.chdir(sys.path[0])
+
+ datasets = Datasets(ARGS.data, ARGS.task)
+
+ if ARGS.task == '1':
+ model = YourModel()
+ model(tf.keras.Input(shape=(hp.img_size, hp.img_size, 3)))
+ checkpoint_path = "checkpoints" + os.sep + \
+ "your_model" + os.sep + timestamp + os.sep
+ logs_path = "logs" + os.sep + "your_model" + \
+ os.sep + timestamp + os.sep
+
+ # Print summary of model
+ model.summary()
+ else:
+ model = VGGModel()
+ checkpoint_path = "checkpoints" + os.sep + \
+ "vgg_model" + os.sep + timestamp + os.sep
+ logs_path = "logs" + os.sep + "vgg_model" + \
+ os.sep + timestamp + os.sep
+ model(tf.keras.Input(shape=(224, 224, 3)))
+
+ # Print summaries for both parts of the model
+ model.vgg16.summary()
+ model.head.summary()
+
+ # Load base of VGG model
+ model.vgg16.load_weights(ARGS.load_vgg, by_name=True)
+
+ # Load checkpoints
+ if ARGS.load_checkpoint is not None:
+ if ARGS.task == '1':
+ model.load_weights(ARGS.load_checkpoint, by_name=False)
+ else:
+ model.head.load_weights(ARGS.load_checkpoint, by_name=False)
+
+ # Make checkpoint directory if needed
+ if not ARGS.evaluate and not os.path.exists(checkpoint_path):
+ os.makedirs(checkpoint_path)
+
+ # Compile model graph
+ model.compile(
+ optimizer=model.optimizer,
+ loss=model.loss_fn,
+ metrics=["sparse_categorical_accuracy"])
+
+ if ARGS.evaluate:
+ test(model, datasets.test_data)
+
+ # TODO: change the image path to be the image of your choice by changing
+ # the lime-image flag when calling run.py to investigate
+ # i.e. python run.py --evaluate --lime-image test/Bedroom/image_003.jpg
+ path = ARGS.data + os.sep + ARGS.lime_image
+ LIME_explainer(model, path, datasets.preprocess_fn)
+ else:
+ train(model, datasets, checkpoint_path, logs_path, init_epoch)
+
+
+# Make arguments global
+ARGS = parse_args()
+
+main()