@@ -293,10 +293,12 @@ def train(dataset, learn_step=0.005,
293293
294294
295295 out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f" + \
296- ", acc val %f, jacc val %f took %f s"
296+ ", acc val %f, jacc val class 0 % f, jacc val class 1 %f, jacc val %f took %f s"
297297 out_str = out_str % (epoch , err_train [epoch ],
298298 err_valid [epoch ],
299299 acc_valid [epoch ],
300+ jacc_perclass_valid [0 ],
301+ jacc_perclass_valid [1 ],
300302 jacc_valid [epoch ],
301303 time .time ()- start_time )
302304 print out_str
@@ -329,28 +331,28 @@ def train(dataset, learn_step=0.005,
329331 # Test
330332 cost_test_tot = 0
331333 acc_test_tot = 0
332- jacc_num_test_tot = np .zeros ((1 , n_classes ))
333- jacc_denom_test_tot = np .zeros ((1 , n_classes ))
334+ jacc_test_tot = np .zeros ((2 , n_classes ))
334335 for i in range (n_batches_test ):
335336 # Get minibatch
336337 X_test_batch , L_test_batch = test_iter .next ()
337338 L_test_batch = np .reshape (L_test_batch , np .prod (L_test_batch .shape ))
338339
339340 # Test step
340341 cost_test , acc_test , jacc_test = val_fn (X_test_batch , L_test_batch )
341- jacc_num_test , jacc_denom_test = jacc_test
342342
343343 acc_test_tot += acc_test
344344 cost_test_tot += cost_test
345- jacc_num_test_tot += jacc_num_test
346- jacc_denom_test_tot += jacc_denom_test
345+ jacc_test_tot += jacc_test
347346
348347 err_test = cost_test_tot / n_batches_test
349348 acc_test = acc_test_tot / n_batches_test
350- jacc_test = np .mean (jacc_num_test_tot / jacc_denom_test_tot )
349+ jacc_test_perclass = jacc_test_tot [0 , :] / jacc_test_tot [1 , :]
350+ jacc_test = np .mean (jacc_test_perclass )
351351
352- out_str = "FINAL MODEL: err test % f, acc test %f, jacc test %f"
353- out_str = out_str % (err_test , acc_test , jacc_test )
352+ out_str = "FINAL MODEL: err test % f, acc test %f, " + \
353+ "jacc test class 0 %f, jacc test class 1 %f, jacc test %f"
354+ out_str = out_str % (err_test , acc_test , jacc_test_perclass [0 ],
355+ jacc_test_perclass [1 ], jacc_test )
354356 print out_str
355357 if savepath != loadpath :
356358 print ('Copying model and other training files to {}' .format (loadpath ))
0 commit comments