forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 2
Open
pytorch/torchtitan
#2154Description
- This is fixed by passing to transformers v5 (which dont use
kwargsanymore for attention)
Bug description
[rank3]:/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:321: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
[rank3]: warnings.warn(
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] torch._dynamo hit config.recompile_limit (8)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] function: 'forward' (/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:145)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] last reason: 0/7: ___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs) # if hook_id in self._forward_pre_hooks_with_kwargs: # nn/modules/module.py:1815 in inner
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
[rank3]:[rank3]: Traceback (most recent call last):
Explanation
- When
torch.compiletraces your model, it creates a compiled graph along with guards. Guards are conditions that must be true for that graph to be reused. If guard fails,torch.compilewill recompiles. - in
modeling_llama.py, theself.attn(hidden_states=hidden_states)is called withkwargs - In torchtitan, if you apply TP, it will apply
register_forward_pre_hook. However, depending on if you usekwargsor not, it will call different function (cf https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L576).- In our case, it will call
module.register_forward_pre_hook(lambda _, inputs, kwargs: some_fn(inputs, kwargs), with_kwargs=True
- In our case, it will call
- but calling this function is problematic as it will trigger
if hook_id in self._forward_pre_hooks_with_kwargs:(cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1808)- This means that using
kwargswill results in differenthook_id, hence the error___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)
- This means that using
- When we don't use
kwargs,self._forward_pre_hooks_with_kwargswill always be empty (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1679C13-L1679C48) so the if check is not triggered, so each attention layer has samehook_id, thus no recompile
Versions
python -m torch.utils.collect_env
<frozen runpy>:128: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
Collecting environment information...
PyTorch version: 2.10.0.dev20251113+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 10.5.0-1ubuntu1~20.04) 10.5.0
Clang version: 18.1.8
CMake version: version 3.27.7
Libc version: glibc-2.31
Python version: 3.12.9 (main, Mar 11 2025, 17:26:57) [Clang 20.1.0 ] (64-bit runtime)
Python platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version: 575.57.08
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7R13 Processor
Stepping: 1
CPU MHz: 3580.180
BogoMIPS: 5299.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3 MiB
L1i cache: 3 MiB
L2 cache: 48 MiB
L3 cache: 384 MiB
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
uv pip list
Using Python 3.12.9 environment at: /fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official
Package Version Editable project location
------------------------ ------------------------ ------------------------------------------------------------------
absl-py 2.3.1
aiohappyeyeballs 2.6.1
aiohttp 3.13.2
aiosignal 1.4.0
anyio 4.11.0
attrs 25.4.0
certifi 2025.11.12
cfgv 3.4.0
charset-normalizer 3.4.4
click 8.3.0
contourpy 1.3.3
cycler 0.12.1
datasets 4.4.1
debugpy 1.8.17
debugpy-run 1.16
dill 0.4.0
distlib 0.4.0
docstring-parser 0.17.0
fastcore 1.8.16
filelock 3.20.0
fonttools 4.60.1
frozenlist 1.8.0
fsspec 2025.9.0
grpcio 1.76.0
h11 0.16.0
hf-xet 1.2.0
httpcore 1.0.9
httpx 0.28.1
huggingface-hub 0.36.0
identify 2.6.15
idna 3.11
jinja2 3.1.6
kiwisolver 1.4.9
lovely-numpy 0.2.17
lovely-tensors 0.1.19
markdown 3.10
markdown-it-py 4.0.0
markupsafe 3.0.2
matplotlib 3.10.7
mdurl 0.1.2
mpmath 1.3.0
multidict 6.7.0
multiprocess 0.70.18
networkx 3.5
nodeenv 1.9.1
numpy 2.3.4
nvidia-cublas-cu12 12.6.4.1
nvidia-cuda-cupti-cu12 12.6.80
nvidia-cuda-nvrtc-cu12 12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12 9.10.2.21
nvidia-cufft-cu12 11.3.0.4
nvidia-cufile-cu12 1.11.1.6
nvidia-curand-cu12 10.3.7.77
nvidia-cusolver-cu12 11.7.1.2
nvidia-cusparse-cu12 12.5.4.2
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.27.5
nvidia-nvjitlink-cu12 12.6.85
nvidia-nvshmem-cu12 3.4.5
nvidia-nvtx-cu12 12.6.77
objprint 0.3.0
packaging 25.0
pandas 2.3.3
pillow 12.0.0
platformdirs 4.5.0
pre-commit 4.4.0
propcache 0.4.1
protobuf 6.33.1
psutil 7.1.3
pyarrow 22.0.0
pygments 2.19.2
pyparsing 3.2.5
python-dateutil 2.9.0.post0
pytorch-triton 3.5.1+gitbfeb0668
pytz 2025.2
pyyaml 6.0.3
regex 2025.11.3
requests 2.32.5
rich 14.2.0
safetensors 0.6.2
setuptools 78.1.0
shellingham 1.5.4
shtab 1.7.2
six 1.17.0
sniffio 1.3.1
sympy 1.14.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
tokenizers 0.22.1
toml 0.10.2
tomli 2.3.0
torch 2.10.0.dev20251113+cu126
torchdata 0.11.0
torchtitan 0.2.0 /fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan
tqdm 4.67.1
transformers 4.57.1 /fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/transformers
typeguard 4.4.4
typer-slim 0.20.0
typing-extensions 4.15.0
tyro 0.9.35
tzdata 2025.2
urllib3 2.5.0
virtualenv 20.35.4
viztracer 1.1.1
werkzeug 3.1.3
xxhash 3.6.0
Metadata
Metadata
Assignees
Labels
No labels