Skip to content

HF modeling with torch.compile doenst work when used with Tensor Parallel #6

@3outeille

Description

@3outeille
  • This is fixed by passing to transformers v5 (which dont use kwargs anymore for attention)

Bug description

[rank3]:/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:321: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
[rank3]:  warnings.warn(
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] torch._dynamo hit config.recompile_limit (8)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8]    function: 'forward' (/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:145)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8]    last reason: 0/7: ___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)  # if hook_id in self._forward_pre_hooks_with_kwargs:  # nn/modules/module.py:1815 in inner
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
[rank3]:[rank3]: Traceback (most recent call last):

Explanation

Versions

  • python -m torch.utils.collect_env
<frozen runpy>:128: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
Collecting environment information...
PyTorch version: 2.10.0.dev20251113+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 10.5.0-1ubuntu1~20.04) 10.5.0
Clang version: 18.1.8
CMake version: version 3.27.7
Libc version: glibc-2.31

Python version: 3.12.9 (main, Mar 11 2025, 17:26:57) [Clang 20.1.0 ] (64-bit runtime)
Python platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 575.57.08
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             96
On-line CPU(s) list:                0-95
Thread(s) per core:                 1
Core(s) per socket:                 48
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              1
Model name:                         AMD EPYC 7R13 Processor
Stepping:                           1
CPU MHz:                            3580.180
BogoMIPS:                           5299.99
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          3 MiB
L1i cache:                          3 MiB
L2 cache:                           48 MiB
L3 cache:                           384 MiB
NUMA node0 CPU(s):                  0-47
NUMA node1 CPU(s):                  48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
  • uv pip list
Using Python 3.12.9 environment at: /fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official
Package                  Version                  Editable project location
------------------------ ------------------------ ------------------------------------------------------------------
absl-py                  2.3.1
aiohappyeyeballs         2.6.1
aiohttp                  3.13.2
aiosignal                1.4.0
anyio                    4.11.0
attrs                    25.4.0
certifi                  2025.11.12
cfgv                     3.4.0
charset-normalizer       3.4.4
click                    8.3.0
contourpy                1.3.3
cycler                   0.12.1
datasets                 4.4.1
debugpy                  1.8.17
debugpy-run              1.16
dill                     0.4.0
distlib                  0.4.0
docstring-parser         0.17.0
fastcore                 1.8.16
filelock                 3.20.0
fonttools                4.60.1
frozenlist               1.8.0
fsspec                   2025.9.0
grpcio                   1.76.0
h11                      0.16.0
hf-xet                   1.2.0
httpcore                 1.0.9
httpx                    0.28.1
huggingface-hub          0.36.0
identify                 2.6.15
idna                     3.11
jinja2                   3.1.6
kiwisolver               1.4.9
lovely-numpy             0.2.17
lovely-tensors           0.1.19
markdown                 3.10
markdown-it-py           4.0.0
markupsafe               3.0.2
matplotlib               3.10.7
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.7.0
multiprocess             0.70.18
networkx                 3.5
nodeenv                  1.9.1
numpy                    2.3.4
nvidia-cublas-cu12       12.6.4.1
nvidia-cuda-cupti-cu12   12.6.80
nvidia-cuda-nvrtc-cu12   12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12        9.10.2.21
nvidia-cufft-cu12        11.3.0.4
nvidia-cufile-cu12       1.11.1.6
nvidia-curand-cu12       10.3.7.77
nvidia-cusolver-cu12     11.7.1.2
nvidia-cusparse-cu12     12.5.4.2
nvidia-cusparselt-cu12   0.7.1
nvidia-nccl-cu12         2.27.5
nvidia-nvjitlink-cu12    12.6.85
nvidia-nvshmem-cu12      3.4.5
nvidia-nvtx-cu12         12.6.77
objprint                 0.3.0
packaging                25.0
pandas                   2.3.3
pillow                   12.0.0
platformdirs             4.5.0
pre-commit               4.4.0
propcache                0.4.1
protobuf                 6.33.1
psutil                   7.1.3
pyarrow                  22.0.0
pygments                 2.19.2
pyparsing                3.2.5
python-dateutil          2.9.0.post0
pytorch-triton           3.5.1+gitbfeb0668
pytz                     2025.2
pyyaml                   6.0.3
regex                    2025.11.3
requests                 2.32.5
rich                     14.2.0
safetensors              0.6.2
setuptools               78.1.0
shellingham              1.5.4
shtab                    1.7.2
six                      1.17.0
sniffio                  1.3.1
sympy                    1.14.0
tensorboard              2.20.0
tensorboard-data-server  0.7.2
tokenizers               0.22.1
toml                     0.10.2
tomli                    2.3.0
torch                    2.10.0.dev20251113+cu126
torchdata                0.11.0
torchtitan               0.2.0                    /fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan
tqdm                     4.67.1
transformers             4.57.1                   /fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/transformers
typeguard                4.4.4
typer-slim               0.20.0
typing-extensions        4.15.0
tyro                     0.9.35
tzdata                   2025.2
urllib3                  2.5.0
virtualenv               20.35.4
viztracer                1.1.1
werkzeug                 3.1.3
xxhash                   3.6.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions