diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 0397be2a..cee0f3c1 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -10,9 +10,17 @@ import torch import torch.nn as nn from einops import rearrange, repeat +from packaging import version +from transformers import __version__ as transformers_version from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput + +if version.parse(transformers_version) >= version.parse("5.0.0.dev0"): + from transformers import initialization as init +else: + from torch.nn import init + from chronos.chronos_bolt import InstanceNorm, Patch from .config import Chronos2CoreConfig, Chronos2ForecastingConfig @@ -268,49 +276,53 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, Chronos2LayerNorm): - module.weight.data.fill_(factor * 1.0) + init.constant_(module.weight, factor * 1.0) elif isinstance(module, MLP): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.zeros_(module.wi.bias) + init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + init.zeros_(module.wo.bias) elif isinstance(module, MHA): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model kv_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) + init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5)) + init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5)) + init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5)) elif isinstance(module, (Chronos2Model)): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, ResidualBlock): - module.hidden_layer.weight.data.normal_( + init.normal_( + module.hidden_layer.weight, mean=0.0, std=factor * (module.hidden_layer.weight.size(-1) ** -0.5), ) if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None: - module.hidden_layer.bias.data.zero_() + init.zeros_(module.hidden_layer.bias) - module.residual_layer.weight.data.normal_( + init.normal_( + module.residual_layer.weight, mean=0.0, std=factor * (module.residual_layer.weight.size(-1) ** -0.5), ) if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None: - module.residual_layer.bias.data.zero_() + init.zeros_(module.residual_layer.bias) - module.output_layer.weight.data.normal_( - mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5) + init.normal_( + module.output_layer.weight, + mean=0.0, + std=factor * (module.output_layer.weight.size(-1) ** -0.5), ) if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None: - module.output_layer.bias.data.zero_() + init.zeros_(module.output_layer.bias) def _validate_input( self, diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 743ec06b..84e9405a 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -13,7 +13,8 @@ import torch import torch.nn as nn -from transformers import AutoConfig +from packaging import version +from transformers import AutoConfig, __version__ as transformers_version from transformers.models.t5.modeling_t5 import ( ACT2FN, T5Config, @@ -28,6 +29,29 @@ logger = logging.getLogger(__file__) +_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0") + +# In transformers v5, use guarded init functions that check _is_hf_initialized +# to avoid re-initializing weights loaded from checkpoint +if _TRANSFORMERS_V5: + from transformers import initialization as init +else: + from torch.nn import init + + +def _create_t5_stack(config: T5Config, embed_tokens: nn.Embedding) -> T5Stack: + """ + Create a T5Stack with the given config and embed_tokens. + + This helper function provides backward compatibility between transformers v4 and v5. + In v4, T5Stack.__init__ accepts (config, embed_tokens). + In v5, T5Stack.__init__ only accepts (config), and embed_tokens must be set separately. + """ + if _TRANSFORMERS_V5: + return T5Stack(config) + else: + return T5Stack(config, embed_tokens) + @dataclass class ChronosBoltConfig: @@ -150,7 +174,15 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): r"output_patch_embedding\.", ] _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore + # In transformers v5, _tied_weights_keys changed from list to dict {target: source} + _tied_weights_keys = ( # type: ignore + { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } + if _TRANSFORMERS_V5 + else ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + ) def __init__(self, config: T5Config): assert hasattr(config, "chronos_config"), "Not a Chronos config file" @@ -188,7 +220,7 @@ def __init__(self, config: T5Config): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - self.encoder = T5Stack(encoder_config, self.shared) + self.encoder = _create_t5_stack(encoder_config, self.shared) self._init_decoder(config) @@ -217,25 +249,27 @@ def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, (self.__class__)): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0) elif isinstance(module, ResidualBlock): - module.hidden_layer.weight.data.normal_( + init.normal_( + module.hidden_layer.weight, mean=0.0, std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5), ) if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None: - module.hidden_layer.bias.data.zero_() + init.zeros_(module.hidden_layer.bias) - module.residual_layer.weight.data.normal_( + init.normal_( + module.residual_layer.weight, mean=0.0, std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5), ) if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None: - module.residual_layer.bias.data.zero_() + init.zeros_(module.residual_layer.bias) - module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + init.normal_(module.output_layer.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None: - module.output_layer.bias.data.zero_() + init.zeros_(module.output_layer.bias) def encode( self, context: torch.Tensor, mask: Optional[torch.Tensor] = None @@ -359,7 +393,7 @@ def _init_decoder(self, config): decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) + self.decoder = _create_t5_stack(decoder_config, self.shared) def decode( self, diff --git a/test/dummy-chronos-model/model.safetensors b/test/dummy-chronos-model/model.safetensors new file mode 100644 index 00000000..40bef767 Binary files /dev/null and b/test/dummy-chronos-model/model.safetensors differ