@@ -213,6 +213,10 @@ def train(dataset, learn_step=0.005,
213213 jacc_valid = []
214214 patience = 0
215215
216+ n_batches_train = 1
217+ n_batches_val = 1
218+ n_batches_test = 1
219+ num_epochs = 1
216220 # Training main loop
217221 print "Start training"
218222 for epoch in range (num_epochs ):
@@ -222,10 +226,12 @@ def train(dataset, learn_step=0.005,
222226
223227 # Train
224228 for i in range (n_batches_train ):
229+ print 'Training batch ' , i
225230 # Get minibatch
226231 X_train_batch , L_train_batch = train_iter .next ()
227232 L_train_batch = np .reshape (L_train_batch , np .prod (L_train_batch .shape ))
228233
234+
229235 # Training step
230236 cost_train = train_fn (X_train_batch , L_train_batch )
231237 out_str = "cost %f" % (cost_train )
@@ -238,6 +244,7 @@ def train(dataset, learn_step=0.005,
238244 acc_val_tot = 0
239245 jacc_val_tot = np .zeros ((2 , n_classes ))
240246 for i in range (n_batches_val ):
247+ print 'Valid batch ' , i
241248 # Get minibatch
242249 X_val_batch , L_val_batch = val_iter .next ()
243250 L_val_batch = np .reshape (L_val_batch , np .prod (L_val_batch .shape ))
@@ -277,13 +284,12 @@ def train(dataset, learn_step=0.005,
277284 best_jacc_val = jacc_valid [epoch ]
278285 patience = 0
279286 np .savez (os .path .join (savepath , 'new_fcn8_model_best.npz' ), * lasagne .layers .get_all_param_values (convmodel ))
280- np .savez (os .path .join (savepath + "fcn8_errors_best.npz" ),
281- err_valid , err_train , acc_valid , jacc_valid )
287+ np .savez (os .path .join (savepath , "fcn8_errors_best.npz" ), err_valid , err_train , acc_valid , jacc_valid )
282288 else :
283289 patience += 1
284- np . savez ( os . path . join ( savepath , 'new_fcn8_model_last.npz' ), * lasagne . layers . get_all_param_values ( convmodel ))
285- np .savez (os .path .join (savepath + "fcn8_errors_last .npz" ),
286- err_valid , err_train , acc_valid , jacc_valid )
290+
291+ np .savez (os .path .join (savepath , 'new_fcn8_model_last .npz' ), * lasagne . layers . get_all_param_values ( convmodel ))
292+ np . savez ( os . path . join ( savepath , "fcn8_errors_last.npz" ), err_valid , err_train , acc_valid , jacc_valid )
287293 # Finish training if patience has expired or max nber of epochs
288294 # reached
289295 if patience == max_patience or epoch == num_epochs - 1 :
0 commit comments