Skip to content

Commit 28fa556

Browse files
adri-romsornotoraptor
authored andcommitted
added per class jaccard
1 parent e0d53ba commit 28fa556

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

code/fcn_2D_segm/train_fcn8.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,12 @@ def train(dataset, learn_step=0.005,
299299

300300

301301
out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f" +\
302-
", acc val %f, jacc val %f took %f s"
302+
", acc val %f, jacc val class 0 %f, jacc val class 1 %f, jacc val %f took %f s"
303303
out_str = out_str % (epoch, err_train[epoch],
304304
err_valid[epoch],
305305
acc_valid[epoch],
306+
jacc_perclass_valid[0],
307+
jacc_perclass_valid[1],
306308
jacc_valid[epoch],
307309
time.time()-start_time)
308310
print out_str
@@ -335,29 +337,30 @@ def train(dataset, learn_step=0.005,
335337
# Test
336338
cost_test_tot = 0
337339
acc_test_tot = 0
338-
jacc_num_test_tot = np.zeros((1, n_classes))
339-
jacc_denom_test_tot = np.zeros((1, n_classes))
340+
jacc_test_tot = np.zeros((2, n_classes))
340341
for i in range(n_batches_test):
341342
# Get minibatch
342343
X_test_batch, L_test_batch = test_iter.next()
343344
L_test_batch = np.reshape(L_test_batch, np.prod(L_test_batch.shape))
344345

345346
# Test step
346347
cost_test, acc_test, jacc_test = val_fn(X_test_batch, L_test_batch)
347-
jacc_num_test, jacc_denom_test = jacc_test
348348

349349
acc_test_tot += acc_test
350350
cost_test_tot += cost_test
351-
jacc_num_test_tot += jacc_num_test
352-
jacc_denom_test_tot += jacc_denom_test
351+
jacc_test_tot += jacc_test
353352

354353
err_test = cost_test_tot/n_batches_test
355354
acc_test = acc_test_tot/n_batches_test
356-
jacc_test = np.mean(jacc_num_test_tot / jacc_denom_test_tot)
355+
jacc_test_perclass = jacc_test_tot[0, :] / jacc_test_tot[1, :]
356+
jacc_test = np.mean(jacc_test_perclass)
357357

358-
out_str = "FINAL MODEL: err test % f, acc test %f, jacc test %f"
358+
out_str = "FINAL MODEL: err test % f, acc test %f"
359+
out_str += "jacc test class 0 % f, jacc test class 1 %f, jacc test %f"
359360
out_str = out_str % (err_test,
360361
acc_test,
362+
jacc_test_perclass[0],
363+
jacc_test_perclass[1],
361364
jacc_test)
362365
print out_str
363366
# if savepath != loadpath:

code/unet/train_unet.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,12 @@ def train(dataset, learn_step=0.005,
293293

294294

295295
out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f" +\
296-
", acc val %f, jacc val %f took %f s"
296+
", acc val %f, jacc val class 0 % f, jacc val class 1 %f, jacc val %f took %f s"
297297
out_str = out_str % (epoch, err_train[epoch],
298298
err_valid[epoch],
299299
acc_valid[epoch],
300+
jacc_perclass_valid[0],
301+
jacc_perclass_valid[1],
300302
jacc_valid[epoch],
301303
time.time()-start_time)
302304
print out_str
@@ -329,28 +331,28 @@ def train(dataset, learn_step=0.005,
329331
# Test
330332
cost_test_tot = 0
331333
acc_test_tot = 0
332-
jacc_num_test_tot = np.zeros((1, n_classes))
333-
jacc_denom_test_tot = np.zeros((1, n_classes))
334+
jacc_test_tot = np.zeros((2, n_classes))
334335
for i in range(n_batches_test):
335336
# Get minibatch
336337
X_test_batch, L_test_batch = test_iter.next()
337338
L_test_batch = np.reshape(L_test_batch, np.prod(L_test_batch.shape))
338339

339340
# Test step
340341
cost_test, acc_test, jacc_test = val_fn(X_test_batch, L_test_batch)
341-
jacc_num_test, jacc_denom_test = jacc_test
342342

343343
acc_test_tot += acc_test
344344
cost_test_tot += cost_test
345-
jacc_num_test_tot += jacc_num_test
346-
jacc_denom_test_tot += jacc_denom_test
345+
jacc_test_tot += jacc_test
347346

348347
err_test = cost_test_tot/n_batches_test
349348
acc_test = acc_test_tot/n_batches_test
350-
jacc_test = np.mean(jacc_num_test_tot / jacc_denom_test_tot)
349+
jacc_test_perclass = jacc_test_tot[0, :] / jacc_test_tot[1, :]
350+
jacc_test = np.mean(jacc_test_perclass)
351351

352-
out_str = "FINAL MODEL: err test % f, acc test %f, jacc test %f"
353-
out_str = out_str % (err_test, acc_test, jacc_test)
352+
out_str = "FINAL MODEL: err test % f, acc test %f, " +\
353+
"jacc test class 0 %f, jacc test class 1 %f, jacc test %f"
354+
out_str = out_str % (err_test, acc_test, jacc_test_perclass[0],
355+
jacc_test_perclass[1], jacc_test)
354356
print out_str
355357
if savepath != loadpath:
356358
print('Copying model and other training files to {}'.format(loadpath))

0 commit comments

Comments
 (0)