Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8c0736a

Browse files
marcvanzeecopybara-github
authored andcommitted
Fix bug in universal transformer hyperparameter range.
PiperOrigin-RevId: 278638337
1 parent b679a88 commit 8c0736a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensor2tensor/models/research/universal_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ def universal_transformer_base_range(rhp):
811811
rhp.set_discrete("hidden_size", [1024, 2048, 4096])
812812
rhp.set_discrete("filter_size", [2048, 4096, 8192])
813813
rhp.set_discrete("num_heads", [8, 16, 32])
814-
rhp.set_discrete("transformer_ffn_type", ["sepconv", "fc"])
814+
rhp.set_categorical("transformer_ffn_type", ["sepconv", "fc"])
815815
rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE)
816816
rhp.set_float("weight_decay", 0.0, 2.0)
817817

@@ -825,6 +825,6 @@ def adaptive_universal_transformer_base_range(rhp):
825825
rhp.set_discrete("hidden_size", [1024, 2048, 4096])
826826
rhp.set_discrete("filter_size", [2048, 4096, 8192])
827827
rhp.set_discrete("num_heads", [8, 16, 32])
828-
rhp.set_discrete("transformer_ffn_type", ["sepconv", "fc"])
828+
rhp.set_categorical("transformer_ffn_type", ["sepconv", "fc"])
829829
rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE)
830830
rhp.set_float("weight_decay", 0.0, 2.0)

0 commit comments

Comments
 (0)