Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_nlp/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
activation="relu",
use_gated_activation=True,
layer_norm_epsilon=1e-06,
tie_embedding_weights=False,
tie_embedding_weights=True,
**kwargs,
):
# Encoder inputs
Expand Down
18 changes: 12 additions & 6 deletions keras_nlp/models/t5/t5_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
"activation": "relu",
"use_gated_activation": False,
"layer_norm_epsilon": 1e-06,
"tie_embedding_weights": True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's update the default for T5Backbone to True

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the default, thanks!

},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/model.weights.h5",
"weights_hash": "5a241ea61142eaf96ac1805898a2f2d1",
"weights_hash": "2e10b5f72405d464ee55026b07e60741",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
Expand All @@ -64,10 +65,11 @@
"activation": "relu",
"use_gated_activation": False,
"layer_norm_epsilon": 1e-06,
"tie_embedding_weights": True,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/model.weights.h5",
"weights_hash": "9bef4c6650d91d1ea438ee4a2bea47ad",
"weights_hash": "bed6ef276cfe83d1323467051211978d",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
Expand All @@ -92,10 +94,11 @@
"activation": "relu",
"use_gated_activation": False,
"layer_norm_epsilon": 1e-06,
"tie_embedding_weights": True,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/model.weights.h5",
"weights_hash": "eab8eee1bad033e65324a71cd6e5a8e9",
"weights_hash": "7854a05c2e6812899bf6f0f104792cda",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
Expand All @@ -121,10 +124,11 @@
"activation": "keras_nlp>gelu_approximate",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
"tie_embedding_weights": False,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/model.weights.h5",
"weights_hash": "4e39b0bab56606a9ab2b8e52a6bc7a9f",
"weights_hash": "aa0fbaddb1759ef313bbc4f9e4f1e197",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
Expand All @@ -149,10 +153,11 @@
"activation": "keras_nlp>gelu_approximate",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
"tie_embedding_weights": False,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/model.weights.h5",
"weights_hash": "b529270c5361db36d359a46403532b5c",
"weights_hash": "84a10bec83fd093931bb2a6264115d31",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
Expand All @@ -177,10 +182,11 @@
"activation": "keras_nlp>gelu_approximate",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
"tie_embedding_weights": False,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/model.weights.h5",
"weights_hash": "50b8d3c88fc10db07e495d79ff29a1b6",
"weights_hash": "513f530ce790efa7e261c0ef965f3697",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
Expand Down
36 changes: 29 additions & 7 deletions tools/checkpoint_conversion/convert_t5_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def convert_checkpoints(hf_model):
keras_nlp_model.get_layer("token_embedding").embeddings.assign(
hf_wts[f"{section}.embed_tokens.weight"]
)
if not keras_nlp_model.tie_embedding_weights:
keras_nlp_model.get_layer(
"token_embedding"
).reverse_embeddings.assign(
hf_wts["lm_head.weight"].transpose(1, 0).numpy()
)

# Query, key, value, and output projectors in self-attention
keras_nlp_model.get_layer(
Expand Down Expand Up @@ -308,17 +314,18 @@ def check_output(
print(k, v)

# Forward pass
keras_outputs = keras_model(keras_inputs)
hf_outputs = hf_model(**hf_inputs)
keras_out = keras_model(keras_inputs)
hf_out = hf_model(**hf_inputs, output_hidden_states=True)

# Only compare non-padded token ids.
keras_outputs = keras_outputs["decoder_sequence_output"]
keras_hidden_states = keras_out["decoder_sequence_output"]
hf_hidden_states = hf_out.decoder_hidden_states[-1]

keras_outputs = ops.take_along_axis(
keras_outputs, ops.where(decoder_padding_mask)
keras_hidden_states, ops.where(decoder_padding_mask)
)
hf_outputs = hf_outputs.last_hidden_state
hf_outputs = ops.take_along_axis(
hf_outputs, ops.where(decoder_padding_mask)
hf_hidden_states, ops.where(decoder_padding_mask)
)

print("-> KerasNLP output:", keras_outputs[0:5])
Expand All @@ -327,6 +334,21 @@ def check_output(
keras_outputs.detach().numpy(), hf_outputs.detach().numpy(), atol=1e-5
)

if keras_model.tie_embedding_weights:
keras_hidden_states = keras_hidden_states * (
keras_model.hidden_dim**-0.5
)

keras_logits = keras_model.token_embedding(
keras_hidden_states, reverse=True
)
hf_logits = hf_out.logits
print("-> KerasNLP logits:", keras_logits[0:5])
print("-> HF logits:", hf_logits[0:5])
np.testing.assert_allclose(
keras_logits.detach().numpy(), hf_logits.detach().numpy(), atol=1e-3
)


def count_params(weights):
shapes = [v.shape for v in weights]
Expand All @@ -339,7 +361,7 @@ def main(_):
os.mkdir(f"./{FLAGS.preset}")

print("\n-> Convert weights.")
hf_model = transformers.AutoModel.from_pretrained(hf_id)
hf_model = transformers.T5ForConditionalGeneration.from_pretrained(hf_id)
keras_model = convert_checkpoints(hf_model)

# Save the model.
Expand Down