Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
484a814
[Common] Deleted unused header (#2324)
Oleg-Goncharov Oct 31, 2025
cbb8248
[JAX] L1_jax_distributed_test suit with individual executions (#2321)
phu0ngng Nov 3, 2025
190a84c
for branch
Nov 5, 2025
9c15c57
clean up and tests
wdykas Nov 10, 2025
1446f0a
change tests
wdykas Nov 12, 2025
7e8f3ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
e89d7a9
[PyTorch debug] Fixes to debug tests failures (#2268)
pggPL Nov 4, 2025
2ca40f9
[PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137)
pggPL Nov 5, 2025
7a06387
[JAX] Fix bug with pre scale bias (#2300)
pggPL Nov 5, 2025
6693da3
[JAX] Try to use pre-downloaded dataset artifacts first (#2345)
jberchtold-nvidia Nov 6, 2025
a68eda5
Fix out of bounds access in the FP4 dequantize kernel (#2346)
ptrendx Nov 6, 2025
7b395bf
Make FP8 weights compatible with older MCore version (#2342)
kunlunl Nov 6, 2025
3299742
[JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#…
jberchtold-nvidia Nov 7, 2025
b97fb5a
Fix sharding of segment position to match id in ring attention. (#2349)
mgoldfarb-nvidia Nov 7, 2025
7217045
Disable cuDNN attention for known IMA and NaNs (#2344)
ksivaman Nov 7, 2025
624fe37
[JAX] Default to fused attention in JAX DPA (#2363)
KshitijLakhani Nov 7, 2025
868f18a
Update cudnn frontend to v1.16.0 (#2362)
ksivaman Nov 7, 2025
46a0b65
[common] Remove kvpacked and qkvpacked attention functions for every …
pggPL Nov 7, 2025
ac3a513
Move Triton to common (#2359)
tdophung Nov 10, 2025
16316ba
[JAX] Fused layers argument default values changed (#2347)
tdophung Nov 10, 2025
52a5f37
remove comment from gpt
wdykas Nov 12, 2025
cf381a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
20ee52a
Merge branch 'NVIDIA:main' into num-splits-attention
wdykas Nov 12, 2025
bba56ba
Merge branch 'main' into num-splits-attention
cyanguwa Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,15 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
dtype,
model_configs,
model,
ckpt_attn,
workspace_opt,
qkv_layout,
swa,
pad_between_seqs,
num_splits=None,
):
"""Test DotProductAttention module"""

Expand Down Expand Up @@ -244,6 +252,7 @@ def test_dot_product_attention(
workspace_opt,
pad_between_seqs,
is_training,
num_splits=num_splits,
)

# Compare results
Expand Down Expand Up @@ -308,6 +317,19 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_0"])
def test_dpa_num_splits(dtype, model_configs, model):
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("num_splits requires FlashAttention-3.")
test_dot_product_attention(
dtype, model_configs, model, False, True, None, False, False, num_splits=2
)


model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
Expand Down Expand Up @@ -848,6 +870,7 @@ def _run_dot_product_attention(
workspace_opt: bool,
pad_between_seqs: bool,
is_training: bool,
num_splits=None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
Expand Down Expand Up @@ -1152,6 +1175,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
# Only pass num_splits when exercising the FlashAttention path
num_splits=(num_splits if backend == "FlashAttention" else None),
)
max_logit = None
if config.return_max_logit:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def forward(
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
num_splits: Optional[int] = None,
) -> torch.Tensor:
"""flash-attn fprop"""

Expand Down Expand Up @@ -843,6 +844,16 @@ def forward(
use_flash_attn_3 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
use_flash_attn_3 = True
# Enforce FA3 when num_splits is provided
if num_splits is not None and not use_flash_attn_3:
if not fa_utils.v3_is_installed:
raise ValueError(
"num_splits is only supported with FlashAttention-3, which is not installed. "
)
raise ValueError(
"num_splits is only supported with FlashAttention-3. "
"Please adjust configuration to enable FA3 for these inputs."
)
Comment on lines +848 to +856
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The validation logic here correctly prevents num_splits from being used with non-FA3 backends. However, there's a potential issue with code maintainability: later in the code (line 939), num_splits is added to fa_optional_forward_kwargs without a version check, unlike other optional parameters. While the current validation prevents reaching that code with num_splits set and FA2, this creates fragile coupling between distant code sections. Consider either: (1) adding a version guard at line 939 similar to window_size and deterministic, or (2) adding an assertion that num_splits is None in the FA2 branch.

if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
):
Expand Down Expand Up @@ -925,6 +936,9 @@ def forward(
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if fa_utils.v2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
Comment on lines +939 to +941
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Verify that the flash-attn version supports num_splits parameter. Unlike other optional parameters (e.g., window_size has fa_utils.v2_3_plus check, deterministic has fa_utils.v2_4_1_plus check), this parameter is added without a version guard.

Comment on lines +939 to +941
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Unlike other optional parameters (window_size has fa_utils.v2_3_plus, deterministic has fa_utils.v2_4_1_plus), num_splits is added to FA2 kwargs without version checking. If flash-attn v2 doesn't support this parameter, this will cause a TypeError at runtime.

Suggested change
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
if num_splits is not None:
# Only add num_splits if flash-attn supports it (check version if needed)
# TODO: Add version check once minimum flash-attn version with num_splits is determined
fa_optional_forward_kwargs["num_splits"] = num_splits

Comment on lines +939 to +941
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This is unreachable dead code. The validation at line 848 already ensures num_splits is None when use_flash_attn_3 is False, so this condition can never be true inside the if not use_flash_attn_3: block. Consider removing this check for code clarity.

Comment on lines +939 to +941
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Dead code: unreachable due to validation at line 848. The check at line 848 ensures that if num_splits is not None, execution raises ValueError unless use_flash_attn_3 is True. Therefore, inside the if not use_flash_attn_3: block (line 931), num_splits must always be None, making this check impossible to satisfy.

Suggested change
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
# num_splits check removed - only supported in FA3 (enforced at line 848)

if inference_params is not None:
# use block_table kwarg to support thd_2bshd for non-paged
fa_optional_forward_kwargs["block_table"] = (
Expand Down Expand Up @@ -959,6 +973,9 @@ def forward(
fa_3_optional_forward_kwargs["page_table"] = (
inference_params.cache_manager.page_table[:batch_size]
)
if num_splits is not None:
# Forward optional split control to flash-attn v3 if supported
fa_3_optional_forward_kwargs["num_splits"] = num_splits
Comment on lines +976 to +978
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same as FA v2: verify flash-attn v3 supports num_splits to avoid potential TypeError

if fp8:
QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def forward(
inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
num_splits: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing documentation for the num_splits parameter in the docstring. Add a parameter description explaining what this controls and when it should be used (e.g., "Controls the number of splits for FlashAttention computation. Used for memory optimization.")

) -> torch.Tensor:
"""
Dot Product Attention Layer.
Expand Down Expand Up @@ -973,6 +974,11 @@ def forward(
If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
num_splits: Optional[int], default = `None`
Optional split control for FlashAttention-3 only. When set, this value is forwarded
to the FA3 backend to control internal kernel splitting behavior. It is ignored for
other backends and will raise a ValueError if a non-FA3 backend is selected or if
FlashAttention-3 is not installed.
"""

with torch.cuda.device(query_layer.device), self.prepare_forward(
Expand Down Expand Up @@ -1366,6 +1372,34 @@ def forward(
fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"]

# If num_splits is requested, ensure we are using FlashAttention-3.
if num_splits is not None:
is_fa3_selected = (
use_flash_attention
and flash_attention_backend is not None
and flash_attention_backend == dpa_utils.FlashAttentionUtils.fa3_version
)
if not is_fa3_selected:
backend_name = (
f"FlashAttention ({str(flash_attention_backend)})"
if use_flash_attention
else (
"FusedAttention"
if use_fused_attention
else "UnfusedDotProductAttention"
)
)
if not dpa_utils.FlashAttentionUtils.v3_is_installed:
raise ValueError(
"num_splits is only supported with FlashAttention-3, which is not"
" installed. "
)
raise ValueError(
"num_splits is only supported with FlashAttention-3. Selected backend is"
f" {backend_name}. Please adjust configuration to enable FA3 for these"
" inputs."
)

# raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
raise ValueError(
Expand Down Expand Up @@ -1413,6 +1447,7 @@ def forward(
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
num_splits=num_splits,
)

if use_fused_attention:
Expand Down