Fix for incompatible ReduceOp.PREMUL_SUM on XPU devices#2332
Fix for incompatible ReduceOp.PREMUL_SUM on XPU devices#2332saforem2 wants to merge 13 commits intopytorch:mainfrom
ReduceOp.PREMUL_SUM on XPU devices#2332Conversation
<details closed><summary>Traceback:</summary> ```bash [rank0]: Traceback (most recent call last): [rank0]: File "<frozen runpy>", line 198, in _run_module_as_main [rank0]: File "<frozen runpy>", line 88, in _run_code [rank0]: File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 823, in <module> [rank0]: main(Trainer) [rank0]: File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 808, in main [rank0]: trainer.train() [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper [rank0]: return f(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^ [rank0]: File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 714, in train [rank0]: self.train_step(data_iterator) [rank0]: File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 599, in train_step [rank0]: loss = self.forward_backward_step( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/lus/tegu/projects/datascience/foremans/projects/saforem2/torchtitan/torchtitan/experiments/ezpz/train.py", line 557, in forward_backward_step [rank0]: loss.backward() [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward [rank0]: torch.autograd.backward( [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward [rank0]: _engine_run_backward( [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward [rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply [rank0]: return user_fn(self, *args) [rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 900, in backward [rank0]: ctx.param_group.post_backward() [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 566, in post_backward [rank0]: ) = foreach_reduce( [rank0]: ^^^^^^^^^^^^^^^ [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 543, in foreach_reduce [rank0]: reduce_scatter_comm( [rank0]: File "/lus/tegu/projects/datasets/software/26.26.0/wheelforge/envs/frameworks_install/frameworks_2025.3.1-RC3/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 125, in __call__ [rank0]: return dist.reduce_scatter_tensor( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^[rank1]: Traceback (most recent call last): ``` </details>
There was a problem hiding this comment.
Pull request overview
This PR fixes a ValueError that occurs when using FSDP with XPU devices and the xccl backend. The error "Cannot use ReduceOp.PREMUL_SUM with XCCL" is addressed by adding fallback logic to force sum reduction for backends that don't support the PREMUL_SUM operation.
Changes:
- Added backend detection logic in
disable_fsdp_gradient_divisionto check if the current backend is NCCL - For non-NCCL backends, the function now calls
set_force_sum_reduction_for_comms(True)on FSDP modules to use regular SUM operations instead of PREMUL_SUM - Added
torch.distributedimport to support backend detection
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| force_sum_reduction = False | ||
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
| backend = torch.distributed.get_backend() | ||
| if backend and backend.lower() != "nccl": |
There was a problem hiding this comment.
The logic backend.lower() != "nccl" enables force_sum_reduction for ALL non-NCCL backends (e.g., gloo, fake, etc.), not just xccl. This is overly broad and may force sum reduction unnecessarily for backends that support PREMUL_SUM. Consider being more specific by checking if backend is in a list of known incompatible backends, such as checking backend.lower() in ("xccl", "gloo") or more specifically backend.lower() == "xccl" if only xccl is known to have this limitation.
| if backend and backend.lower() != "nccl": | |
| if backend and backend.lower() == "xccl": |
| force_sum_reduction = False | ||
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
| backend = torch.distributed.get_backend() | ||
| if backend and backend.lower() != "nccl": | ||
| force_sum_reduction = True |
There was a problem hiding this comment.
Add a comment explaining why force_sum_reduction is needed for certain backends. The comment should clarify that some backends (like xccl) do not support ReduceOp.PREMUL_SUM and require falling back to regular SUM operations. This will help future maintainers understand the purpose of this workaround.
| """ | ||
| Disable FSDP's automatic gradient division for all FSDP modules. | ||
|
|
||
| Set gradient_divide_factor=1.0 to disable FSDP's automatic gradient division. | ||
| We handle gradient scaling ourselves in the training loop with global token count. | ||
|
|
||
| Args: | ||
| model: The model containing FSDP-wrapped modules | ||
| """ |
There was a problem hiding this comment.
The docstring for disable_fsdp_gradient_division should be updated to document the new behavior of forcing sum reduction for non-NCCL backends. This is a significant change to the function's behavior that should be documented for users.
| if force_sum_reduction: | ||
| module.set_force_sum_reduction_for_comms(True) |
There was a problem hiding this comment.
Consider adding a logger statement to inform users when force_sum_reduction is enabled due to backend incompatibility. This would help with debugging and make the behavior more transparent. For example: logger.info(f"Forcing sum reduction for FSDP communications due to backend: {backend}")
tianyu-l
left a comment
There was a problem hiding this comment.
this logic seems too hardware-specific to land in torchtitan parallelize.py
cc @weifengpy if you want to change the default in FSDP2 based on distributed backend
@saforem2 we can do this in fsdp2. just add "xpu" next to "mtia" btw, a lot of fsdp2 + xpu tests are disabled. maybe consider enable them to truly make fsdp2 work for xpu: pytorch/pytorch@98e9440#diff-f18e800e9f3510cb444d863248a2f10f6896a2861ff76352b41d58b38e081b7bR419 |
|
I realized after creating this that the fix for XCCL supporting PreMul Sum was merged in so this should be resolved with a nightly build of PyTorch on XPU; in that case, happy to close this |
|
@weifengpy , the root cause of the issue is the introduction of disable_fsdp_gradient_division(model) in TorchTitan that ends up calling torch.distributed._make_nccl_premul_sum(1 / factor) . |
are you talking about this _make_nccl_premul_sum? https://github.com/pytorch/pytorch/blob/5ef2e503b46dadeee246f566a64510c7592975a4/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L715-L739 I thought disable_fsdp_gradient_division(model) is a must-have in titan. FSDP2 made an easy assumption that grads are uniformly all-reduced. but that's can be wrong in titan's case |
Add fallback to enable
module.set_force_sum_reduction_for_commswhen on XPU devices withxcclbackendFull Traceback: