Skip to content

Conversation

@nv-lschneider
Copy link
Collaborator

@nv-lschneider nv-lschneider commented Sep 22, 2025

Summary by CodeRabbit

  • New Features

    • Added AllReduce strategy NCCL_DEVICE with device-side fused AllReduce (residual RMSNorm), multimem support, and launch-configuration caching; exposed in Python and runtime config; userbuffers updated to support it.
  • Tests

    • Added multi-GPU NCCL device tests and extended microbenchmark with only-ub option and improved size reporting.
  • Documentation

    • API references updated to include NCCL_DEVICE.
  • Chores

    • Updated spell-check and pre-commit configuration.

Description

The MR introduces a new kernel launch mechanism to support kernels with the NCCL device API.
It implements 1 kernel to start with: RESIDUAL_RMS_NORM for fp16 types.

This new kernel is meant to replace/enhance the performance of AllReduce using the stable NCCL API.
This is the first kernel of potentially more variations for best performance. The default AR selection strategy is not impacted yet.

It is designed to be low latency for small to medium message sizes.

The MR uses the existing NCCLUBAllocator and extends it to hold necessary persistent resources like NCCL registered memory windows and device communicators.

The AllReduce Operation is implemented as a new AllReduceStrategy and launched from the AllReduce plugin. cpp/tensorrt_llm/thop/allreduceOP.cpp.
It launches its own new kernels at cpp/tensorrt_llm/kernels/nccl_device.

The kernel itself is highly templated to be flexible for future demands without impeding runtime performance.

This MR implements the new kernel in a two-shot / fp16 implementation first.
It is already competitive in this form, however after adoption of this kernel further modifications and additions can be included.

  1. one-shot flavor
  2. fp8 and fp4
    support in the future.

Test Coverage

  • Python unit test: tests/unittest/_torch/multi_gpu/test_nccl_device.py
  • Microbenchmakr: tests/microbenchmark/all_reduce.py

The microbenchmark has been updated slightly. It includes now the new strategy and optionally also UB for comparison.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

Important Caveat

This change requires NCCL 2.28 to run successfully.
SInce the current dev container of TRT-LLM does not use 2.28 yet, I would like to gather some feedback before 2.28 becomes available.
A real test will only be possible when version 2.28 is included in the dev container.

@nv-lschneider nv-lschneider requested review from a team as code owners September 22, 2025 16:53
@coderabbitai coderabbitai bot changed the title [None] @coderabbit title [None] [feat] Add NCCL device kernels; enable NCCL_DEVICE all-reduce title Sep 22, 2025
@nv-lschneider nv-lschneider force-pushed the introducing-nccl-device-ar branch 2 times, most recently from 6e1c6cd to 39b2e16 Compare September 22, 2025 17:08
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 22, 2025

📝 Walkthrough

Walkthrough

Adds a new NCCL_DEVICE all-reduce strategy and device-side fusion module (nccl_device): build targets, CUDA kernels, multimem/vector helpers, launch-config factory, runtime dispatch and allocator support, Python/enum plumbing, benchmarks, and tests.

Changes

Cohort / File(s) Summary
Tooling: codespell & pre-commit
\.codespellignore, \.pre-commit-config.yaml
Add "commIter" to .codespellignore and pass .codespellignore to the codespell pre-commit hook.
Build: enable nccl_device module
cpp/tensorrt_llm/kernels/CMakeLists.txt, cpp/tensorrt_llm/kernels/nccl_device/CMakeLists.txt
Add nccl_device subdirectory and new CUDA library target tensorrt_llm_nccl_device with include paths, CUDA properties, link, and install rules.
Enum & bindings
cpp/tensorrt_llm/kernels/customAllReduceKernels.h, cpp/tensorrt_llm/pybind/runtime/bindings.cpp, tensorrt_llm/functional.py
Add NCCL_DEVICE = 9 to AllReduceStrategyType and expose it in Python bindings and the Python IntEnum.
nccl_device public headers & constants
cpp/tensorrt_llm/kernels/nccl_device/constants.h, .../vector_types.h, .../multimem.h, .../kernels.h
Add device constants, vector wrapper types, architecture-gated multimem load/store intrinsics, warp/block reduce helpers, and the fusedAllReduceRMSNorm kernel templates.
nccl_device launch-config implementation
cpp/tensorrt_llm/kernels/nccl_device/config.h, cpp/tensorrt_llm/kernels/nccl_device/config.cu
Add LaunchConfig base, TypedLaunchConfig, factory makeLaunchConfig, validity/occupancy checks, NCCL-version gating, kernel pointer resolution, and type-specialized kernel launch paths.
Allocator: NCCL device comm & LaunchConfig cache
cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h, .../ub_allocator.cpp
Add destructor, device communicator create/destroy symbol resolution, getters for new NCCL symbols, per-block dev-comm map, LaunchConfigKey and mLaunchConfigCache, and getCachedNCCLDeviceLaunchConfig + getNCCLDevComm.
Runtime op dispatch (C++)
cpp/tensorrt_llm/thop/allreduceOp.cpp
Add handling for NCCL_DEVICE, device-fusion path runNCCLAllReduceDeviceFusion, UB symmetry handling, logging and fallback behavior.
Python runtime wiring
tensorrt_llm/_torch/model_config.py, tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/llmapi/llm_args.py
Map string "NCCL_DEVICE" to enum, treat NCCL_DEVICE like NCCL_SYMMETRIC for enabling user buffers, and extend allowed Literal values to include "NCCL_DEVICE".
Plugins: runtime strategy support
cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
Treat NCCL_DEVICE alongside NCCL_SYMMETRIC in format checks, enqueue, initialization gating, and status/logging branches.
Benchmarks
tests/microbenchmarks/all_reduce.py
Add only_ub mode, extend CLI, adjust benchmark loops and metrics, print message size in bytes.
Tests: multi-GPU
tests/unittest/_torch/multi_gpu/test_nccl_device.py
New multi-GPU test validating UB + NCCL device RMSNorm all-reduce path with per-rank checks and MPI executor.
API stability reference
tests/unittest/api_stability/references/llm.yaml
Allow NCCL_DEVICE in allreduce_strategy Literal.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant PY as Python API
  participant ME as ModelEngine
  participant OP as AllreduceOp (C++)
  participant UB as NCCLUserBufferAllocator
  participant KD as nccl_device::LaunchConfig
  participant KRN as nccl_device Kernel
  participant NCCL as NCCL (host+device)

  PY->>ME: request AllReduce with strategy="NCCL_DEVICE"
  ME->>OP: execute AllReduce (inputs, fusion attrs)
  OP->>UB: getNCCLDevComm(numBarriers)
  UB->>NCCL: resolve/create device communicator
  UB-->>OP: ncclDevComm
  OP->>UB: getCachedNCCLDeviceLaunchConfig(dtype, dims, flags)
  UB-->>OP: LaunchConfig (KD)
  OP->>KRN: KD.launchRMSNorm(..., devComm, stream)
  KRN->>NCCL: device-side allreduce (multimem ld/st)
  KRN-->>OP: outputs written
  OP-->>ME: return tensors
  ME-->>PY: result
Loading
sequenceDiagram
  autonumber
  participant OP as AllreduceOp (C++)
  participant SYM as UB Symmetric Buffers
  participant F as Fallback Path

  OP->>SYM: Verify symmetric UB buffer
  alt buffer missing
    OP->>SYM: Create symmetric input and copy data
  end
  alt device fusion unsupported or invalid
    OP->>F: fallbackRunSubsequentOps(...)
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • nv-guomingz
  • liji-nv
  • shaharmor98
  • Superjomn

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 9.43% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: adding NCCL device kernels for AllReduce with RMS normalization fusion.
Description check ✅ Passed The description provides a clear explanation of what the PR does, why it's needed, test coverage, and relevant caveats about NCCL 2.28 requirement.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@nv-lschneider
Copy link
Collaborator Author

@CodeRabbit review

@Tabrizian Tabrizian changed the title [None] [feat] Add NCCL device kernels; enable NCCL_DEVICE all-reduce title [None][feat] Add NCCL device kernels for AR+RMS fusion Nov 4, 2025
@nv-lschneider nv-lschneider requested a review from a team as a code owner November 5, 2025 18:00
@nv-lschneider nv-lschneider force-pushed the introducing-nccl-device-ar branch from 9ead5ce to ca6dd60 Compare November 5, 2025 22:57
Copy link
Collaborator Author

@nv-lschneider nv-lschneider left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased the code and addressesd your comments. Thx for the patience.
Rebasing takes awhile.

Comment on lines 430 to 564
goto default_case;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, using intentional fallthrough without goto instead.
(If we support more cases, we will have refactor the default case out.)

Comment on lines 66 to 81
k_chunk_size = a.size(1) // tensor_parallel_size
b.size(0) // tensor_parallel_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing unused command.

* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the license of the 2 files. And double checked that the other files are OK too.

@nv-lschneider nv-lschneider force-pushed the introducing-nccl-device-ar branch from a7b677d to 6cc2722 Compare November 7, 2025 21:00
Copy link
Member

@Tabrizian Tabrizian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, merging is blocked until the PyTorch container upgrades to 2.28 version.

protected:
bool oneShot;
int token_per_rank;
int start_token;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nit.
But seems the style is mixed with "a_b" and "aB" in this file. Can we make it unified in one file/component? Thx

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is fair, I change it now to be consistently cameCase. And do the same for the kernels.cuh file

* limitations under the License.
*/

#ifndef TRTLLM_NCCL_DEVICE_CONSTANTS_H
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nit. Normally we use "pragma once" in TRTLLM headers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing that out, I change to pragma once over inclusion guards in my files now.

#define TRTLLM_NCCL_DEVICE_MULTIMEM_H

#include <cuda_fp16.h>
#if CUDART_VERSION >= 11000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not needed, we only support cuda 12.9 and 13 now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I can remove it. It seems like we don't need to be backward compatible.


// Base template for multimemLoadSum - device assert for SM < 90
template <typename ptrT, typename valT>
__device__ __forceinline__ valT multimemLoadSum(ptrT const* addr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this file has suffix ".cuh" since it has __device__ identifier?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point these are all device functions, so we should rename them to .cuh.

for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(kFinalMask, val[i], mask, kWarpSize);
}
return (T) (0.0f);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fully understand this one. Is this return value used/needed? Sine its always 0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure.
I took this part directly from the user buffers.
My assumption was that it was needed for the compiler to not optimize the code out.
But I will test that and remove the return type if it is unnecessary here.

nv-lschneider and others added 17 commits December 1, 2025 09:01
… and NEW NCCL device API to use NCCL to fuse RMS Norm with AllReduce.

Signed-off-by: Ludwig Schneider <[email protected]>
…) (NVIDIA#7900)

Signed-off-by: Yan Chunwei <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>

pre-commit changes

Signed-off-by: Ludwig Schneider <[email protected]>

clang formatting

Signed-off-by: Ludwig Schneider <[email protected]>

safe guarding NCCL 2.27 build

Signed-off-by: Ludwig Schneider <[email protected]>

fixing precommit formatting

Signed-off-by: Ludwig Schneider <[email protected]>

most of code rabbit comments

Signed-off-by: Ludwig Schneider <[email protected]>

adding missing semi-colon

Signed-off-by: Ludwig Schneider <[email protected]>

removing unused comment lines

Signed-off-by: Ludwig Schneider <[email protected]>

Clarifying the test on how to compre residual chunked and unchunked.

Signed-off-by: Ludwig Schneider <[email protected]>

fixing pre-commit

Signed-off-by: Ludwig Schneider <[email protected]>

fixing pre-commit

Signed-off-by: Ludwig Schneider <[email protected]>

fixing missing variable, rebase complete and tested

Signed-off-by: Ludwig Schneider <[email protected]>

using a grid stride loop with less blocks launched for large message sizes

Signed-off-by: Ludwig Schneider <[email protected]>

using functioning grid stride loop for NCCL_DEVICE. It helps with better performance at larger message sizes

Signed-off-by: Ludwig Schneider <[email protected]>

initial oneshot implementation

Signed-off-by: Ludwig Schneider <[email protected]>

minor tweaks to include one shot

fixes

Signed-off-by: Ludwig Schneider <[email protected]>

enabling grid stride loop, but no perf benefit.

Signed-off-by: Ludwig Schneider <[email protected]>

addressing review feedback

Signed-off-by: Ludwig Schneider <[email protected]>

fix formatting

Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>

better UB init handling

Signed-off-by: Ludwig Schneider <[email protected]>

accept multiple strategies

Signed-off-by: Ludwig Schneider <[email protected]>

test to debug mnnvl

Signed-off-by: Ludwig Schneider <[email protected]>

rebasing and addressing comments

Signed-off-by: Ludwig Schneider <[email protected]>

remove unneeded type decl

Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
@nv-lschneider nv-lschneider force-pushed the introducing-nccl-device-ar branch from 6cc2722 to ef1cf92 Compare December 1, 2025 15:01
@nv-lschneider nv-lschneider requested a review from a team as a code owner December 1, 2025 15:01
Signed-off-by: Ludwig Schneider <[email protected]>
@nv-lschneider nv-lschneider marked this pull request as draft December 11, 2025 16:40
@nv-lschneider
Copy link
Collaborator Author

Converted to draft.
After merging #9314
This needs to be refactored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants