Skip to content

Commit c00a74b

Browse files
lingvo-botcopybara-github
authored andcommitted
Add BatchNorm to supported normalizations.
PiperOrigin-RevId: 561187464
1 parent 74ac32f commit c00a74b

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

lingvo/core/conformer_layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ def _NormalizeStep(self, theta, inputs, paddings, state0, state1):
377377
inputs, paddings, norm_state1 = self.norm.StreamStep(
378378
theta.norm, inputs, paddings, state0.norm_state)
379379
state1.norm_state = norm_state1
380+
elif isinstance(self.norm, bn_layers.BatchNormLayer):
381+
inputs = self.norm.FProp(theta.norm, inputs)
380382
elif isinstance(self.norm, layers.LayerNorm):
381383
inputs = self.norm.FProp(theta.norm, inputs)
382384
else:

lingvo/core/conformer_layer_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def _GetParams(self, **kwargs):
7272
input_dim=input_dim, is_causal=True, kernel_size=kernel)
7373
if norm_type == 'ln':
7474
p.conv_norm_layer_tpl = lingvo_layers.LayerNorm.Params()
75+
elif norm_type == 'bn':
76+
p.conv_norm_layer_tpl = bn_layers.BatchNormLayer.Params()
7577
else:
7678
p.conv_norm_layer_tpl = bn_layers.GroupNormLayer.Params().Set(
7779
num_groups=2, cumulative=True)
@@ -90,12 +92,13 @@ def _GetFPropOutput(self, fprop_out):
9092
@parameterized.named_parameters(
9193
('Basic',),
9294
('BasicGN', False, 'gn'),
95+
('BasicBN', False, 'bn'),
9396
('SkipNorm', True),
9497
)
9598
def testLeftContext(self, testonly_skip_norm_layers=False, norm_type='ln'):
9699
with flagsaver.flagsaver(testonly_skip_norm_layers=testonly_skip_norm_layers
97100
), cluster_factory.SetEval(True):
98-
assert norm_type in ('ln', 'gn')
101+
assert norm_type in ('ln', 'gn', 'bn')
99102
input_dim, kernel = 2, 3
100103
self._TestStreamStepHelper(
101104
num_heads=2, input_dim=input_dim, kernel=kernel, norm_type=norm_type)

0 commit comments

Comments
 (0)