@@ -34,6 +34,7 @@ def jaccard_metric(y_pred, y_true, n_classes, one_hot=False):
3434 y_true = T .argmax (y_true , axis = 1 )
3535
3636 # Compute confusion matrix
37+ # cm = T.nnet.confusion_matrix(y_pred, y_true)
3738 cm = T .zeros ((n_classes , n_classes ))
3839 for i in range (n_classes ):
3940 for j in range (n_classes ):
@@ -156,12 +157,16 @@ def train(dataset, learn_step=0.005,
156157 if batch_size is not None :
157158 bs = batch_size
158159 else :
159- bs = [10 , 1 , 1 ]
160+ bs = [1 , 1 , 1 ]
160161
161162 train_iter , val_iter , test_iter = \
162163 load_data (dataset , data_augmentation ,
163164 one_hot = False , batch_size = bs , return_0_255 = train_from_0_255 )
164165
166+ batch = train_iter .next ()
167+ input_dim = (np .shape (batch [0 ])[2 ], np .shape (batch [0 ])[3 ])
168+ print 'batch size ' , np .shape (batch [0 ])
169+
165170 n_batches_train = train_iter .nbatches
166171 n_batches_val = val_iter .nbatches
167172 n_batches_test = test_iter .nbatches if test_iter is not None else 0
@@ -176,29 +181,30 @@ def train(dataset, learn_step=0.005,
176181 #
177182 # Build network
178183 #
179- convmodel = build_UNet (nb_in_channels , input_var , n_classes = n_classes ,
180- void_labels = void_labels , trainable = True ,
181- load_weights = resume , pascal = True , layer = ['probs' ])
182184
185+ net = build_UNet (n_input_channels = nb_in_channels ,# BATCH_SIZE = batch_size,
186+ num_output_classes = n_classes , base_n_filters = 64 , do_dropout = False ,
187+ input_dim = input_dim ) #(512,512))
188+
189+ output_layer = net ["output_flattened" ]
183190 #
184191 # Define and compile theano functions
185192 #
186193 print "Defining and compiling training functions"
187- prediction = lasagne .layers .get_output (convmodel )[ 0 ]
194+ prediction = lasagne .layers .get_output (output_layer , input_var )
188195 loss = crossentropy_metric (prediction , target_var , void_labels )
189196
190197 if weight_decay > 0 :
191- weightsl2 = regularize_network_params (
192- convmodel , lasagne .regularization .l2 )
198+ weightsl2 = regularize_network_params (output_layer , lasagne .regularization .l2 )
193199 loss += weight_decay * weightsl2
194200
195- params = lasagne .layers .get_all_params (convmodel , trainable = True )
201+ params = lasagne .layers .get_all_params (output_layer , trainable = True )
196202 updates = lasagne .updates .adam (loss , params , learning_rate = learn_step )
197203
198204 train_fn = theano .function ([input_var , target_var ], loss , updates = updates )
199205
200206 print "Defining and compiling test functions"
201- test_prediction = lasagne .layers .get_output (convmodel , deterministic = True )[ 0 ]
207+ test_prediction = lasagne .layers .get_output (output_layer , input_var , deterministic = True )
202208 test_loss = crossentropy_metric (test_prediction , target_var , void_labels )
203209 test_acc = accuracy_metric (test_prediction , target_var , void_labels )
204210 test_jacc = jaccard_metric (test_prediction , target_var , n_classes )
@@ -221,8 +227,12 @@ def train(dataset, learn_step=0.005,
221227 start_time = time .time ()
222228 cost_train_tot = 0
223229
230+ n_batches_train = 2
231+ n_batches_val = 2
224232 # Train
233+ print 'Training steps '
225234 for i in range (n_batches_train ):
235+ print i
226236 # Get minibatch
227237 X_train_batch , L_train_batch = train_iter .next ()
228238 L_train_batch = np .reshape (L_train_batch , np .prod (L_train_batch .shape ))
@@ -238,7 +248,10 @@ def train(dataset, learn_step=0.005,
238248 cost_val_tot = 0
239249 acc_val_tot = 0
240250 jacc_val_tot = np .zeros ((2 , n_classes ))
251+
252+ print 'Validation steps'
241253 for i in range (n_batches_val ):
254+ print i
242255 # Get minibatch
243256 X_val_batch , L_val_batch = val_iter .next ()
244257 L_val_batch = np .reshape (L_val_batch , np .prod (L_val_batch .shape ))
@@ -277,12 +290,12 @@ def train(dataset, learn_step=0.005,
277290 elif epoch > 1 and jacc_valid [epoch ] > best_jacc_val :
278291 best_jacc_val = jacc_valid [epoch ]
279292 patience = 0
280- np .savez (os .path .join (savepath , 'new_unet_model_best.npz' ), * lasagne .layers .get_all_param_values (convmodel ))
293+ np .savez (os .path .join (savepath , 'new_unet_model_best.npz' ), * lasagne .layers .get_all_param_values (output_layer ))
281294 np .savez (os .path .join (savepath + "unet_errors_best.npz" ),
282295 err_valid , err_train , acc_valid , jacc_valid )
283296 else :
284297 patience += 1
285- np .savez (os .path .join (savepath , 'new_unet_model_last.npz' ), * lasagne .layers .get_all_param_values (convmodel ))
298+ np .savez (os .path .join (savepath , 'new_unet_model_last.npz' ), * lasagne .layers .get_all_param_values (output_layer ))
286299 np .savez (os .path .join (savepath + "unet_errors_last.npz" ),
287300 err_valid , err_train , acc_valid , jacc_valid )
288301 # Finish training if patience has expired or max nber of epochs
@@ -292,8 +305,8 @@ def train(dataset, learn_step=0.005,
292305 # Load best model weights
293306 with np .load (os .path .join (savepath , 'new_unet_model_best.npz' )) as f :
294307 param_values = [f ['arr_%d' % i ] for i in range (len (f .files ))]
295- nlayers = len (lasagne .layers .get_all_params (convmodel ))
296- lasagne .layers .set_all_param_values (convmodel , param_values [:nlayers ])
308+ nlayers = len (lasagne .layers .get_all_params (output_layer ))
309+ lasagne .layers .set_all_param_values (output_layer , param_values [:nlayers ])
297310 # Test
298311 cost_test_tot = 0
299312 acc_test_tot = 0
@@ -318,13 +331,11 @@ def train(dataset, learn_step=0.005,
318331 jacc_test = np .mean (jacc_num_test_tot / jacc_denom_test_tot )
319332
320333 out_str = "FINAL MODEL: err test % f, acc test %f, jacc test %f"
321- out_str = out_str % (err_test ,
322- acc_test ,
323- jacc_test )
334+ out_str = out_str % (err_test , acc_test , jacc_test )
324335 print out_str
325- # if savepath != loadpath:
326- # print('Copying model and other training files to {}'.format(loadpath))
327- # copy_tree(savepath, loadpath)
336+ if savepath != loadpath :
337+ print ('Copying model and other training files to {}' .format (loadpath ))
338+ copy_tree (savepath , loadpath )
328339
329340 # End
330341 return
@@ -353,11 +364,18 @@ def main():
353364 help = 'Max patience' )
354365 parser .add_argument ('-batch_size' ,
355366 type = int ,
356- default = [10 , 1 , 1 ],
367+ default = [1 , 1 , 1 ],
357368 help = 'Batch size [train, val, test]' )
358369 parser .add_argument ('-data_augmentation' ,
359370 type = dict ,
360- default = {'crop_size' : (224 , 224 ), 'horizontal_flip' : True , 'fill_mode' :'constant' },
371+ default = {'rotation_range' :25 ,
372+ 'shear_range' :0.41 ,
373+ 'horizontal_flip' :True ,
374+ 'vertical_flip' :True ,
375+ 'fill_mode' :'reflect' ,
376+ 'spline_warp' :True ,
377+ 'warp_sigma' :10 ,
378+ 'warp_grid_size' :3 },
361379 help = 'use data augmentation' )
362380 parser .add_argument ('-early_stop_class' ,
363381 type = int ,
0 commit comments