Skip to content

Commit ff2a09c

Browse files
committed
changing pooling output
1 parent 4cf6586 commit ff2a09c

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

keras_nlp/models/bert/bert_backbone.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,12 @@ def __init__(
162162
# Construct the two BERT outputs. The pooled output is a dense layer on
163163
# top of the [CLS] token.
164164
sequence_output = x
165-
x = keras.layers.Dense(
165+
pooled_output = keras.layers.Dense(
166166
hidden_dim,
167167
kernel_initializer=bert_kernel_initializer(),
168168
activation="tanh",
169169
name="pooled_dense",
170-
)(x)
171-
pooled_output = x[:, cls_token_index, :]
170+
)(x[:, cls_token_index, :])
172171

173172
# Instantiate using Functional API Model constructor
174173
super().__init__(

keras_nlp/models/electra/electra_backbone.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,15 @@ def __init__(
163163
)(x, padding_mask=padding_mask)
164164

165165
sequence_output = x
166-
x = keras.layers.Dense(
166+
# Construct the two ELECTRA outputs. The pooled output is a dense layer on
167+
# top of the [CLS] token.
168+
sequence_output = x
169+
pooled_output = keras.layers.Dense(
167170
hidden_dim,
168171
kernel_initializer=electra_kernel_initializer(),
169172
activation="tanh",
170173
name="pooled_dense",
171-
)(x)
172-
pooled_output = x[:, cls_token_index, :]
174+
)(x[:, cls_token_index, :])
173175

174176
# Instantiate using Functional API Model constructor
175177
super().__init__(

0 commit comments

Comments
 (0)