Skip to content

Commit 8285a9f

Browse files
StephanieLarocquenotoraptor
authored andcommitted
first commit
1 parent 2026c59 commit 8285a9f

File tree

1 file changed

+379
-0
lines changed

1 file changed

+379
-0
lines changed

code/unet/train_unet.py

Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
#!/usr/bin/env python2
2+
3+
import os
4+
import argparse
5+
import time
6+
from getpass import getuser
7+
from distutils.dir_util import copy_tree
8+
9+
import numpy as np
10+
import theano
11+
import theano.tensor as T
12+
from theano import config
13+
import lasagne
14+
from lasagne.regularization import regularize_network_params
15+
16+
from data_loader import load_data
17+
from Unet_lasagne_recipes import build_UNet
18+
# from metrics import jaccard, accuracy, crossentropy
19+
20+
21+
_FLOATX = config.floatX
22+
_EPSILON = 10e-7
23+
24+
25+
def jaccard_metric(y_pred, y_true, n_classes, one_hot=False):
26+
27+
assert (y_pred.ndim == 2) or (y_pred.ndim == 1)
28+
29+
# y_pred to indices
30+
if y_pred.ndim == 2:
31+
y_pred = T.argmax(y_pred, axis=1)
32+
33+
if one_hot:
34+
y_true = T.argmax(y_true, axis=1)
35+
36+
# Compute confusion matrix
37+
cm = T.zeros((n_classes, n_classes))
38+
for i in range(n_classes):
39+
for j in range(n_classes):
40+
cm = T.set_subtensor(
41+
cm[i, j], T.sum(T.eq(y_pred, i) * T.eq(y_true, j)))
42+
43+
# Compute Jaccard Index
44+
TP_perclass = T.cast(cm.diagonal(), _FLOATX)
45+
FP_perclass = cm.sum(1) - TP_perclass
46+
FN_perclass = cm.sum(0) - TP_perclass
47+
48+
num = TP_perclass
49+
denom = TP_perclass + FP_perclass + FN_perclass
50+
51+
return T.stack([num, denom], axis=0)
52+
53+
54+
def accuracy_metric(y_pred, y_true, void_labels, one_hot=False):
55+
56+
assert (y_pred.ndim == 2) or (y_pred.ndim == 1)
57+
58+
# y_pred to indices
59+
if y_pred.ndim == 2:
60+
y_pred = T.argmax(y_pred, axis=1)
61+
62+
if one_hot:
63+
y_true = T.argmax(y_true, axis=1)
64+
65+
# Compute accuracy
66+
acc = T.eq(y_pred, y_true).astype(_FLOATX)
67+
68+
# Create mask
69+
mask = T.ones_like(y_true, dtype=_FLOATX)
70+
for el in void_labels:
71+
indices = T.eq(y_true, el).nonzero()
72+
if any(indices):
73+
mask = T.set_subtensor(mask[indices], 0.)
74+
75+
# Apply mask
76+
acc *= mask
77+
acc = T.sum(acc) / T.sum(mask)
78+
79+
return acc
80+
81+
82+
def crossentropy_metric(y_pred, y_true, void_labels, one_hot=False):
83+
# Clip predictions
84+
y_pred = T.clip(y_pred, _EPSILON, 1.0 - _EPSILON)
85+
86+
if one_hot:
87+
y_true = T.argmax(y_true, axis=1)
88+
89+
# Create mask
90+
mask = T.ones_like(y_true, dtype=_FLOATX)
91+
for el in void_labels:
92+
mask = T.set_subtensor(mask[T.eq(y_true, el).nonzero()], 0.)
93+
94+
# Modify y_true temporarily
95+
y_true_tmp = y_true * mask
96+
y_true_tmp = y_true_tmp.astype('int32')
97+
98+
# Compute cross-entropy
99+
loss = T.nnet.categorical_crossentropy(y_pred, y_true_tmp)
100+
101+
# Compute masked mean loss
102+
loss *= mask
103+
loss = T.sum(loss) / T.sum(mask)
104+
105+
return loss
106+
107+
108+
SAVEPATH = 'save_models/'
109+
LOADPATH = SAVEPATH
110+
WEIGHTS_PATH = SAVEPATH
111+
112+
113+
def train(dataset, learn_step=0.005,
114+
weight_decay=1e-4, num_epochs=500,
115+
max_patience=100, data_augmentation={},
116+
savepath=None, loadpath=None,
117+
early_stop_class=None,
118+
batch_size=None,
119+
resume=False,
120+
train_from_0_255=False):
121+
122+
#
123+
# Prepare load/save directories
124+
#
125+
exp_name = 'unet_' + 'data_aug' if bool(data_augmentation) else ''
126+
127+
if savepath is None:
128+
raise ValueError('A saving directory must be specified')
129+
130+
savepath = os.path.join(savepath, dataset, exp_name)
131+
# loadpath = os.path.join(loadpath, dataset, exp_name)
132+
print savepath
133+
# print loadpath
134+
135+
if not os.path.exists(savepath):
136+
os.makedirs(savepath)
137+
else:
138+
print('\033[93m The following folder already exists {}. '
139+
'It will be overwritten in a few seconds...\033[0m'.format(
140+
savepath))
141+
142+
print('Saving directory : ' + savepath)
143+
with open(os.path.join(savepath, "config.txt"), "w") as f:
144+
for key, value in locals().items():
145+
f.write('{} = {}\n'.format(key, value))
146+
147+
#
148+
# Define symbolic variables
149+
#
150+
input_var = T.tensor4('input_var')
151+
target_var = T.ivector('target_var')
152+
153+
#
154+
# Build dataset iterator
155+
#
156+
if batch_size is not None:
157+
bs = batch_size
158+
else:
159+
bs = [10, 1, 1]
160+
161+
train_iter, val_iter, test_iter = \
162+
load_data(dataset, data_augmentation,
163+
one_hot=False, batch_size=bs, return_0_255=train_from_0_255)
164+
165+
n_batches_train = train_iter.nbatches
166+
n_batches_val = val_iter.nbatches
167+
n_batches_test = test_iter.nbatches if test_iter is not None else 0
168+
n_classes = train_iter.non_void_nclasses
169+
void_labels = train_iter.void_labels
170+
nb_in_channels = train_iter.data_shape[0]
171+
172+
print "Batch. train: %d, val %d, test %d" % (n_batches_train, n_batches_val, n_batches_test)
173+
print "Nb of classes: %d" % (n_classes)
174+
print "Nb. of input channels: %d" % (nb_in_channels)
175+
176+
#
177+
# Build network
178+
#
179+
convmodel = build_UNet(nb_in_channels, input_var, n_classes=n_classes,
180+
void_labels=void_labels, trainable=True,
181+
load_weights=resume, pascal=True, layer=['probs'])
182+
183+
#
184+
# Define and compile theano functions
185+
#
186+
print "Defining and compiling training functions"
187+
prediction = lasagne.layers.get_output(convmodel)[0]
188+
loss = crossentropy_metric(prediction, target_var, void_labels)
189+
190+
if weight_decay > 0:
191+
weightsl2 = regularize_network_params(
192+
convmodel, lasagne.regularization.l2)
193+
loss += weight_decay * weightsl2
194+
195+
params = lasagne.layers.get_all_params(convmodel, trainable=True)
196+
updates = lasagne.updates.adam(loss, params, learning_rate=learn_step)
197+
198+
train_fn = theano.function([input_var, target_var], loss, updates=updates)
199+
200+
print "Defining and compiling test functions"
201+
test_prediction = lasagne.layers.get_output(convmodel, deterministic=True)[0]
202+
test_loss = crossentropy_metric(test_prediction, target_var, void_labels)
203+
test_acc = accuracy_metric(test_prediction, target_var, void_labels)
204+
test_jacc = jaccard_metric(test_prediction, target_var, n_classes)
205+
206+
val_fn = theano.function([input_var, target_var], [test_loss, test_acc, test_jacc])
207+
208+
#
209+
# Train
210+
#
211+
err_train = []
212+
err_valid = []
213+
acc_valid = []
214+
jacc_valid = []
215+
patience = 0
216+
217+
# Training main loop
218+
print "Start training"
219+
for epoch in range(num_epochs):
220+
# Single epoch training and validation
221+
start_time = time.time()
222+
cost_train_tot = 0
223+
224+
# Train
225+
for i in range(n_batches_train):
226+
# Get minibatch
227+
X_train_batch, L_train_batch = train_iter.next()
228+
L_train_batch = np.reshape(L_train_batch, np.prod(L_train_batch.shape))
229+
230+
# Training step
231+
cost_train = train_fn(X_train_batch, L_train_batch)
232+
out_str = "cost %f" % (cost_train)
233+
cost_train_tot += cost_train
234+
235+
err_train += [cost_train_tot/n_batches_train]
236+
237+
# Validation
238+
cost_val_tot = 0
239+
acc_val_tot = 0
240+
jacc_val_tot = np.zeros((2, n_classes))
241+
for i in range(n_batches_val):
242+
# Get minibatch
243+
X_val_batch, L_val_batch = val_iter.next()
244+
L_val_batch = np.reshape(L_val_batch, np.prod(L_val_batch.shape))
245+
246+
# Validation step
247+
cost_val, acc_val, jacc_val = val_fn(X_val_batch, L_val_batch)
248+
249+
acc_val_tot += acc_val
250+
cost_val_tot += cost_val
251+
jacc_val_tot += jacc_val
252+
253+
err_valid += [cost_val_tot/n_batches_val]
254+
acc_valid += [acc_val_tot/n_batches_val]
255+
jacc_perclass_valid = jacc_val_tot[0, :] / jacc_val_tot[1, :]
256+
if early_stop_class == None:
257+
jacc_valid += [np.mean(jacc_perclass_valid)]
258+
else:
259+
jacc_valid += [jacc_perclass_valid[early_stop_class]]
260+
261+
262+
out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f" +\
263+
", acc val %f, jacc val %f took %f s"
264+
out_str = out_str % (epoch, err_train[epoch],
265+
err_valid[epoch],
266+
acc_valid[epoch],
267+
jacc_valid[epoch],
268+
time.time()-start_time)
269+
print out_str
270+
271+
with open(os.path.join(savepath, "unet_output.log"), "a") as f:
272+
f.write(out_str + "\n")
273+
274+
# Early stopping and saving stuff
275+
if epoch == 0:
276+
best_jacc_val = jacc_valid[epoch]
277+
elif epoch > 1 and jacc_valid[epoch] > best_jacc_val:
278+
best_jacc_val = jacc_valid[epoch]
279+
patience = 0
280+
np.savez(os.path.join(savepath, 'new_unet_model_best.npz'), *lasagne.layers.get_all_param_values(convmodel))
281+
np.savez(os.path.join(savepath + "unet_errors_best.npz"),
282+
err_valid, err_train, acc_valid, jacc_valid)
283+
else:
284+
patience += 1
285+
np.savez(os.path.join(savepath, 'new_unet_model_last.npz'), *lasagne.layers.get_all_param_values(convmodel))
286+
np.savez(os.path.join(savepath + "unet_errors_last.npz"),
287+
err_valid, err_train, acc_valid, jacc_valid)
288+
# Finish training if patience has expired or max nber of epochs
289+
# reached
290+
if patience == max_patience or epoch == num_epochs-1:
291+
if test_iter is not None:
292+
# Load best model weights
293+
with np.load(os.path.join(savepath, 'new_unet_model_best.npz')) as f:
294+
param_values = [f['arr_%d' % i] for i in range(len(f.files))]
295+
nlayers = len(lasagne.layers.get_all_params(convmodel))
296+
lasagne.layers.set_all_param_values(convmodel, param_values[:nlayers])
297+
# Test
298+
cost_test_tot = 0
299+
acc_test_tot = 0
300+
jacc_num_test_tot = np.zeros((1, n_classes))
301+
jacc_denom_test_tot = np.zeros((1, n_classes))
302+
for i in range(n_batches_test):
303+
# Get minibatch
304+
X_test_batch, L_test_batch = test_iter.next()
305+
L_test_batch = np.reshape(L_test_batch, np.prod(L_test_batch.shape))
306+
307+
# Test step
308+
cost_test, acc_test, jacc_test = val_fn(X_test_batch, L_test_batch)
309+
jacc_num_test, jacc_denom_test = jacc_test
310+
311+
acc_test_tot += acc_test
312+
cost_test_tot += cost_test
313+
jacc_num_test_tot += jacc_num_test
314+
jacc_denom_test_tot += jacc_denom_test
315+
316+
err_test = cost_test_tot/n_batches_test
317+
acc_test = acc_test_tot/n_batches_test
318+
jacc_test = np.mean(jacc_num_test_tot / jacc_denom_test_tot)
319+
320+
out_str = "FINAL MODEL: err test % f, acc test %f, jacc test %f"
321+
out_str = out_str % (err_test,
322+
acc_test,
323+
jacc_test)
324+
print out_str
325+
# if savepath != loadpath:
326+
# print('Copying model and other training files to {}'.format(loadpath))
327+
# copy_tree(savepath, loadpath)
328+
329+
# End
330+
return
331+
332+
333+
def main():
334+
parser = argparse.ArgumentParser(description='U-Net model training')
335+
parser.add_argument('-dataset',
336+
default='em',
337+
help='Dataset.')
338+
parser.add_argument('-learning_rate',
339+
default=0.0001,
340+
help='Learning Rate')
341+
parser.add_argument('-penal_cst',
342+
default=0.0,
343+
help='regularization constant')
344+
parser.add_argument('--num_epochs',
345+
'-ne',
346+
type=int,
347+
default=750,
348+
help='Optional. Int to indicate the max'
349+
'number of epochs.')
350+
parser.add_argument('-max_patience',
351+
type=int,
352+
default=100,
353+
help='Max patience')
354+
parser.add_argument('-batch_size',
355+
type=int,
356+
default=[10, 1, 1],
357+
help='Batch size [train, val, test]')
358+
parser.add_argument('-data_augmentation',
359+
type=dict,
360+
default={'crop_size': (224, 224), 'horizontal_flip': True, 'fill_mode':'constant'},
361+
help='use data augmentation')
362+
parser.add_argument('-early_stop_class',
363+
type=int,
364+
default=None,
365+
help='class to early stop on')
366+
parser.add_argument('-train_from_0_255',
367+
type=bool,
368+
default=False,
369+
help='Whether to train from images within 0-255 range')
370+
args = parser.parse_args()
371+
372+
train(args.dataset, float(args.learning_rate),
373+
float(args.penal_cst), int(args.num_epochs), int(args.max_patience),
374+
data_augmentation=args.data_augmentation, batch_size=args.batch_size,
375+
early_stop_class=args.early_stop_class, savepath=SAVEPATH,
376+
train_from_0_255=args.train_from_0_255, loadpath=LOADPATH)
377+
378+
if __name__ == "__main__":
379+
main()

0 commit comments

Comments
 (0)