Skip to content
50 changes: 31 additions & 19 deletions src/chronos/chronos2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
import torch
import torch.nn as nn
from einops import rearrange, repeat
from packaging import version
from transformers import __version__ as transformers_version
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput


if version.parse(transformers_version) >= version.parse("5.0.0.dev0"):
from transformers import initialization as init
else:
from torch.nn import init

from chronos.chronos_bolt import InstanceNorm, Patch

from .config import Chronos2CoreConfig, Chronos2ForecastingConfig
Expand Down Expand Up @@ -264,53 +272,57 @@ def __init__(self, config: Chronos2CoreConfig):
self.post_init()

def _init_weights(self, module):
super()._init_weights(module)
"""Initialize the weights"""
"""Initialize the weights.
"""
factor = self.config.initializer_factor
if isinstance(module, Chronos2LayerNorm):
module.weight.data.fill_(factor * 1.0)
init.constant_(module.weight, factor * 1.0)
elif isinstance(module, MLP):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi, "bias") and module.wi.bias is not None:
module.wi.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
init.zeros_(module.wi.bias)
init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
init.zeros_(module.wo.bias)
elif isinstance(module, MHA):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model
kv_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
elif isinstance(module, (Chronos2Model)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
init.normal_(
module.hidden_layer.weight,
mean=0.0,
std=factor * (module.hidden_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
module.hidden_layer.bias.data.zero_()
init.zeros_(module.hidden_layer.bias)

module.residual_layer.weight.data.normal_(
init.normal_(
module.residual_layer.weight,
mean=0.0,
std=factor * (module.residual_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
module.residual_layer.bias.data.zero_()
init.zeros_(module.residual_layer.bias)

module.output_layer.weight.data.normal_(
mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5)
init.normal_(
module.output_layer.weight,
mean=0.0,
std=factor * (module.output_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
module.output_layer.bias.data.zero_()
init.zeros_(module.output_layer.bias)

def _validate_input(
self,
Expand Down
62 changes: 49 additions & 13 deletions src/chronos/chronos_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import torch
import torch.nn as nn
from transformers import AutoConfig
from packaging import version
from transformers import AutoConfig, __version__ as transformers_version
from transformers.models.t5.modeling_t5 import (
ACT2FN,
T5Config,
Expand All @@ -28,6 +29,31 @@

logger = logging.getLogger(__file__)

_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0")

# In transformers v5, use guarded init functions that check _is_hf_initialized
# to avoid re-initializing weights loaded from checkpoint
if _TRANSFORMERS_V5:
from transformers import initialization as init
else:
from torch.nn import init


def _create_t5_stack(config: T5Config, embed_tokens: nn.Embedding) -> T5Stack:
"""
Create a T5Stack with the given config and embed_tokens.

This helper function provides backward compatibility between transformers v4 and v5.
In v4, T5Stack.__init__ accepts (config, embed_tokens).
In v5, T5Stack.__init__ only accepts (config), and embed_tokens must be set separately.
"""
if _TRANSFORMERS_V5:
stack = T5Stack(config)
stack.set_input_embeddings(embed_tokens)
return stack
else:
return T5Stack(config, embed_tokens)


@dataclass
class ChronosBoltConfig:
Expand Down Expand Up @@ -150,7 +176,15 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
r"output_patch_embedding\.",
]
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore
# In transformers v5, _tied_weights_keys changed from list to dict {target: source}
_tied_weights_keys = ( # type: ignore
{
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
if _TRANSFORMERS_V5
else ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
)

def __init__(self, config: T5Config):
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
Expand Down Expand Up @@ -188,7 +222,7 @@ def __init__(self, config: T5Config):
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
self.encoder = T5Stack(encoder_config, self.shared)
self.encoder = _create_t5_stack(encoder_config, self.shared)

self._init_decoder(config)

Expand All @@ -213,29 +247,31 @@ def __init__(self, config: T5Config):
self.device_map = None

def _init_weights(self, module):
super()._init_weights(module)
"""Initialize the weights"""
"""Initialize the weights.
"""
factor = self.config.initializer_factor
if isinstance(module, (self.__class__)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
init.normal_(
module.hidden_layer.weight,
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
module.hidden_layer.bias.data.zero_()
init.zeros_(module.hidden_layer.bias)

module.residual_layer.weight.data.normal_(
init.normal_(
module.residual_layer.weight,
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
module.residual_layer.bias.data.zero_()
init.zeros_(module.residual_layer.bias)

module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
init.normal_(module.output_layer.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
module.output_layer.bias.data.zero_()
init.zeros_(module.output_layer.bias)

def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -359,7 +395,7 @@ def _init_decoder(self, config):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
self.decoder = _create_t5_stack(decoder_config, self.shared)

def decode(
self,
Expand Down
Binary file added test/dummy-chronos-model/model.safetensors
Binary file not shown.