Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Prev Previous commit
Next Next commit
Update base hparams for tpu
PiperOrigin-RevId: 195761510
  • Loading branch information
Niki Parmar authored and lukaszkaiser committed May 8, 2018
commit 6be28ffe148e925e65cdbbc5f4fee72080e7402d
19 changes: 14 additions & 5 deletions tensor2tensor/models/image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,16 @@ def update_hparams_for_tpu(hparams):

@registry.register_hparams
def imagetransformer_base_tpu():
hparams = imagetransformer_base()
"""Transformer base params for cifar-10."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.hidden_size = 256
hparams.filter_size = 512
hparams.num_hidden_layers = 8
hparams.sampling_method = "random"
hparams.num_decoder_layers = 12
hparams.block_length = 128
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
return hparams


Expand All @@ -691,11 +693,16 @@ def imagetransformer_sep_channels_8l_tpu():

@registry.register_hparams
def imagetransformer_b10l_4h_big_uncond_dr03_tpu():
"""Small model for tpu cifar 10."""
hparams = imagetransformer_bas8l_8h_big_uncond_dr03_imgnet()
update_hparams_for_tpu(hparams)
hparams.batch_size = 4
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 10
hparams.block_length = 128
hparams.hidden_size = 256
hparams.filter_size = 1024
hparams.learning_rate = 0.2
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
return hparams
Expand Down Expand Up @@ -740,6 +747,8 @@ def imagetransformer_b12l_4h_big_uncond_dr03_tpu():
hparams.num_heads = 4 # heads are expensive on tpu
hparams.num_decoder_layers = 12
hparams.block_length = 128
hparams.hidden_size = 512
hparams.filter_size = 1024
hparams.layer_preprocess_sequence = "none"
hparams.layer_postprocess_sequence = "dan"
hparams.layer_prepostprocess_dropout = 0.3
Expand Down