Skip to content

Fix for incompatible ReduceOp.PREMUL_SUM on XPU devices#2332

Closed
saforem2 wants to merge 13 commits intopytorch:mainfrom
saforem2:saforem2/fix-reduce-op-xpu
Closed

Fix for incompatible ReduceOp.PREMUL_SUM on XPU devices#2332
saforem2 wants to merge 13 commits intopytorch:mainfrom
saforem2:saforem2/fix-reduce-op-xpu

Conversation

@saforem2
Copy link

@saforem2 saforem2 commented Feb 5, 2026

Add fallback to enable module.set_force_sum_reduction_for_comms when on XPU devices with xccl backend

Full Traceback:
ezpz launch python3 -m torchtitan.experiments.ezpz.train --job.config_file torchtitan/models/llama3/train_configs/debug_model.toml --training.steps=100


[2026-02-05 11:32:52,798396][I][ezpz/launch:403:launch] ----[🍋 ezpz.launch][started][2026-02-05-113252]----
[2026-02-05 11:32:52,804619][I][ezpz/launch:66:_log_json_log_file] Logs available at: /lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/logs/torchtitan.experiments.ezpz.train/2026-02-05-113252-rank0.jsonl
[2026-02-05 11:32:54,175673][I][ezpz/launch:424:launch] Job ID: 12460005
[2026-02-05 11:32:54,176510][I][ezpz/launch:425:launch] nodelist: ['x1922c7s2b0n0', 'x1922c7s3b0n0']
[2026-02-05 11:32:54,176940][I][ezpz/launch:426:launch] hostfile: /var/spool/pbs/aux/12460005.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov
[2026-02-05 11:32:54,177616][I][ezpz/pbs:267:get_pbs_launch_cmd] ✅ Using [24/24] GPUs [2 hosts] x [12 GPU/host]
[2026-02-05 11:32:54,178533][I][ezpz/launch:374:build_executable] Building command to execute by piecing together:
[2026-02-05 11:32:54,178967][I][ezpz/launch:375:build_executable] (1.) launch_cmd: mpiexec --envall --np=24 --ppn=12 --hostfile=/var/spool/pbs/aux/12460005.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov --no-vni --cpu-bind=verbose,list:2-4:10-12:18-20:26-28:34-36:42-44:54-56:62-64:70-72:78-80:86-88:94-96
[2026-02-05 11:32:54,179640][I][ezpz/launch:376:build_executable] (2.) cmd_to_launch: python3 -m torchtitan.experiments.ezpz.train --job.config_file torchtitan/models/llama3/train_configs/debug_model.toml --training.steps=100
[2026-02-05 11:32:54,180386][I][ezpz/launch:441:launch] Took: 1.97 seconds to build command.
[2026-02-05 11:32:54,180761][I][ezpz/launch:444:launch] Executing:
mpiexec
  --envall
  --np=24
  --ppn=12
  --hostfile=/var/spool/pbs/aux/12460005.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov
  --no-vni
  --cpu-bind=verbose,list:2-4:10-12:18-20:26-28:34-36:42-44:54-56:62-64:70-72:78-80:86-88:94-96
  python3
  -m
  torchtitan.experiments.ezpz.train
  --job.config_file
  torchtitan/models/llama3/train_configs/debug_model.toml
  --training.steps=100
[2026-02-05 11:32:54,182160][I][ezpz/launch:451:launch] Execution started @ 2026-02-05-113254...
[2026-02-05 11:32:54,182656][I][ezpz/launch:146:run_command] Running command:
 mpiexec --envall --np=24 --ppn=12 --hostfile=/var/spool/pbs/aux/12460005.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov --no-vni --cpu-bind=verbose,list:2-4:10-12:18-20:26-28:34-36:42-44:54-56:62-64:70-72:78-80:86-88:94-96 python3 -m torchtitan.experiments.ezpz.train --job.config_file torchtitan/models/llama3/train_configs/debug_model.toml --training.steps=100
cpubind:list x1922c7s2b0n0 pid 100049 rank 0 0: mask 0x1c
cpubind:list x1922c7s2b0n0 pid 100050 rank 1 1: mask 0x1c00
cpubind:list x1922c7s2b0n0 pid 100051 rank 2 2: mask 0x1c0000
cpubind:list x1922c7s2b0n0 pid 100052 rank 3 3: mask 0x1c000000
cpubind:list x1922c7s2b0n0 pid 100053 rank 4 4: mask 0x1c00000000
cpubind:list x1922c7s2b0n0 pid 100054 rank 5 5: mask 0x1c0000000000
cpubind:list x1922c7s2b0n0 pid 100055 rank 6 6: mask 0x1c0000000000000
cpubind:list x1922c7s2b0n0 pid 100056 rank 7 7: mask 0x1c000000000000000
cpubind:list x1922c7s2b0n0 pid 100057 rank 8 8: mask 0x1c00000000000000000
cpubind:list x1922c7s2b0n0 pid 100058 rank 9 9: mask 0x1c0000000000000000000
cpubind:list x1922c7s2b0n0 pid 100059 rank 10 10: mask 0x1c000000000000000000000
cpubind:list x1922c7s2b0n0 pid 100060 rank 11 11: mask 0x1c00000000000000000000000
cpubind:list x1922c7s3b0n0 pid 18912 rank 12 0: mask 0x1c
cpubind:list x1922c7s3b0n0 pid 18913 rank 13 1: mask 0x1c00
cpubind:list x1922c7s3b0n0 pid 18914 rank 14 2: mask 0x1c0000
cpubind:list x1922c7s3b0n0 pid 18915 rank 15 3: mask 0x1c000000
cpubind:list x1922c7s3b0n0 pid 18916 rank 16 4: mask 0x1c00000000
cpubind:list x1922c7s3b0n0 pid 18917 rank 17 5: mask 0x1c0000000000
cpubind:list x1922c7s3b0n0 pid 18918 rank 18 6: mask 0x1c0000000000000
cpubind:list x1922c7s3b0n0 pid 18919 rank 19 7: mask 0x1c000000000000000
cpubind:list x1922c7s3b0n0 pid 18920 rank 20 8: mask 0x1c00000000000000000
cpubind:list x1922c7s3b0n0 pid 18921 rank 21 9: mask 0x1c0000000000000000000
cpubind:list x1922c7s3b0n0 pid 18922 rank 22 10: mask 0x1c000000000000000000000
cpubind:list x1922c7s3b0n0 pid 18923 rank 23 11: mask 0x1c00000000000000000000000
[2026-02-05 11:33:03,860171][I][ezpz/dist:1226:get_local_rank] Local rank env vars unset; falling back to rank modulo GPUs. Checked LOCAL_RANK, PMI_LOCAL_RANK, OMPI_COMM_WORLD_LOCAL_RANK, MPI_LOCALRANKID, MPICH_LOCALRANKID, SLURM_LOCAL_ID.
[2026-02-05 11:33:03,863249][I][ezpz/dist:1565:setup_torch_distributed] Using device=xpu with backend=xccl
[2026-02-05 11:33:03,864010][I][ezpz/dist:1430:setup_torch_DDP] Caught MASTER_PORT=49717 from environment!
[2026-02-05 11:33:03,864728][I][ezpz/dist:1446:setup_torch_DDP] Using torch.distributed.init_process_group with
- master_addr='x1922c7s2b0n0'
- master_port='49717'
- world_size=24
- rank=0
- local_rank=0
- timeout=datetime.timedelta(seconds=3600)
- backend='xccl'
[2026-02-05 11:33:03,865657][I][ezpz/dist:1007:init_process_group] Calling torch.distributed.init_process_group_with: rank=0 world_size=24 backend=xccl
[2026-02-05 11:33:04,515510][I][ezpz/dist:1822:setup_torch] Using device='xpu' with backend='xccl' + 'xccl' for distributed training.
[2026-02-05 11:33:04,516334][W][ezpz/dist:529:print_dist_setup] Using [24 / 24] available "xpu" devices !!
[2026-02-05 11:33:04,516843][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=00/11][rank=00/23]
[2026-02-05 11:33:04,516859][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=02/11][rank=02/23]
[2026-02-05 11:33:04,516861][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=06/11][rank=06/23]
[2026-02-05 11:33:04,516859][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=01/11][rank=01/23]
[2026-02-05 11:33:04,516863][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=03/11][rank=03/23]
[2026-02-05 11:33:04,516859][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=04/11][rank=04/23]
[2026-02-05 11:33:04,516862][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=05/11][rank=05/23]
[2026-02-05 11:33:04,516864][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=07/11][rank=07/23]
[2026-02-05 11:33:04,516851][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=00/11][rank=12/23]
[2026-02-05 11:33:04,516859][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=08/11][rank=08/23]
[2026-02-05 11:33:04,516864][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=09/11][rank=09/23]
[2026-02-05 11:33:04,516855][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=01/11][rank=13/23]
[2026-02-05 11:33:04,516861][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=10/11][rank=10/23]
[2026-02-05 11:33:04,516854][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=02/11][rank=14/23]
[2026-02-05 11:33:04,516862][I][ezpz/dist:1871:setup_torch] ['x1922c7s2b0n0'][device='xpu'][node=0/1][local_rank=11/11][rank=11/23]
[2026-02-05 11:33:04,516860][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=03/11][rank=15/23]
[2026-02-05 11:33:04,516856][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=04/11][rank=16/23]
[2026-02-05 11:33:04,516857][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=06/11][rank=18/23]
[titan] 2026-02-05 11:33:04,520 - root - INFO - torchtitan version: 0.0.0+unknown (0.0.0 means __version__ is not defined correctly).
[2026-02-05 11:33:04,516855][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=11/11][rank=23/23]
[2026-02-05 11:33:04,516858][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=05/11][rank=17/23]
[2026-02-05 11:33:04,516860][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=07/11][rank=19/23]
[2026-02-05 11:33:04,516855][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=08/11][rank=20/23]
[2026-02-05 11:33:04,516856][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=09/11][rank=21/23]
[2026-02-05 11:33:04,516857][I][ezpz/dist:1871:setup_torch] ['x1922c7s3b0n0'][device='xpu'][node=1/1][local_rank=10/11][rank=22/23]
[titan] 2026-02-05 11:33:04,619 - root - INFO - Starting job: Llama 3 debug training
[titan] 2026-02-05 11:33:04,619 - root - INFO - Building device mesh with parallelism: pp=1, dp_replicate=1, dp_shard=24, cp=1, tp=1, ep=1, etp=1
[titan] 2026-02-05 11:33:04,628 - root - INFO - Successfully created meshes with active dimensions: ['batch', 'loss', 'fsdp', 'efsdp']
[titan] 2026-02-05 11:33:04,628 - root - INFO - [GC] Initial GC collection took 0.00 seconds
2026:02:05-11:33:04:100049 |CCL_WARN| value of CCL_OP_SYNC changed to be 1 (default:0)
2026:02:05-11:33:04:100049 |CCL_WARN| value of CCL_PROCESS_LAUNCHER changed to be pmix (default:hydra)
[titan] 2026-02-05 11:33:05,409 - root - INFO - Loading tokenizer from tokenizer.json
[titan] 2026-02-05 11:33:05,412 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[titan] 2026-02-05 11:33:05,890 - root - INFO - Building llama3 debugmodelwith {
  "_enforced": "This field is used to enforce all fields have defaults.",
  "dim": 256,
  "n_layers": 6,
  "n_heads": 16,
  "n_kv_heads": null,
  "vocab_size": 2048,
  "multiple_of": 256,
  "ffn_dim_multiplier": null,
  "norm_eps": 1e-05,
  "rope_theta": 500000,
  "rope_scaling_args": {
    "scaling_factor": 8.0,
    "low_freq_factor": 1.0,
    "high_freq_factor": 4.0,
    "original_max_position_embeddings": 8192
  },
  "max_seq_len": 2048,
  "depth_init": true,
  "attn_type": "sdpa",
  "attn_mask_type": "causal",
  "eos_id": 0
}
[titan] 2026-02-05 11:33:05,905 - root - INFO - XPU capacity: Intel(R) Data Center GPU Max 1550 with 63.98GiB memory
[titan] 2026-02-05 11:33:06,295 - root - INFO - Model llama3 debugmodel size: 6,163,712 total parameters
[titan] 2026-02-05 11:33:06,295 - root - INFO - Applied selective activation checkpointing to the model
[titan] 2026-02-05 11:33:06,315 - root - INFO - Applied FSDP to the model
[titan] 2026-02-05 11:33:06,782 - root - INFO - Peak FLOPS used for computing MFU: 2.982e+14
[titan] 2026-02-05 11:33:06,782 - root - INFO - XPU memory usage for model: 0.00GiB(0.00%)
[titan] 2026-02-05 11:33:06,783 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[titan] 2026-02-05 11:33:06,786 - torchtitan.experiments.ezpz.distributed.utils - INFO - Mixed precision training is handled by fully_shard
[titan] 2026-02-05 11:33:06,786 - root - INFO - Trainer is initialized with local batch size 8, global batch size 192, gradient accumulation steps 1, sequence length 2048, total steps 100 (warmup 2)
[titan] 2026-02-05 11:33:06,786 - root - INFO - Training starts at step 1
[rank5]: Traceback (most recent call last):
[rank5]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank5]:   File "<frozen runpy>", line 88, in _run_code
[rank5]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 823, in <module>
[rank5]:     main(Trainer)
[rank5]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 808, in main
[rank5]:     trainer.train()
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
[rank5]:     return f(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 714, in train
[rank5]:     self.train_step(data_iterator)
[rank5]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 599, in train_step
[rank5]:     loss = self.forward_backward_step(
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 557, in forward_backward_step
[rank5]:     loss.backward()
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank5]:     torch.autograd.backward(
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank5]:     _engine_run_backward(
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank5]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank5]:     return user_fn(self, *args)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 900, in backward
[rank5]:     ctx.param_group.post_backward()
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 566, in post_backward
[rank5]:     ) = foreach_reduce(
[rank5]:         ^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank5]:     return func(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 543, in foreach_reduce
[rank5]:     reduce_scatter_comm(
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 125, in __call__
[rank5]:     return dist.reduce_scatter_tensor(
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank5]:     return func(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 4591, in reduce_scatter_tensor
[rank5]:     work = group._reduce_scatter_base(output, input, opts)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]: ValueError: Cannot use ReduceOp.PREMUL_SUM with XCCL

<details closed><summary>Traceback:</summary>

```bash
[rank0]: Traceback (most recent call last):
[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 823, in <module>
[rank0]:     main(Trainer)
[rank0]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 808, in main
[rank0]:     trainer.train()
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
[rank0]:     return f(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 714, in train
[rank0]:     self.train_step(data_iterator)
[rank0]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 599, in train_step
[rank0]:     loss = self.forward_backward_step(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 557, in forward_backward_step
[rank0]:     loss.backward()
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 900, in backward
[rank0]:     ctx.param_group.post_backward()
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 566, in post_backward
[rank0]:     ) = foreach_reduce(
[rank0]:         ^^^^^^^^^^^^^^^
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 543, in foreach_reduce
[rank0]:     reduce_scatter_comm(
[rank0]:   File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 125, in __call__
[rank0]:     return dist.reduce_scatter_tensor(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^[rank1]: Traceback (most recent call last):
```

</details>
Copilot AI review requested due to automatic review settings February 5, 2026 17:53
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 5, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a ValueError that occurs when using FSDP with XPU devices and the xccl backend. The error "Cannot use ReduceOp.PREMUL_SUM with XCCL" is addressed by adding fallback logic to force sum reduction for backends that don't support the PREMUL_SUM operation.

Changes:

  • Added backend detection logic in disable_fsdp_gradient_division to check if the current backend is NCCL
  • For non-NCCL backends, the function now calls set_force_sum_reduction_for_comms(True) on FSDP modules to use regular SUM operations instead of PREMUL_SUM
  • Added torch.distributed import to support backend detection

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

force_sum_reduction = False
if torch.distributed.is_available() and torch.distributed.is_initialized():
backend = torch.distributed.get_backend()
if backend and backend.lower() != "nccl":
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic backend.lower() != "nccl" enables force_sum_reduction for ALL non-NCCL backends (e.g., gloo, fake, etc.), not just xccl. This is overly broad and may force sum reduction unnecessarily for backends that support PREMUL_SUM. Consider being more specific by checking if backend is in a list of known incompatible backends, such as checking backend.lower() in ("xccl", "gloo") or more specifically backend.lower() == "xccl" if only xccl is known to have this limitation.

Suggested change
if backend and backend.lower() != "nccl":
if backend and backend.lower() == "xccl":

Copilot uses AI. Check for mistakes.
Comment on lines +283 to +287
force_sum_reduction = False
if torch.distributed.is_available() and torch.distributed.is_initialized():
backend = torch.distributed.get_backend()
if backend and backend.lower() != "nccl":
force_sum_reduction = True
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining why force_sum_reduction is needed for certain backends. The comment should clarify that some backends (like xccl) do not support ReduceOp.PREMUL_SUM and require falling back to regular SUM operations. This will help future maintainers understand the purpose of this workaround.

Copilot uses AI. Check for mistakes.
Comment on lines 274 to 282
"""
Disable FSDP's automatic gradient division for all FSDP modules.

Set gradient_divide_factor=1.0 to disable FSDP's automatic gradient division.
We handle gradient scaling ourselves in the training loop with global token count.

Args:
model: The model containing FSDP-wrapped modules
"""
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for disable_fsdp_gradient_division should be updated to document the new behavior of forcing sum reduction for non-NCCL backends. This is a significant change to the function's behavior that should be documented for users.

Copilot uses AI. Check for mistakes.
Comment on lines +291 to +292
if force_sum_reduction:
module.set_force_sum_reduction_for_comms(True)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a logger statement to inform users when force_sum_reduction is enabled due to backend incompatibility. This would help with debugging and make the behavior more transparent. For example: logger.info(f"Forcing sum reduction for FSDP communications due to backend: {backend}")

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic seems too hardware-specific to land in torchtitan parallelize.py
cc @weifengpy if you want to change the default in FSDP2 based on distributed backend

@weifengpy
Copy link
Contributor

weifengpy commented Feb 5, 2026

this logic seems too hardware-specific to land in torchtitan parallelize.py cc @weifengpy if you want to change the default in FSDP2 based on distributed backend

@saforem2 we can do this in fsdp2. just add "xpu" next to "mtia"

https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L715-L716

btw, a lot of fsdp2 + xpu tests are disabled. maybe consider enable them to truly make fsdp2 work for xpu: pytorch/pytorch@98e9440#diff-f18e800e9f3510cb444d863248a2f10f6896a2861ff76352b41d58b38e081b7bR419

@saforem2
Copy link
Author

saforem2 commented Feb 6, 2026

I realized after creating this that the fix for XCCL supporting PreMul Sum was merged in

pytorch/pytorch#172298

so this should be resolved with a nightly build of PyTorch on XPU; in that case, happy to close this

@saforem2 saforem2 closed this Feb 6, 2026
@githubsgi
Copy link
Contributor

githubsgi commented Feb 11, 2026

@weifengpy , the root cause of the issue is the introduction of disable_fsdp_gradient_division(model) in TorchTitan that ends up calling torch.distributed._make_nccl_premul_sum(1 / factor) .

@weifengpy
Copy link
Contributor

_make_nccl_premul_sum(1 / factor)

are you talking about this _make_nccl_premul_sum? https://github.com/pytorch/pytorch/blob/5ef2e503b46dadeee246f566a64510c7592975a4/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L715-L739

I thought disable_fsdp_gradient_division(model) is a must-have in titan. FSDP2 made an easy assumption that grads are uniformly all-reduced. but that's can be wrong in titan's case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants