Fix torch.compile recompilation issue with HF modeling + TP#2130
Fix torch.compile recompilation issue with HF modeling + TP#21303outeille wants to merge 7 commits intopytorch:mainfrom
torch.compile recompilation issue with HF modeling + TP#2130Conversation
…e issue when combined with TP
|
|
||
|
|
||
| flavors = { | ||
| "debugperf": HFTransformerModelArgs( |
There was a problem hiding this comment.
what's the difference between debugperf / debugperf_large and debugmodel? Can we just keep one of them?
There was a problem hiding this comment.
Do we need to ship this folder to fix the issue? It's about 2k LoC complexity.
There was a problem hiding this comment.
+1 , I think we could remove these test scripts to keep code simple
wwwjn
left a comment
There was a problem hiding this comment.
Thanks for finding this! to check my understanding, the bug is:
the function call with kwargs will return new object id for the hook -> causing recompile
Is this correct?
|
|
||
|
|
||
| flavors = { | ||
| "debugperf": HFTransformerModelArgs( |
There was a problem hiding this comment.
Should we remove these 2 test models?
There was a problem hiding this comment.
+1 , I think we could remove these test scripts to keep code simple
|
|
||
|
|
||
| llama3_args = { | ||
| "debugperf": TransformerModelArgs( |
There was a problem hiding this comment.
same here, could we remove these 2 models here?
| class HFTransformers: | ||
| model: str = "" | ||
| """HuggingFace model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507')""" | ||
| tie_word_embeddings: bool = False |
There was a problem hiding this comment.
Putting tie_word_embeddings into job config is a little bit confusing, and seems not related to this error?
IIUC this is a field is decided by model architecture, and not decided by each training run. So previously we put Qwen3's weight tying config into model_args:
|
It is a known issue: cf pytorch/pytorch#170110. In this PR, we make changes so that we bypass the use of. |
Fixing the bug huggingface#6
TODO: need to apply change in
transformers V5. That requires to wait for V5 to be a bit stable before switch torchtitan transformers modeling backend to v5 (as for now, it relies on 4.57.1)Issue
Fix
transformersatmodeling_llama.py, change./tooling_dev/debug_local.sh debugperf_large --compileExplanation
torch.compiletraces your model, it creates a compiled graph along with guards. Guards are conditions that must be true for that graph to be reused. If guard fails,torch.compilewill recompiles.modeling_llama.py, theself.attn(hidden_states=hidden_states)is called withkwargsregister_forward_pre_hook. However, depending on if you usekwargsor not, it will call different function (cf https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L576).module.register_forward_pre_hook(lambda _, inputs, kwargs: some_fn(inputs, kwargs), with_kwargs=Trueif hook_id in self._forward_pre_hooks_with_kwargs:(cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1808)kwargswill results in differenthook_id, hence the error___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)kwargs,self._forward_pre_hooks_with_kwargswill always be empty (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1679C13-L1679C48) so the if check is not triggered, so each attention layer has samehook_id, thus no recompile