Skip to content
Prev Previous commit
Next Next commit
use from torch.nn import init
  • Loading branch information
kashif committed Dec 5, 2025
commit ac46167d7467a989d3d1ae77375a42c5416999da
55 changes: 36 additions & 19 deletions src/chronos/chronos2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@
import torch
import torch.nn as nn
from einops import rearrange, repeat
from packaging import version
from transformers import __version__ as transformers_version
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput

# In transformers v5, use guarded init functions that check _is_hf_initialized
# to avoid re-initializing weights loaded from checkpoint
if version.parse(transformers_version) >= version.parse("5.0.0.dev0"):
from transformers import initialization as init
else:
from torch.nn import init

from chronos.chronos_bolt import InstanceNorm, Patch

from .config import Chronos2CoreConfig, Chronos2ForecastingConfig
Expand Down Expand Up @@ -264,53 +273,61 @@ def __init__(self, config: Chronos2CoreConfig):
self.post_init()

def _init_weights(self, module):
super()._init_weights(module)
"""Initialize the weights"""
"""Initialize the weights.

Uses transformers.initialization functions which are guarded against
re-initializing weights that have already been loaded from checkpoint
(they check the _is_hf_initialized flag on each parameter).
"""
factor = self.config.initializer_factor
if isinstance(module, Chronos2LayerNorm):
module.weight.data.fill_(factor * 1.0)
init.constant_(module.weight, factor * 1.0)
elif isinstance(module, MLP):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi, "bias") and module.wi.bias is not None:
module.wi.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
init.zeros_(module.wi.bias)
init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
init.zeros_(module.wo.bias)
elif isinstance(module, MHA):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model
kv_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * kv_proj_dim) ** -0.5))
init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * kv_proj_dim) ** -0.5))
elif isinstance(module, (Chronos2Model)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
init.normal_(
module.hidden_layer.weight,
mean=0.0,
std=factor * (module.hidden_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
module.hidden_layer.bias.data.zero_()
init.zeros_(module.hidden_layer.bias)

module.residual_layer.weight.data.normal_(
init.normal_(
module.residual_layer.weight,
mean=0.0,
std=factor * (module.residual_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
module.residual_layer.bias.data.zero_()
init.zeros_(module.residual_layer.bias)

module.output_layer.weight.data.normal_(
mean=0.0, std=factor * (module.output_layer.weight.size(-1) ** -0.5)
init.normal_(
module.output_layer.weight,
mean=0.0,
std=factor * (module.output_layer.weight.size(-1) ** -0.5),
)
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
module.output_layer.bias.data.zero_()
init.zeros_(module.output_layer.bias)

def _validate_input(
self,
Expand Down
32 changes: 23 additions & 9 deletions src/chronos/chronos_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@
# Transformers v5 introduced breaking changes:
# - T5Stack.__init__ no longer accepts embed_tokens argument
# - _tied_weights_keys changed from list to dict format
# - _init_weights needs guarded init functions to avoid re-initializing loaded weights
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0")

# In transformers v5, use guarded init functions that check _is_hf_initialized
# to avoid re-initializing weights loaded from checkpoint
if _TRANSFORMERS_V5:
from transformers import initialization as init
else:
from torch.nn import init


def _create_t5_stack(config: T5Config, embed_tokens: nn.Embedding) -> T5Stack:
"""
Expand Down Expand Up @@ -243,29 +251,35 @@ def __init__(self, config: T5Config):
self.device_map = None

def _init_weights(self, module):
super()._init_weights(module)
"""Initialize the weights"""
"""Initialize the weights.

Uses transformers.initialization functions which are guarded against
re-initializing weights that have already been loaded from checkpoint
(they check the _is_hf_initialized flag on each parameter).
"""
factor = self.config.initializer_factor
if isinstance(module, (self.__class__)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
init.normal_(
module.hidden_layer.weight,
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if hasattr(module.hidden_layer, "bias") and module.hidden_layer.bias is not None:
module.hidden_layer.bias.data.zero_()
init.zeros_(module.hidden_layer.bias)

module.residual_layer.weight.data.normal_(
init.normal_(
module.residual_layer.weight,
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if hasattr(module.residual_layer, "bias") and module.residual_layer.bias is not None:
module.residual_layer.bias.data.zero_()
init.zeros_(module.residual_layer.bias)

module.output_layer.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
init.normal_(module.output_layer.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.output_layer, "bias") and module.output_layer.bias is not None:
module.output_layer.bias.data.zero_()
init.zeros_(module.output_layer.bias)

def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
Expand Down