@@ -27,16 +27,23 @@ class ResNetBackbone(FeaturePyramidBackbone):
2727 This class implements a ResNet backbone as described in [Deep Residual
2828 Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
2929 CVPR 2016), [Identity Mappings in Deep Residual Networks](
30- https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An
30+ https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An
3131 improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
32- NeurIPS 2021 Workshop).
32+ NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with
33+ Convolutional Neural Networks](https://arxiv.org/abs/1812.01187).
3334
3435 The difference in ResNet and ResNetV2 rests in the structure of their
3536 individual building blocks. In ResNetV2, the batch normalization and
3637 ReLU activation precede the convolution layers, as opposed to ResNet where
3738 the batch normalization and ReLU activation are applied after the
3839 convolution layers.
3940
41+ ResNetVd introduces two key modifications to the standard ResNet. First,
42+ the initial convolutional layer is replaced by a series of three
43+ successive convolutional layers. Second, shortcut connections use an
44+ additional pooling operation rather than performing downsampling within
45+ the convolutional layers themselves.
46+
4047 Note that `ResNetBackbone` expects the inputs to be images with a value
4148 range of `[0, 255]` when `include_rescaling=True`.
4249
@@ -51,6 +58,7 @@ class ResNetBackbone(FeaturePyramidBackbone):
5158 Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
5259 use_pre_activation: boolean. Whether to use pre-activation or not.
5360 `True` for ResNetV2, `False` for ResNet.
61+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
5462 include_rescaling: boolean. If `True`, rescale the input using
5563 `Rescaling` and `Normalization` layers. If `False`, do nothing.
5664 Defaults to `True`.
@@ -106,6 +114,7 @@ def __init__(
106114 stackwise_num_strides ,
107115 block_type ,
108116 use_pre_activation = False ,
117+ use_vd_pooling = False ,
109118 include_rescaling = True ,
110119 input_image_shape = (None , None , 3 ),
111120 pooling = "avg" ,
@@ -133,7 +142,12 @@ def __init__(
133142 '`block_type` must be either `"basic_block"` or '
134143 f'`"bottleneck_block"`. Received block_type={ block_type } .'
135144 )
136- version = "v1" if not use_pre_activation else "v2"
145+ if use_vd_pooling :
146+ version = "vd"
147+ elif use_pre_activation :
148+ version = "v2"
149+ else :
150+ version = "v1"
137151 data_format = standardize_data_format (data_format )
138152 bn_axis = - 1 if data_format == "channels_last" else 1
139153 num_stacks = len (stackwise_num_filters )
@@ -155,21 +169,21 @@ def __init__(
155169 # The padding between torch and tensorflow/jax differs when `strides>1`.
156170 # Therefore, we need to manually pad the tensor.
157171 x = layers .ZeroPadding2D (
158- 3 ,
172+ 1 if use_vd_pooling else 3 ,
159173 data_format = data_format ,
160174 dtype = dtype ,
161175 name = "conv1_pad" ,
162176 )(x )
163- x = layers . Conv2D (
164- 64 ,
165- 7 ,
166- strides = 2 ,
167- data_format = data_format ,
168- use_bias = False ,
169- dtype = dtype ,
170- name = "conv1_conv" ,
171- )( x )
172- if not use_pre_activation :
177+ if use_vd_pooling :
178+ x = layers . Conv2D (
179+ 32 ,
180+ 3 ,
181+ strides = 2 ,
182+ data_format = data_format ,
183+ use_bias = False ,
184+ dtype = dtype ,
185+ name = "conv1_conv" ,
186+ )( x )
173187 x = layers .BatchNormalization (
174188 axis = bn_axis ,
175189 epsilon = 1e-5 ,
@@ -178,6 +192,57 @@ def __init__(
178192 name = "conv1_bn" ,
179193 )(x )
180194 x = layers .Activation ("relu" , dtype = dtype , name = "conv1_relu" )(x )
195+ x = layers .Conv2D (
196+ 32 ,
197+ 3 ,
198+ strides = 1 ,
199+ padding = "same" ,
200+ data_format = data_format ,
201+ use_bias = False ,
202+ dtype = dtype ,
203+ name = "conv2_conv" ,
204+ )(x )
205+ x = layers .BatchNormalization (
206+ axis = bn_axis ,
207+ epsilon = 1e-5 ,
208+ momentum = 0.9 ,
209+ dtype = dtype ,
210+ name = "conv2_bn" ,
211+ )(x )
212+ x = layers .Activation ("relu" , dtype = dtype , name = "conv2_relu" )(x )
213+ x = layers .Conv2D (
214+ 64 ,
215+ 3 ,
216+ strides = 1 ,
217+ padding = "same" ,
218+ data_format = data_format ,
219+ use_bias = False ,
220+ dtype = dtype ,
221+ name = "conv3_conv" ,
222+ )(x )
223+ else :
224+ x = layers .Conv2D (
225+ 64 ,
226+ 7 ,
227+ strides = 2 ,
228+ data_format = data_format ,
229+ use_bias = False ,
230+ dtype = dtype ,
231+ name = "conv1_conv" ,
232+ )(x )
233+ if not use_pre_activation :
234+ x = layers .BatchNormalization (
235+ axis = bn_axis ,
236+ epsilon = 1e-5 ,
237+ momentum = 0.9 ,
238+ dtype = dtype ,
239+ name = "conv3_bn" if use_vd_pooling else "conv1_bn" ,
240+ )(x )
241+ x = layers .Activation (
242+ "relu" ,
243+ dtype = dtype ,
244+ name = "conv3_relu" if use_vd_pooling else "conv1_relu" ,
245+ )(x )
181246
182247 if use_pre_activation :
183248 # A workaround for ResNetV2: we need -inf padding to prevent zeros
@@ -210,8 +275,11 @@ def __init__(
210275 stride = stackwise_num_strides [stack_index ],
211276 block_type = block_type ,
212277 use_pre_activation = use_pre_activation ,
278+ use_vd_pooling = use_vd_pooling ,
213279 first_shortcut = (
214- block_type == "bottleneck_block" or stack_index > 0
280+ block_type == "bottleneck_block"
281+ or stack_index > 0
282+ or use_vd_pooling
215283 ),
216284 data_format = data_format ,
217285 dtype = dtype ,
@@ -253,6 +321,7 @@ def __init__(
253321 self .stackwise_num_strides = stackwise_num_strides
254322 self .block_type = block_type
255323 self .use_pre_activation = use_pre_activation
324+ self .use_vd_pooling = use_vd_pooling
256325 self .include_rescaling = include_rescaling
257326 self .input_image_shape = input_image_shape
258327 self .pooling = pooling
@@ -267,6 +336,7 @@ def get_config(self):
267336 "stackwise_num_strides" : self .stackwise_num_strides ,
268337 "block_type" : self .block_type ,
269338 "use_pre_activation" : self .use_pre_activation ,
339+ "use_vd_pooling" : self .use_vd_pooling ,
270340 "include_rescaling" : self .include_rescaling ,
271341 "input_image_shape" : self .input_image_shape ,
272342 "pooling" : self .pooling ,
@@ -282,6 +352,7 @@ def apply_basic_block(
282352 stride = 1 ,
283353 conv_shortcut = False ,
284354 use_pre_activation = False ,
355+ use_vd_pooling = False ,
285356 data_format = None ,
286357 dtype = None ,
287358 name = None ,
@@ -299,6 +370,7 @@ def apply_basic_block(
299370 `False`.
300371 use_pre_activation: boolean. Whether to use pre-activation or not.
301372 `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
373+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
302374 data_format: `None` or str. the ordering of the dimensions in the
303375 inputs. Can be `"channels_last"`
304376 (`(batch_size, height, width, channels)`) or`"channels_first"`
@@ -327,16 +399,27 @@ def apply_basic_block(
327399 )(x_preact )
328400
329401 if conv_shortcut :
330- x = x_preact if x_preact is not None else x
402+ if x_preact is not None :
403+ shortcut = x_preact
404+ elif use_vd_pooling and stride > 1 :
405+ shortcut = layers .AveragePooling2D (
406+ 2 ,
407+ strides = stride ,
408+ data_format = data_format ,
409+ dtype = dtype ,
410+ padding = "same" ,
411+ )(x )
412+ else :
413+ shortcut = x
331414 shortcut = layers .Conv2D (
332415 filters ,
333416 1 ,
334- strides = stride ,
417+ strides = 1 if use_vd_pooling else stride ,
335418 data_format = data_format ,
336419 use_bias = False ,
337420 dtype = dtype ,
338421 name = f"{ name } _0_conv" ,
339- )(x )
422+ )(shortcut )
340423 if not use_pre_activation :
341424 shortcut = layers .BatchNormalization (
342425 axis = bn_axis ,
@@ -407,6 +490,7 @@ def apply_bottleneck_block(
407490 stride = 1 ,
408491 conv_shortcut = False ,
409492 use_pre_activation = False ,
493+ use_vd_pooling = False ,
410494 data_format = None ,
411495 dtype = None ,
412496 name = None ,
@@ -424,6 +508,7 @@ def apply_bottleneck_block(
424508 `False`.
425509 use_pre_activation: boolean. Whether to use pre-activation or not.
426510 `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
511+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
427512 data_format: `None` or str. the ordering of the dimensions in the
428513 inputs. Can be `"channels_last"`
429514 (`(batch_size, height, width, channels)`) or`"channels_first"`
@@ -452,16 +537,27 @@ def apply_bottleneck_block(
452537 )(x_preact )
453538
454539 if conv_shortcut :
455- x = x_preact if x_preact is not None else x
540+ if x_preact is not None :
541+ shortcut = x_preact
542+ elif use_vd_pooling and stride > 1 :
543+ shortcut = layers .AveragePooling2D (
544+ 2 ,
545+ strides = stride ,
546+ data_format = data_format ,
547+ dtype = dtype ,
548+ padding = "same" ,
549+ )(x )
550+ else :
551+ shortcut = x
456552 shortcut = layers .Conv2D (
457553 4 * filters ,
458554 1 ,
459- strides = stride ,
555+ strides = 1 if use_vd_pooling else stride ,
460556 data_format = data_format ,
461557 use_bias = False ,
462558 dtype = dtype ,
463559 name = f"{ name } _0_conv" ,
464- )(x )
560+ )(shortcut )
465561 if not use_pre_activation :
466562 shortcut = layers .BatchNormalization (
467563 axis = bn_axis ,
@@ -548,6 +644,7 @@ def apply_stack(
548644 stride ,
549645 block_type ,
550646 use_pre_activation ,
647+ use_vd_pooling = False ,
551648 first_shortcut = True ,
552649 data_format = None ,
553650 dtype = None ,
@@ -565,6 +662,7 @@ def apply_stack(
565662 Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
566663 use_pre_activation: boolean. Whether to use pre-activation or not.
567664 `True` for ResNetV2, `False` for ResNet and ResNeXt.
665+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
568666 first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
569667 use an identity or pooling shortcut based on the stride. Defaults to
570668 `True`.
@@ -580,7 +678,12 @@ def apply_stack(
580678 Output tensor for the stacked blocks.
581679 """
582680 if name is None :
583- version = "v1" if not use_pre_activation else "v2"
681+ if use_vd_pooling :
682+ version = "vd"
683+ elif use_pre_activation :
684+ version = "v2"
685+ else :
686+ version = "v1"
584687 name = f"{ version } _stack"
585688
586689 if block_type == "basic_block" :
@@ -605,6 +708,7 @@ def apply_stack(
605708 stride = stride ,
606709 conv_shortcut = conv_shortcut ,
607710 use_pre_activation = use_pre_activation ,
711+ use_vd_pooling = use_vd_pooling ,
608712 data_format = data_format ,
609713 dtype = dtype ,
610714 name = f"{ name } _block{ str (i )} " ,
0 commit comments