|
| 1 | +#!/usr/bin/env python2 |
| 2 | + |
| 3 | +import os |
| 4 | +import argparse |
| 5 | +import time |
| 6 | +from getpass import getuser |
| 7 | +from distutils.dir_util import copy_tree |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +import theano |
| 11 | +import theano.tensor as T |
| 12 | +from theano import config |
| 13 | +import lasagne |
| 14 | +from lasagne.regularization import regularize_network_params |
| 15 | + |
| 16 | +from data_loader import load_data |
| 17 | +from Unet_lasagne_recipes import build_UNet |
| 18 | +# from metrics import jaccard, accuracy, crossentropy |
| 19 | + |
| 20 | + |
| 21 | +_FLOATX = config.floatX |
| 22 | +_EPSILON = 10e-7 |
| 23 | + |
| 24 | + |
| 25 | +def jaccard_metric(y_pred, y_true, n_classes, one_hot=False): |
| 26 | + |
| 27 | + assert (y_pred.ndim == 2) or (y_pred.ndim == 1) |
| 28 | + |
| 29 | + # y_pred to indices |
| 30 | + if y_pred.ndim == 2: |
| 31 | + y_pred = T.argmax(y_pred, axis=1) |
| 32 | + |
| 33 | + if one_hot: |
| 34 | + y_true = T.argmax(y_true, axis=1) |
| 35 | + |
| 36 | + # Compute confusion matrix |
| 37 | + cm = T.zeros((n_classes, n_classes)) |
| 38 | + for i in range(n_classes): |
| 39 | + for j in range(n_classes): |
| 40 | + cm = T.set_subtensor( |
| 41 | + cm[i, j], T.sum(T.eq(y_pred, i) * T.eq(y_true, j))) |
| 42 | + |
| 43 | + # Compute Jaccard Index |
| 44 | + TP_perclass = T.cast(cm.diagonal(), _FLOATX) |
| 45 | + FP_perclass = cm.sum(1) - TP_perclass |
| 46 | + FN_perclass = cm.sum(0) - TP_perclass |
| 47 | + |
| 48 | + num = TP_perclass |
| 49 | + denom = TP_perclass + FP_perclass + FN_perclass |
| 50 | + |
| 51 | + return T.stack([num, denom], axis=0) |
| 52 | + |
| 53 | + |
| 54 | +def accuracy_metric(y_pred, y_true, void_labels, one_hot=False): |
| 55 | + |
| 56 | + assert (y_pred.ndim == 2) or (y_pred.ndim == 1) |
| 57 | + |
| 58 | + # y_pred to indices |
| 59 | + if y_pred.ndim == 2: |
| 60 | + y_pred = T.argmax(y_pred, axis=1) |
| 61 | + |
| 62 | + if one_hot: |
| 63 | + y_true = T.argmax(y_true, axis=1) |
| 64 | + |
| 65 | + # Compute accuracy |
| 66 | + acc = T.eq(y_pred, y_true).astype(_FLOATX) |
| 67 | + |
| 68 | + # Create mask |
| 69 | + mask = T.ones_like(y_true, dtype=_FLOATX) |
| 70 | + for el in void_labels: |
| 71 | + indices = T.eq(y_true, el).nonzero() |
| 72 | + if any(indices): |
| 73 | + mask = T.set_subtensor(mask[indices], 0.) |
| 74 | + |
| 75 | + # Apply mask |
| 76 | + acc *= mask |
| 77 | + acc = T.sum(acc) / T.sum(mask) |
| 78 | + |
| 79 | + return acc |
| 80 | + |
| 81 | + |
| 82 | +def crossentropy_metric(y_pred, y_true, void_labels, one_hot=False): |
| 83 | + # Clip predictions |
| 84 | + y_pred = T.clip(y_pred, _EPSILON, 1.0 - _EPSILON) |
| 85 | + |
| 86 | + if one_hot: |
| 87 | + y_true = T.argmax(y_true, axis=1) |
| 88 | + |
| 89 | + # Create mask |
| 90 | + mask = T.ones_like(y_true, dtype=_FLOATX) |
| 91 | + for el in void_labels: |
| 92 | + mask = T.set_subtensor(mask[T.eq(y_true, el).nonzero()], 0.) |
| 93 | + |
| 94 | + # Modify y_true temporarily |
| 95 | + y_true_tmp = y_true * mask |
| 96 | + y_true_tmp = y_true_tmp.astype('int32') |
| 97 | + |
| 98 | + # Compute cross-entropy |
| 99 | + loss = T.nnet.categorical_crossentropy(y_pred, y_true_tmp) |
| 100 | + |
| 101 | + # Compute masked mean loss |
| 102 | + loss *= mask |
| 103 | + loss = T.sum(loss) / T.sum(mask) |
| 104 | + |
| 105 | + return loss |
| 106 | + |
| 107 | + |
| 108 | +SAVEPATH = 'save_models/' |
| 109 | +LOADPATH = SAVEPATH |
| 110 | +WEIGHTS_PATH = SAVEPATH |
| 111 | + |
| 112 | + |
| 113 | +def train(dataset, learn_step=0.005, |
| 114 | + weight_decay=1e-4, num_epochs=500, |
| 115 | + max_patience=100, data_augmentation={}, |
| 116 | + savepath=None, loadpath=None, |
| 117 | + early_stop_class=None, |
| 118 | + batch_size=None, |
| 119 | + resume=False, |
| 120 | + train_from_0_255=False): |
| 121 | + |
| 122 | + # |
| 123 | + # Prepare load/save directories |
| 124 | + # |
| 125 | + exp_name = 'unet_' + 'data_aug' if bool(data_augmentation) else '' |
| 126 | + |
| 127 | + if savepath is None: |
| 128 | + raise ValueError('A saving directory must be specified') |
| 129 | + |
| 130 | + savepath = os.path.join(savepath, dataset, exp_name) |
| 131 | + # loadpath = os.path.join(loadpath, dataset, exp_name) |
| 132 | + print savepath |
| 133 | + # print loadpath |
| 134 | + |
| 135 | + if not os.path.exists(savepath): |
| 136 | + os.makedirs(savepath) |
| 137 | + else: |
| 138 | + print('\033[93m The following folder already exists {}. ' |
| 139 | + 'It will be overwritten in a few seconds...\033[0m'.format( |
| 140 | + savepath)) |
| 141 | + |
| 142 | + print('Saving directory : ' + savepath) |
| 143 | + with open(os.path.join(savepath, "config.txt"), "w") as f: |
| 144 | + for key, value in locals().items(): |
| 145 | + f.write('{} = {}\n'.format(key, value)) |
| 146 | + |
| 147 | + # |
| 148 | + # Define symbolic variables |
| 149 | + # |
| 150 | + input_var = T.tensor4('input_var') |
| 151 | + target_var = T.ivector('target_var') |
| 152 | + |
| 153 | + # |
| 154 | + # Build dataset iterator |
| 155 | + # |
| 156 | + if batch_size is not None: |
| 157 | + bs = batch_size |
| 158 | + else: |
| 159 | + bs = [10, 1, 1] |
| 160 | + |
| 161 | + train_iter, val_iter, test_iter = \ |
| 162 | + load_data(dataset, data_augmentation, |
| 163 | + one_hot=False, batch_size=bs, return_0_255=train_from_0_255) |
| 164 | + |
| 165 | + n_batches_train = train_iter.nbatches |
| 166 | + n_batches_val = val_iter.nbatches |
| 167 | + n_batches_test = test_iter.nbatches if test_iter is not None else 0 |
| 168 | + n_classes = train_iter.non_void_nclasses |
| 169 | + void_labels = train_iter.void_labels |
| 170 | + nb_in_channels = train_iter.data_shape[0] |
| 171 | + |
| 172 | + print "Batch. train: %d, val %d, test %d" % (n_batches_train, n_batches_val, n_batches_test) |
| 173 | + print "Nb of classes: %d" % (n_classes) |
| 174 | + print "Nb. of input channels: %d" % (nb_in_channels) |
| 175 | + |
| 176 | + # |
| 177 | + # Build network |
| 178 | + # |
| 179 | + convmodel = build_UNet(nb_in_channels, input_var, n_classes=n_classes, |
| 180 | + void_labels=void_labels, trainable=True, |
| 181 | + load_weights=resume, pascal=True, layer=['probs']) |
| 182 | + |
| 183 | + # |
| 184 | + # Define and compile theano functions |
| 185 | + # |
| 186 | + print "Defining and compiling training functions" |
| 187 | + prediction = lasagne.layers.get_output(convmodel)[0] |
| 188 | + loss = crossentropy_metric(prediction, target_var, void_labels) |
| 189 | + |
| 190 | + if weight_decay > 0: |
| 191 | + weightsl2 = regularize_network_params( |
| 192 | + convmodel, lasagne.regularization.l2) |
| 193 | + loss += weight_decay * weightsl2 |
| 194 | + |
| 195 | + params = lasagne.layers.get_all_params(convmodel, trainable=True) |
| 196 | + updates = lasagne.updates.adam(loss, params, learning_rate=learn_step) |
| 197 | + |
| 198 | + train_fn = theano.function([input_var, target_var], loss, updates=updates) |
| 199 | + |
| 200 | + print "Defining and compiling test functions" |
| 201 | + test_prediction = lasagne.layers.get_output(convmodel, deterministic=True)[0] |
| 202 | + test_loss = crossentropy_metric(test_prediction, target_var, void_labels) |
| 203 | + test_acc = accuracy_metric(test_prediction, target_var, void_labels) |
| 204 | + test_jacc = jaccard_metric(test_prediction, target_var, n_classes) |
| 205 | + |
| 206 | + val_fn = theano.function([input_var, target_var], [test_loss, test_acc, test_jacc]) |
| 207 | + |
| 208 | + # |
| 209 | + # Train |
| 210 | + # |
| 211 | + err_train = [] |
| 212 | + err_valid = [] |
| 213 | + acc_valid = [] |
| 214 | + jacc_valid = [] |
| 215 | + patience = 0 |
| 216 | + |
| 217 | + # Training main loop |
| 218 | + print "Start training" |
| 219 | + for epoch in range(num_epochs): |
| 220 | + # Single epoch training and validation |
| 221 | + start_time = time.time() |
| 222 | + cost_train_tot = 0 |
| 223 | + |
| 224 | + # Train |
| 225 | + for i in range(n_batches_train): |
| 226 | + # Get minibatch |
| 227 | + X_train_batch, L_train_batch = train_iter.next() |
| 228 | + L_train_batch = np.reshape(L_train_batch, np.prod(L_train_batch.shape)) |
| 229 | + |
| 230 | + # Training step |
| 231 | + cost_train = train_fn(X_train_batch, L_train_batch) |
| 232 | + out_str = "cost %f" % (cost_train) |
| 233 | + cost_train_tot += cost_train |
| 234 | + |
| 235 | + err_train += [cost_train_tot/n_batches_train] |
| 236 | + |
| 237 | + # Validation |
| 238 | + cost_val_tot = 0 |
| 239 | + acc_val_tot = 0 |
| 240 | + jacc_val_tot = np.zeros((2, n_classes)) |
| 241 | + for i in range(n_batches_val): |
| 242 | + # Get minibatch |
| 243 | + X_val_batch, L_val_batch = val_iter.next() |
| 244 | + L_val_batch = np.reshape(L_val_batch, np.prod(L_val_batch.shape)) |
| 245 | + |
| 246 | + # Validation step |
| 247 | + cost_val, acc_val, jacc_val = val_fn(X_val_batch, L_val_batch) |
| 248 | + |
| 249 | + acc_val_tot += acc_val |
| 250 | + cost_val_tot += cost_val |
| 251 | + jacc_val_tot += jacc_val |
| 252 | + |
| 253 | + err_valid += [cost_val_tot/n_batches_val] |
| 254 | + acc_valid += [acc_val_tot/n_batches_val] |
| 255 | + jacc_perclass_valid = jacc_val_tot[0, :] / jacc_val_tot[1, :] |
| 256 | + if early_stop_class == None: |
| 257 | + jacc_valid += [np.mean(jacc_perclass_valid)] |
| 258 | + else: |
| 259 | + jacc_valid += [jacc_perclass_valid[early_stop_class]] |
| 260 | + |
| 261 | + |
| 262 | + out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f" +\ |
| 263 | + ", acc val %f, jacc val %f took %f s" |
| 264 | + out_str = out_str % (epoch, err_train[epoch], |
| 265 | + err_valid[epoch], |
| 266 | + acc_valid[epoch], |
| 267 | + jacc_valid[epoch], |
| 268 | + time.time()-start_time) |
| 269 | + print out_str |
| 270 | + |
| 271 | + with open(os.path.join(savepath, "unet_output.log"), "a") as f: |
| 272 | + f.write(out_str + "\n") |
| 273 | + |
| 274 | + # Early stopping and saving stuff |
| 275 | + if epoch == 0: |
| 276 | + best_jacc_val = jacc_valid[epoch] |
| 277 | + elif epoch > 1 and jacc_valid[epoch] > best_jacc_val: |
| 278 | + best_jacc_val = jacc_valid[epoch] |
| 279 | + patience = 0 |
| 280 | + np.savez(os.path.join(savepath, 'new_unet_model_best.npz'), *lasagne.layers.get_all_param_values(convmodel)) |
| 281 | + np.savez(os.path.join(savepath + "unet_errors_best.npz"), |
| 282 | + err_valid, err_train, acc_valid, jacc_valid) |
| 283 | + else: |
| 284 | + patience += 1 |
| 285 | + np.savez(os.path.join(savepath, 'new_unet_model_last.npz'), *lasagne.layers.get_all_param_values(convmodel)) |
| 286 | + np.savez(os.path.join(savepath + "unet_errors_last.npz"), |
| 287 | + err_valid, err_train, acc_valid, jacc_valid) |
| 288 | + # Finish training if patience has expired or max nber of epochs |
| 289 | + # reached |
| 290 | + if patience == max_patience or epoch == num_epochs-1: |
| 291 | + if test_iter is not None: |
| 292 | + # Load best model weights |
| 293 | + with np.load(os.path.join(savepath, 'new_unet_model_best.npz')) as f: |
| 294 | + param_values = [f['arr_%d' % i] for i in range(len(f.files))] |
| 295 | + nlayers = len(lasagne.layers.get_all_params(convmodel)) |
| 296 | + lasagne.layers.set_all_param_values(convmodel, param_values[:nlayers]) |
| 297 | + # Test |
| 298 | + cost_test_tot = 0 |
| 299 | + acc_test_tot = 0 |
| 300 | + jacc_num_test_tot = np.zeros((1, n_classes)) |
| 301 | + jacc_denom_test_tot = np.zeros((1, n_classes)) |
| 302 | + for i in range(n_batches_test): |
| 303 | + # Get minibatch |
| 304 | + X_test_batch, L_test_batch = test_iter.next() |
| 305 | + L_test_batch = np.reshape(L_test_batch, np.prod(L_test_batch.shape)) |
| 306 | + |
| 307 | + # Test step |
| 308 | + cost_test, acc_test, jacc_test = val_fn(X_test_batch, L_test_batch) |
| 309 | + jacc_num_test, jacc_denom_test = jacc_test |
| 310 | + |
| 311 | + acc_test_tot += acc_test |
| 312 | + cost_test_tot += cost_test |
| 313 | + jacc_num_test_tot += jacc_num_test |
| 314 | + jacc_denom_test_tot += jacc_denom_test |
| 315 | + |
| 316 | + err_test = cost_test_tot/n_batches_test |
| 317 | + acc_test = acc_test_tot/n_batches_test |
| 318 | + jacc_test = np.mean(jacc_num_test_tot / jacc_denom_test_tot) |
| 319 | + |
| 320 | + out_str = "FINAL MODEL: err test % f, acc test %f, jacc test %f" |
| 321 | + out_str = out_str % (err_test, |
| 322 | + acc_test, |
| 323 | + jacc_test) |
| 324 | + print out_str |
| 325 | + # if savepath != loadpath: |
| 326 | + # print('Copying model and other training files to {}'.format(loadpath)) |
| 327 | + # copy_tree(savepath, loadpath) |
| 328 | + |
| 329 | + # End |
| 330 | + return |
| 331 | + |
| 332 | + |
| 333 | +def main(): |
| 334 | + parser = argparse.ArgumentParser(description='U-Net model training') |
| 335 | + parser.add_argument('-dataset', |
| 336 | + default='em', |
| 337 | + help='Dataset.') |
| 338 | + parser.add_argument('-learning_rate', |
| 339 | + default=0.0001, |
| 340 | + help='Learning Rate') |
| 341 | + parser.add_argument('-penal_cst', |
| 342 | + default=0.0, |
| 343 | + help='regularization constant') |
| 344 | + parser.add_argument('--num_epochs', |
| 345 | + '-ne', |
| 346 | + type=int, |
| 347 | + default=750, |
| 348 | + help='Optional. Int to indicate the max' |
| 349 | + 'number of epochs.') |
| 350 | + parser.add_argument('-max_patience', |
| 351 | + type=int, |
| 352 | + default=100, |
| 353 | + help='Max patience') |
| 354 | + parser.add_argument('-batch_size', |
| 355 | + type=int, |
| 356 | + default=[10, 1, 1], |
| 357 | + help='Batch size [train, val, test]') |
| 358 | + parser.add_argument('-data_augmentation', |
| 359 | + type=dict, |
| 360 | + default={'crop_size': (224, 224), 'horizontal_flip': True, 'fill_mode':'constant'}, |
| 361 | + help='use data augmentation') |
| 362 | + parser.add_argument('-early_stop_class', |
| 363 | + type=int, |
| 364 | + default=None, |
| 365 | + help='class to early stop on') |
| 366 | + parser.add_argument('-train_from_0_255', |
| 367 | + type=bool, |
| 368 | + default=False, |
| 369 | + help='Whether to train from images within 0-255 range') |
| 370 | + args = parser.parse_args() |
| 371 | + |
| 372 | + train(args.dataset, float(args.learning_rate), |
| 373 | + float(args.penal_cst), int(args.num_epochs), int(args.max_patience), |
| 374 | + data_augmentation=args.data_augmentation, batch_size=args.batch_size, |
| 375 | + early_stop_class=args.early_stop_class, savepath=SAVEPATH, |
| 376 | + train_from_0_255=args.train_from_0_255, loadpath=LOADPATH) |
| 377 | + |
| 378 | +if __name__ == "__main__": |
| 379 | + main() |
0 commit comments