-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Add NCCL device kernels for AR+RMS fusion #7910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[None][feat] Add NCCL device kernels for AR+RMS fusion #7910
Conversation
6e1c6cd to
39b2e16
Compare
📝 WalkthroughWalkthroughAdds a new NCCL_DEVICE all-reduce strategy and device-side fusion module (nccl_device): build targets, CUDA kernels, multimem/vector helpers, launch-config factory, runtime dispatch and allocator support, Python/enum plumbing, benchmarks, and tests. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant PY as Python API
participant ME as ModelEngine
participant OP as AllreduceOp (C++)
participant UB as NCCLUserBufferAllocator
participant KD as nccl_device::LaunchConfig
participant KRN as nccl_device Kernel
participant NCCL as NCCL (host+device)
PY->>ME: request AllReduce with strategy="NCCL_DEVICE"
ME->>OP: execute AllReduce (inputs, fusion attrs)
OP->>UB: getNCCLDevComm(numBarriers)
UB->>NCCL: resolve/create device communicator
UB-->>OP: ncclDevComm
OP->>UB: getCachedNCCLDeviceLaunchConfig(dtype, dims, flags)
UB-->>OP: LaunchConfig (KD)
OP->>KRN: KD.launchRMSNorm(..., devComm, stream)
KRN->>NCCL: device-side allreduce (multimem ld/st)
KRN-->>OP: outputs written
OP-->>ME: return tensors
ME-->>PY: result
sequenceDiagram
autonumber
participant OP as AllreduceOp (C++)
participant SYM as UB Symmetric Buffers
participant F as Fallback Path
OP->>SYM: Verify symmetric UB buffer
alt buffer missing
OP->>SYM: Create symmetric input and copy data
end
alt device fusion unsupported or invalid
OP->>F: fallbackRunSubsequentOps(...)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@CodeRabbit review |
9ead5ce to
ca6dd60
Compare
nv-lschneider
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rebased the code and addressesd your comments. Thx for the patience.
Rebasing takes awhile.
| goto default_case; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, using intentional fallthrough without goto instead.
(If we support more cases, we will have refactor the default case out.)
| k_chunk_size = a.size(1) // tensor_parallel_size | ||
| b.size(0) // tensor_parallel_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing unused command.
| * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * See LICENSE.txt for license information | ||
| ************************************************************************/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the license of the 2 files. And double checked that the other files are OK too.
a7b677d to
6cc2722
Compare
Tabrizian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, merging is blocked until the PyTorch container upgrades to 2.28 version.
| protected: | ||
| bool oneShot; | ||
| int token_per_rank; | ||
| int start_token; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A nit.
But seems the style is mixed with "a_b" and "aB" in this file. Can we make it unified in one file/component? Thx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is fair, I change it now to be consistently cameCase. And do the same for the kernels.cuh file
| * limitations under the License. | ||
| */ | ||
|
|
||
| #ifndef TRTLLM_NCCL_DEVICE_CONSTANTS_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A nit. Normally we use "pragma once" in TRTLLM headers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing that out, I change to pragma once over inclusion guards in my files now.
| #define TRTLLM_NCCL_DEVICE_MULTIMEM_H | ||
|
|
||
| #include <cuda_fp16.h> | ||
| #if CUDART_VERSION >= 11000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not needed, we only support cuda 12.9 and 13 now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I can remove it. It seems like we don't need to be backward compatible.
|
|
||
| // Base template for multimemLoadSum - device assert for SM < 90 | ||
| template <typename ptrT, typename valT> | ||
| __device__ __forceinline__ valT multimemLoadSum(ptrT const* addr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall this file has suffix ".cuh" since it has __device__ identifier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point these are all device functions, so we should rename them to .cuh.
| for (int mask = 16; mask > 0; mask >>= 1) | ||
| val[i] += __shfl_xor_sync(kFinalMask, val[i], mask, kWarpSize); | ||
| } | ||
| return (T) (0.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fully understand this one. Is this return value used/needed? Sine its always 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not 100% sure.
I took this part directly from the user buffers.
My assumption was that it was needed for the compiler to not optimize the code out.
But I will test that and remove the return type if it is unnecessary here.
… and NEW NCCL device API to use NCCL to fuse RMS Norm with AllReduce. Signed-off-by: Ludwig Schneider <[email protected]>
…) (NVIDIA#7900) Signed-off-by: Yan Chunwei <[email protected]> Signed-off-by: Ludwig Schneider <[email protected]> pre-commit changes Signed-off-by: Ludwig Schneider <[email protected]> clang formatting Signed-off-by: Ludwig Schneider <[email protected]> safe guarding NCCL 2.27 build Signed-off-by: Ludwig Schneider <[email protected]> fixing precommit formatting Signed-off-by: Ludwig Schneider <[email protected]> most of code rabbit comments Signed-off-by: Ludwig Schneider <[email protected]> adding missing semi-colon Signed-off-by: Ludwig Schneider <[email protected]> removing unused comment lines Signed-off-by: Ludwig Schneider <[email protected]> Clarifying the test on how to compre residual chunked and unchunked. Signed-off-by: Ludwig Schneider <[email protected]> fixing pre-commit Signed-off-by: Ludwig Schneider <[email protected]> fixing pre-commit Signed-off-by: Ludwig Schneider <[email protected]> fixing missing variable, rebase complete and tested Signed-off-by: Ludwig Schneider <[email protected]> using a grid stride loop with less blocks launched for large message sizes Signed-off-by: Ludwig Schneider <[email protected]> using functioning grid stride loop for NCCL_DEVICE. It helps with better performance at larger message sizes Signed-off-by: Ludwig Schneider <[email protected]> initial oneshot implementation Signed-off-by: Ludwig Schneider <[email protected]> minor tweaks to include one shot fixes Signed-off-by: Ludwig Schneider <[email protected]> enabling grid stride loop, but no perf benefit. Signed-off-by: Ludwig Schneider <[email protected]> addressing review feedback Signed-off-by: Ludwig Schneider <[email protected]> fix formatting Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]> better UB init handling Signed-off-by: Ludwig Schneider <[email protected]> accept multiple strategies Signed-off-by: Ludwig Schneider <[email protected]> test to debug mnnvl Signed-off-by: Ludwig Schneider <[email protected]> rebasing and addressing comments Signed-off-by: Ludwig Schneider <[email protected]> remove unneeded type decl Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
Signed-off-by: Ludwig Schneider <[email protected]>
6cc2722 to
ef1cf92
Compare
Signed-off-by: Ludwig Schneider <[email protected]>
|
Converted to draft. |
Summary by CodeRabbit
New Features
Tests
Documentation
Chores
Description
The MR introduces a new kernel launch mechanism to support kernels with the NCCL device API.
It implements 1 kernel to start with: RESIDUAL_RMS_NORM for fp16 types.
This new kernel is meant to replace/enhance the performance of AllReduce using the stable NCCL API.
This is the first kernel of potentially more variations for best performance. The default AR selection strategy is not impacted yet.
It is designed to be low latency for small to medium message sizes.
The MR uses the existing NCCLUBAllocator and extends it to hold necessary persistent resources like NCCL registered memory windows and device communicators.
The AllReduce Operation is implemented as a new AllReduceStrategy and launched from the AllReduce plugin.
cpp/tensorrt_llm/thop/allreduceOP.cpp.It launches its own new kernels at
cpp/tensorrt_llm/kernels/nccl_device.The kernel itself is highly templated to be flexible for future demands without impeding runtime performance.
This MR implements the new kernel in a two-shot / fp16 implementation first.
It is already competitive in this form, however after adoption of this kernel further modifications and additions can be included.
support in the future.
Test Coverage
tests/unittest/_torch/multi_gpu/test_nccl_device.pytests/microbenchmark/all_reduce.pyThe microbenchmark has been updated slightly. It includes now the new strategy and optionally also UB for comparison.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
Important Caveat
This change requires NCCL 2.28 to run successfully.
SInce the current dev container of TRT-LLM does not use 2.28 yet, I would like to gather some feedback before 2.28 becomes available.
A real test will only be possible when version 2.28 is included in the dev container.