Skip to content

Commit acc0594

Browse files
Allow to non-trainable models with factory function
NOTE: it is particularly important that this command also sets the batch-normalization layers to non-trainable, which now seems to be the standard with Tensorflow 2 + Keras, but is not yet handled well by, e.g., the models from `segmentation_models`. Cf. `freeze_model` from `segmentation_models/models/_utils.py` and, e.g., https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute and keras-team/keras#9965.
1 parent aa3fa3e commit acc0594

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

src/bfseg/cl_models/base_cl_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,14 @@ def _build_model(self):
5252
assert (self.run.config['cl_params']['cl_framework']
5353
in ["ewc", "finetune"
5454
]), "Currently, only EWC and fine-tuning are supported."
55+
# NOTE: by default the model is created as trainable. CL frameworks that
56+
# require a fixed, non-trainable network from which to distill the
57+
# information (e.g., in distillation experiments) should create additional
58+
# models by overloading this method and calling `super()._build_model()` in
59+
# the overload.
5560
self.encoder, self.model = create_model(
5661
model_name=self.run.config['network_params']['architecture'],
62+
trainable=True,
5763
**self.run.config['network_params']['model_params'])
5864
self.new_model = keras.Model(
5965
inputs=self.model.input,

src/bfseg/cl_models/distillation_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, run, root_output_dir):
4343
"Distillation model requires the CL parameter `distillation_type` "
4444
"to be specified.")
4545

46-
super(DistillationModel, self).__init__(run=run, root_output_dir=root_output_dir)
46+
super(DistillationModel, self).__init__(run=run,
47+
root_output_dir=root_output_dir)
4748

4849
self._started_training_new_task = False
4950

src/bfseg/utils/models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
def create_model(model_name,
1010
image_h,
1111
image_w,
12+
trainable,
1213
log_params_used=True,
1314
**model_params):
1415
r"""Factory function that creates a model with the given parameters.
@@ -18,6 +19,7 @@ def create_model(model_name,
1819
"fast_scnn", "unet".
1920
image_h (int): Image height.
2021
image_w (int): Image width.
22+
trainable (bool): Whether or not the model should be trainable.
2123
log_params_used (bool): If True, the complete list of parameters used to
2224
instantiate the model is printed.
2325
---
@@ -91,4 +93,17 @@ def create_model(model_name,
9193

9294
encoder, model = model_fn(**model_params)
9395

94-
return encoder, model
96+
# Optionally set the model as non-trainable.
97+
if (not trainable):
98+
# NOTE: it is particularly important that this command also sets the
99+
# batch-normalization layers to non-trainable, which now seems to be the
100+
# standard with Tensorflow 2 + Keras, but is not yet handled well by, e.g.,
101+
# the models from `segmentation_models`.
102+
# Cf. `freeze_model` from `segmentation_models/models/_utils.py` and, e.g.,
103+
# https://keras.io/getting_started/faq/#whats-the-difference-between-the-
104+
# training-argument-in-call-and-the-trainable-attribute and
105+
# https://github.com/keras-team/keras/pull/9965.
106+
encoder.trainable = False
107+
model.trainable = False
108+
109+
return encoder, model

src/nyu_pretraining.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def pretrain_nyu(_run,
3838
_, model = create_model(model_name="fast_scnn",
3939
image_h=image_h,
4040
image_w=image_w,
41+
trainable=True,
4142
num_downsampling_layers=2)
4243

4344
model.compile(

0 commit comments

Comments
 (0)