import os import sys import argparse import re from datetime import datetime import tensorflow as tf import hyperparameters as hp from losses import YourModel from preprocess import Datasets from skimage.transform import resize # from tensorboard_utils import \ # ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver from skimage.io import imread from lime import lime_image from skimage.segmentation import mark_boundaries from matplotlib import pyplot as plt import numpy as np os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' def parse_args(): """ Perform command-line argument parsing. """ parser = argparse.ArgumentParser( description="Let's train some neural nets!", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--task', required=True, choices=['1', '3'], help='''Which task of the assignment to run - training from scratch (1), or fine tuning VGG-16 (3).''') parser.add_argument( '--data', default='..'+os.sep+'data'+os.sep, help='Location where the dataset is stored.') parser.add_argument( '--load-vgg', default='vgg16_imagenet.h5', help='''Path to pre-trained VGG-16 file (only applicable to task 3).''') parser.add_argument( '--load-checkpoint', default=None, help='''Path to model checkpoint file (should end with the extension .h5). Checkpoints are automatically saved when you train your model. If you want to continue training from where you left off, this is how you would load your weights.''') parser.add_argument( '--confusion', action='store_true', help='''Log a confusion matrix at the end of each epoch (viewable in Tensorboard). This is turned off by default as it takes a little bit of time to complete.''') parser.add_argument( '--evaluate', action='store_true', help='''Skips training and evaluates on the test set once. You can use this to test an already trained model by loading its checkpoint.''') parser.add_argument( '--lime-image', default='test/Bedroom/image_0003.jpg', help='''Name of an image in the dataset to use for LIME evaluation.''') return parser.parse_args() def LIME_explainer(model, path, preprocess_fn): """ This function takes in a trained model and a path to an image and outputs 5 visual explanations using the LIME model """ def image_and_mask(title, positive_only=True, num_features=5, hide_rest=True): temp, mask = explanation.get_image_and_mask( explanation.top_labels[0], positive_only=positive_only, num_features=num_features, hide_rest=hide_rest) plt.imshow(mark_boundaries(temp / 2 + 0.5, mask)) plt.title(title) plt.show() image = imread(path) if len(image.shape) == 2: image = np.stack([image, image, image], axis=-1) image = preprocess_fn(image) image = resize(image, (hp.img_size, hp.img_size, 3)) explainer = lime_image.LimeImageExplainer() explanation = explainer.explain_instance( image.astype('double'), model.predict, top_labels=5, hide_color=0, num_samples=1000) # The top 5 superpixels that are most positive towards the class with the # rest of the image hidden image_and_mask("Top 5 superpixels", positive_only=True, num_features=5, hide_rest=True) # The top 5 superpixels with the rest of the image present image_and_mask("Top 5 with the rest of the image present", positive_only=True, num_features=5, hide_rest=False) # The 'pros and cons' (pros in green, cons in red) image_and_mask("Pros(green) and Cons(red)", positive_only=False, num_features=10, hide_rest=False) # Select the same class explained on the figures above. ind = explanation.top_labels[0] # Map each explanation weight to the corresponding superpixel dict_heatmap = dict(explanation.local_exp[ind]) heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) plt.imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max()) plt.colorbar() plt.title("Map each explanation weight to the corresponding superpixel") plt.show() def train(model, datasets, checkpoint_path, logs_path, init_epoch): """ Training routine. """ # Keras callbacks for training callback_list = [ tf.keras.callbacks.TensorBoard( log_dir=logs_path, update_freq='batch', profile_batch=0) # ImageLabelingLogger(logs_path, datasets), # CustomModelSaver(checkpoint_path, ARGS.task, hp.max_num_weights) ] # Include confusion logger in callbacks if flag set if ARGS.confusion: callback_list.append(ConfusionMatrixLogger(logs_path, datasets)) # Begin training model.fit( x=datasets.train_data, validation_data=datasets.test_data, epochs=hp.num_epochs, batch_size=None, callbacks=callback_list, initial_epoch=init_epoch, ) def test(model, test_data): """ Testing routine. """ # Run model on test set model.evaluate( x=test_data, verbose=1, ) def main(): """ Main function. """ time_now = datetime.now() timestamp = time_now.strftime("%m%d%y-%H%M%S") init_epoch = 0 # If loading from a checkpoint, the loaded checkpoint's directory # will be used for future checkpoints if ARGS.load_checkpoint is not None: ARGS.load_checkpoint = os.path.abspath(ARGS.load_checkpoint) # Get timestamp and epoch from filename regex = r"(?:.+)(?:\.e)(\d+)(?:.+)(?:.h5)" init_epoch = int(re.match(regex, ARGS.load_checkpoint).group(1)) + 1 timestamp = os.path.basename(os.path.dirname(ARGS.load_checkpoint)) # If paths provided by program arguments are accurate, then this will # ensure they are used. If not, these directories/files will be # set relative to the directory of run.py if os.path.exists(ARGS.data): ARGS.data = os.path.abspath(ARGS.data) if os.path.exists(ARGS.load_vgg): ARGS.load_vgg = os.path.abspath(ARGS.load_vgg) # Run script from location of run.py os.chdir(sys.path[0]) datasets = Datasets(ARGS.data, ARGS.task) if ARGS.task == '1': model = YourModel() model(tf.keras.Input(shape=(hp.img_size, hp.img_size, 3))) checkpoint_path = "checkpoints" + os.sep + \ "your_model" + os.sep + timestamp + os.sep logs_path = "logs" + os.sep + "your_model" + \ os.sep + timestamp + os.sep # Print summary of model model.summary() else: model = VGGModel() checkpoint_path = "checkpoints" + os.sep + \ "vgg_model" + os.sep + timestamp + os.sep logs_path = "logs" + os.sep + "vgg_model" + \ os.sep + timestamp + os.sep model(tf.keras.Input(shape=(224, 224, 3))) # Print summaries for both parts of the model model.vgg16.summary() model.head.summary() # Load base of VGG model model.vgg16.load_weights(ARGS.load_vgg, by_name=True) # Load checkpoints if ARGS.load_checkpoint is not None: if ARGS.task == '1': model.load_weights(ARGS.load_checkpoint, by_name=False) else: model.head.load_weights(ARGS.load_checkpoint, by_name=False) # Make checkpoint directory if needed if not ARGS.evaluate and not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Compile model graph model.compile( optimizer=model.optimizer, loss=model.loss_fn, metrics=["sparse_categorical_accuracy"]) if ARGS.evaluate: test(model, datasets.test_data) # TODO: change the image path to be the image of your choice by changing # the lime-image flag when calling run.py to investigate # i.e. python run.py --evaluate --lime-image test/Bedroom/image_003.jpg path = ARGS.data + os.sep + ARGS.lime_image LIME_explainer(model, path, datasets.preprocess_fn) else: train(model, datasets, checkpoint_path, logs_path, init_epoch) # Make arguments global ARGS = parse_args() main()