-
Notifications
You must be signed in to change notification settings - Fork 594
Fix bugs in initial_load_in_hf when enable_weight_tying=true in Qwen3 #1964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add checks for weight tying in state_dict processing
|
Hi @Achazwl! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
| if key not in to_hf_map: | ||
| continue | ||
| # Skip output.weight if weight tying is enabled (HF checkpoint won't have lm_head.weight) | ||
| if self.model_args.enable_weight_tying and key == "output.weight": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By checking 0.6B and 1.7B model weights, they do have separate weights for embed_tokens and lm_head, and I assumed these 2 weight are the same (please correct me if I am wrong), so loading the same weights twice are ok here.
1.7B: https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/model.safetensors.index.json
0.6B: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/model.safetensors
I see your change makes sense. Our previous code will fail when loading the 4B model weights: 4B model doesn't have "lm_head.weight" in their checkpoint files, but our translated hf_state_dict will still have key lm_head.weight. Did you verified the updated code still on par with HF forward? cc @shuhuayu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # copy from embed_tokens.weight | ||
| if self.model_args.enable_weight_tying and "lm_head.weight" not in hf_state_dict: | ||
| if "model.embed_tokens.weight" in hf_state_dict: | ||
| hf_state_dict = dict(hf_state_dict) # Make a copy to avoid modifying original |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to make a shallow copy of the dict? Can you elaborate more on "avoid modify original"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without dict(hf_state_dict), the line hf_state_dict["lm_head.weight"] = ... would directly mutate the dictionary object provided by the caller function. I'm not sure if the caller expects the input dictionary to be modified, so I made a copy to avoid any potential side effects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not necessary, revert this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the input hf_state_dict will not be used after calling from_hf() function:
| state_dict = self.sd_adapter.from_hf(hf_state_dict) |
It should be mutate the dictionary object (hf_state_dict)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the input
hf_state_dictwill not be used after calling from_hf() function:
state_dict = self.sd_adapter.from_hf(hf_state_dict) It should be mutate the dictionary object (
hf_state_dict)
Ok, I've removed the shallow copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for fixing this bug! lgtm.
Co-authored-by: Shuhua Yu <[email protected]>
Co-authored-by: Shuhua Yu <[email protected]>
|
Hi @Achazwl Thanks for contribution! Do you wanna fix the lint and run CI before again before we can merge it? |

Add checks for weight tying in state_dict processing