diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a671f1eec2..ddb4cba369 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -117,7 +117,15 @@ def reset_global_fp8_state(): @pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("pad_between_seqs", [False]) def test_dot_product_attention( - dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs + dtype, + model_configs, + model, + ckpt_attn, + workspace_opt, + qkv_layout, + swa, + pad_between_seqs, + num_splits=None, ): """Test DotProductAttention module""" @@ -244,6 +252,7 @@ def test_dot_product_attention( workspace_opt, pad_between_seqs, is_training, + num_splits=num_splits, ) # Compare results @@ -301,11 +310,18 @@ def test_dpa_checkpoint(dtype, model_configs, model): @pytest.mark.parametrize("model_configs", [model_configs_max_logit]) @pytest.mark.parametrize("model", model_configs_max_logit.keys()) @pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"]) -def test_dpa_max_logit(dtype, model_configs, model, qkv_layout): +@pytest.mark.parametrize("num_splits", [None, 2]) +def test_dpa_max_logit(dtype, model_configs, model, qkv_layout, num_splits): """Test DotProductAttention module with checkpointing""" config = model_configs[model] config.return_max_logit = True - test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) + # Minimal guard: if num_splits is requested, require FA3 be installed + if num_splits is not None: + if not FlashAttentionUtils.v3_is_installed: + pytest.skip("num_splits requires FlashAttention-3.") + test_dot_product_attention( + dtype, model_configs, model, False, True, qkv_layout, False, False, num_splits=num_splits + ) model_configs_softmax = { @@ -848,6 +864,7 @@ def _run_dot_product_attention( workspace_opt: bool, pad_between_seqs: bool, is_training: bool, + num_splits=None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" # Set RNG and environment varables @@ -1152,6 +1169,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: core_attention_bias=bias, alibi_slopes=alibi_slopes, fast_zero_fill=True, + # Only pass num_splits when exercising the FlashAttention path + num_splits=(num_splits if backend == "FlashAttention" else None), ) max_logit = None if config.return_max_logit: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 147a85fc2f..e8ef3d8b8a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -669,6 +669,7 @@ def forward( inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), fp8_output: bool = False, + num_splits: Optional[int] = None, ) -> torch.Tensor: """flash-attn fprop""" @@ -843,6 +844,16 @@ def forward( use_flash_attn_3 = False if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): use_flash_attn_3 = True + # Enforce FA3 when num_splits is provided + if num_splits is not None and not use_flash_attn_3: + if not fa_utils.v3_is_installed: + raise ValueError( + "num_splits is only supported with FlashAttention-3, which is not installed. " + ) + raise ValueError( + "num_splits is only supported with FlashAttention-3. " + "Please adjust configuration to enable FA3 for these inputs." + ) if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -925,6 +936,9 @@ def forward( fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes if fa_utils.v2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic + if num_splits is not None: + # Forward optional split control to flash-attn if available + fa_optional_forward_kwargs["num_splits"] = num_splits if inference_params is not None: # use block_table kwarg to support thd_2bshd for non-paged fa_optional_forward_kwargs["block_table"] = ( @@ -959,6 +973,9 @@ def forward( fa_3_optional_forward_kwargs["page_table"] = ( inference_params.cache_manager.page_table[:batch_size] ) + if num_splits is not None: + # Forward optional split control to flash-attn v3 if supported + fa_3_optional_forward_kwargs["num_splits"] = num_splits if fp8: QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 4278820e7a..0eec4a449a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -799,6 +799,7 @@ def forward( inference_params: Optional[InferenceParams] = None, pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, + num_splits: Optional[int] = None, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -973,6 +974,11 @@ def forward( If true, there are padding tokens between individual sequences in a packed batch. fp8_output: Optional[bool], default = `False` Whether to enforce output to be in FP8 or not. + num_splits: Optional[int], default = `None` + Optional split control for FlashAttention-3 only. When set, this value is forwarded + to the FA3 backend to control internal kernel splitting behavior. It is ignored for + other backends and will raise a ValueError if a non-FA3 backend is selected or if + FlashAttention-3 is not installed. """ with torch.cuda.device(query_layer.device), self.prepare_forward( @@ -1366,6 +1372,34 @@ def forward( fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] + # If num_splits is requested, ensure we are using FlashAttention-3. + if num_splits is not None: + is_fa3_selected = ( + use_flash_attention + and flash_attention_backend is not None + and flash_attention_backend == dpa_utils.FlashAttentionUtils.fa3_version + ) + if not is_fa3_selected: + backend_name = ( + f"FlashAttention ({str(flash_attention_backend)})" + if use_flash_attention + else ( + "FusedAttention" + if use_fused_attention + else "UnfusedDotProductAttention" + ) + ) + if not dpa_utils.FlashAttentionUtils.v3_is_installed: + raise ValueError( + "num_splits is only supported with FlashAttention-3, which is not" + " installed. " + ) + raise ValueError( + "num_splits is only supported with FlashAttention-3. Selected backend is" + f" {backend_name}. Please adjust configuration to enable FA3 for these" + " inputs." + ) + # raise exception if no backend is available if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: raise ValueError( @@ -1413,6 +1447,7 @@ def forward( inference_params=inference_params, flash_attention_backend=flash_attention_backend, fp8_output=fp8_output, + num_splits=num_splits, ) if use_fused_attention: