diff --git a/torchtitan/models/qwen3/model/state_dict_adapter.py b/torchtitan/models/qwen3/model/state_dict_adapter.py index 66e94a09c..1c75c3e9d 100644 --- a/torchtitan/models/qwen3/model/state_dict_adapter.py +++ b/torchtitan/models/qwen3/model/state_dict_adapter.py @@ -104,6 +104,8 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: if key not in to_hf_map: continue + if self.model_args.enable_weight_tying and key == "output.weight": + continue new_key = to_hf_map[key] hf_state_dict[new_key] = value @@ -118,6 +120,15 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: state_dict = {} expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} + if ( + self.model_args.enable_weight_tying + and "lm_head.weight" not in hf_state_dict + ): + if "model.embed_tokens.weight" in hf_state_dict: + hf_state_dict["lm_head.weight"] = hf_state_dict[ + "model.embed_tokens.weight" + ] + for key, value in hf_state_dict.items(): if "mlp.experts" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=2)