Skip to content

Commit 119d10e

Browse files
StephanieLarocquenotoraptor
authored andcommitted
remove loading pretrained weights
1 parent ca52c8a commit 119d10e

File tree

1 file changed

+0
-49
lines changed

1 file changed

+0
-49
lines changed

code/fcn_2D_segm/fcn8.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -130,55 +130,6 @@ def buildFCN8(nb_in_channels, input_var,
130130
nonlinearity=softmax)
131131
# end-snippet-1
132132

133-
# Load weights
134-
if load_weights:
135-
if pascal:
136-
path_weights = '/data/lisatmp4/erraqabi/data/att-segm/' + \
137-
'pre_trained_weights/pascal-fcn8s-tvg-dag.mat'
138-
if 'tvg' in path_weights:
139-
str_filter = 'f'
140-
str_bias = 'b'
141-
else:
142-
str_filter = '_filter'
143-
str_bias = '_bias'
144-
145-
W = sio.loadmat(path_weights)
146-
147-
# Load the parameter values into the net
148-
num_params = W.get('params').shape[1]
149-
for i in range(num_params):
150-
# Get layer name from the saved model
151-
name = str(W.get('params')[0][i][0])[3:-2]
152-
# Get parameter value
153-
param_value = W.get('params')[0][i][1]
154-
155-
# Load weights
156-
if name.endswith(str_filter):
157-
raw_name = name[:-len(str_filter)]
158-
if 'score' not in raw_name and \
159-
'upsample' not in raw_name and \
160-
'final' not in raw_name and \
161-
'probs' not in raw_name:
162-
163-
# print 'Initializing layer ' + raw_name
164-
param_value = param_value.T
165-
param_value = np.swapaxes(param_value, 2, 3)
166-
net[raw_name].W.set_value(param_value)
167-
168-
# Load bias terms
169-
if name.endswith(str_bias):
170-
raw_name = name[:-len(str_bias)]
171-
if 'score' not in raw_name and \
172-
'upsample' not in raw_name and \
173-
'final' not in raw_name and \
174-
'probs' not in raw_name:
175-
176-
param_value = np.squeeze(param_value)
177-
net[raw_name].b.set_value(param_value)
178-
else:
179-
with np.load(path_weights) as f:
180-
param_values = [f['arr_%d' % i] for i in range(len(f.files))]
181-
lasagne.layers.set_all_param_values(net['probs'], param_values)
182133

183134
# Do not train
184135
if not trainable:

0 commit comments

Comments
 (0)