Skip to content

Conversation

@Achazwl
Copy link

@Achazwl Achazwl commented Oct 29, 2025

Add checks for weight tying in state_dict processing

Add checks for weight tying in state_dict processing
@meta-cla
Copy link

meta-cla bot commented Oct 29, 2025

Hi @Achazwl!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 29, 2025
@meta-cla
Copy link

meta-cla bot commented Oct 29, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@tianyu-l tianyu-l requested a review from shuhuayu October 29, 2025 17:27
if key not in to_hf_map:
continue
# Skip output.weight if weight tying is enabled (HF checkpoint won't have lm_head.weight)
if self.model_args.enable_weight_tying and key == "output.weight":
Copy link
Contributor

@wwwjn wwwjn Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By checking 0.6B and 1.7B model weights, they do have separate weights for embed_tokens and lm_head, and I assumed these 2 weight are the same (please correct me if I am wrong), so loading the same weights twice are ok here.

1.7B: https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/model.safetensors.index.json
0.6B: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/model.safetensors

I see your change makes sense. Our previous code will fail when loading the 4B model weights: 4B model doesn't have "lm_head.weight" in their checkpoint files, but our translated hf_state_dict will still have key lm_head.weight. Did you verified the updated code still on par with HF forward? cc @shuhuayu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for catching the bug when loading the qwen3 4b model. I did a forward parity check, it works well.
Image

# copy from embed_tokens.weight
if self.model_args.enable_weight_tying and "lm_head.weight" not in hf_state_dict:
if "model.embed_tokens.weight" in hf_state_dict:
hf_state_dict = dict(hf_state_dict) # Make a copy to avoid modifying original
Copy link
Contributor

@wwwjn wwwjn Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to make a shallow copy of the dict? Can you elaborate more on "avoid modify original"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without dict(hf_state_dict), the line hf_state_dict["lm_head.weight"] = ... would directly mutate the dictionary object provided by the caller function. I'm not sure if the caller expects the input dictionary to be modified, so I made a copy to avoid any potential side effects.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not necessary, revert this line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the input hf_state_dict will not be used after calling from_hf() function:

state_dict = self.sd_adapter.from_hf(hf_state_dict)

It should be mutate the dictionary object (hf_state_dict)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the input hf_state_dict will not be used after calling from_hf() function:

state_dict = self.sd_adapter.from_hf(hf_state_dict)

It should be mutate the dictionary object (hf_state_dict)

Ok, I've removed the shallow copy.

Copy link
Contributor

@shuhuayu shuhuayu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for fixing this bug! lgtm.

@wwwjn
Copy link
Contributor

wwwjn commented Nov 3, 2025

Hi @Achazwl Thanks for contribution! Do you wanna fix the lint and run CI before again before we can merge it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants