Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,15 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
dtype,
model_configs,
model,
ckpt_attn,
workspace_opt,
qkv_layout,
swa,
pad_between_seqs,
num_splits=None,
):
"""Test DotProductAttention module"""

Expand Down Expand Up @@ -244,6 +252,7 @@ def test_dot_product_attention(
workspace_opt,
pad_between_seqs,
is_training,
num_splits=num_splits,
)

# Compare results
Expand Down Expand Up @@ -301,11 +310,18 @@ def test_dpa_checkpoint(dtype, model_configs, model):
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
@pytest.mark.parametrize("num_splits", [None, 2])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout, num_splits):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the num_splits a separate test, instead of piggybacking on the max_logit test :) You can still call test_dot_product_attention in it the same way you do here. Thanks!

"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_logit = True
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
# Minimal guard: if num_splits is requested, require FA3 be installed
if num_splits is not None:
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("num_splits requires FlashAttention-3.")
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, False, num_splits=num_splits
)


model_configs_softmax = {
Expand Down Expand Up @@ -848,6 +864,7 @@ def _run_dot_product_attention(
workspace_opt: bool,
pad_between_seqs: bool,
is_training: bool,
num_splits=None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
Expand Down Expand Up @@ -1152,6 +1169,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
# Only pass num_splits when exercising the FlashAttention path
num_splits=(num_splits if backend == "FlashAttention" else None),
)
max_logit = None
if config.return_max_logit:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def forward(
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
num_splits: Optional[int] = None,
) -> torch.Tensor:
"""flash-attn fprop"""

Expand Down Expand Up @@ -843,6 +844,16 @@ def forward(
use_flash_attn_3 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
use_flash_attn_3 = True
# Enforce FA3 when num_splits is provided
if num_splits is not None and not use_flash_attn_3:
if not fa_utils.v3_is_installed:
raise ValueError(
"num_splits is only supported with FlashAttention-3, which is not installed. "
)
raise ValueError(
"num_splits is only supported with FlashAttention-3. "
"Please adjust configuration to enable FA3 for these inputs."
)
if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
):
Expand Down Expand Up @@ -925,6 +936,9 @@ def forward(
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if fa_utils.v2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
Comment on lines +939 to +941
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Verify that the flash-attn version supports num_splits parameter. Unlike other optional parameters (e.g., window_size has fa_utils.v2_3_plus check, deterministic has fa_utils.v2_4_1_plus check), this parameter is added without a version guard.

Comment on lines +939 to +941
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Unlike other optional parameters (window_size has fa_utils.v2_3_plus, deterministic has fa_utils.v2_4_1_plus), num_splits is added to FA2 kwargs without version checking. If flash-attn v2 doesn't support this parameter, this will cause a TypeError at runtime.

Suggested change
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
if num_splits is not None:
# Only add num_splits if flash-attn supports it (check version if needed)
# TODO: Add version check once minimum flash-attn version with num_splits is determined
fa_optional_forward_kwargs["num_splits"] = num_splits

if inference_params is not None:
# use block_table kwarg to support thd_2bshd for non-paged
fa_optional_forward_kwargs["block_table"] = (
Expand Down Expand Up @@ -959,6 +973,9 @@ def forward(
fa_3_optional_forward_kwargs["page_table"] = (
inference_params.cache_manager.page_table[:batch_size]
)
if num_splits is not None:
# Forward optional split control to flash-attn v3 if supported
fa_3_optional_forward_kwargs["num_splits"] = num_splits
Comment on lines +976 to +978
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same as FA v2: verify flash-attn v3 supports num_splits to avoid potential TypeError

if fp8:
QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ def forward(
inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
num_splits: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing documentation for the num_splits parameter in the docstring. Add a parameter description explaining what this controls and when it should be used (e.g., "Controls the number of splits for FlashAttention computation. Used for memory optimization.")

) -> torch.Tensor:
"""
Dot Product Attention Layer.
Expand Down Expand Up @@ -973,6 +974,11 @@ def forward(
If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
num_splits: Optional[int], default = `None`
Optional split control for FlashAttention-3 only. When set, this value is forwarded
to the FA3 backend to control internal kernel splitting behavior. It is ignored for
other backends and will raise a ValueError if a non-FA3 backend is selected or if
FlashAttention-3 is not installed.
"""

with torch.cuda.device(query_layer.device), self.prepare_forward(
Expand Down Expand Up @@ -1366,6 +1372,34 @@ def forward(
fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"]

# If num_splits is requested, ensure we are using FlashAttention-3.
if num_splits is not None:
is_fa3_selected = (
use_flash_attention
and flash_attention_backend is not None
and flash_attention_backend == dpa_utils.FlashAttentionUtils.fa3_version
)
if not is_fa3_selected:
backend_name = (
f"FlashAttention ({str(flash_attention_backend)})"
if use_flash_attention
else (
"FusedAttention"
if use_fused_attention
else "UnfusedDotProductAttention"
)
)
if not dpa_utils.FlashAttentionUtils.v3_is_installed:
raise ValueError(
"num_splits is only supported with FlashAttention-3, which is not"
" installed. "
)
raise ValueError(
"num_splits is only supported with FlashAttention-3. Selected backend is"
f" {backend_name}. Please adjust configuration to enable FA3 for these"
" inputs."
)

# raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
raise ValueError(
Expand Down Expand Up @@ -1413,6 +1447,7 @@ def forward(
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
num_splits=num_splits,
)

if use_fused_attention:
Expand Down