-
Notifications
You must be signed in to change notification settings - Fork 540
FA num splits option #2357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FA num splits option #2357
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -669,6 +669,7 @@ def forward( | |||||||||||||||
| inference_params: Optional[InferenceParams] = None, | ||||||||||||||||
| flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), | ||||||||||||||||
| fp8_output: bool = False, | ||||||||||||||||
| num_splits: Optional[int] = None, | ||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||
| """flash-attn fprop""" | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -843,6 +844,16 @@ def forward( | |||||||||||||||
| use_flash_attn_3 = False | ||||||||||||||||
| if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): | ||||||||||||||||
| use_flash_attn_3 = True | ||||||||||||||||
| # Enforce FA3 when num_splits is provided | ||||||||||||||||
| if num_splits is not None and not use_flash_attn_3: | ||||||||||||||||
| if not fa_utils.v3_is_installed: | ||||||||||||||||
| raise ValueError( | ||||||||||||||||
| "num_splits is only supported with FlashAttention-3, which is not installed. " | ||||||||||||||||
| ) | ||||||||||||||||
| raise ValueError( | ||||||||||||||||
| "num_splits is only supported with FlashAttention-3. " | ||||||||||||||||
| "Please adjust configuration to enable FA3 for these inputs." | ||||||||||||||||
| ) | ||||||||||||||||
| if context_parallel and all( | ||||||||||||||||
| not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] | ||||||||||||||||
| ): | ||||||||||||||||
|
|
@@ -925,6 +936,9 @@ def forward( | |||||||||||||||
| fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes | ||||||||||||||||
| if fa_utils.v2_4_1_plus: | ||||||||||||||||
| fa_optional_forward_kwargs["deterministic"] = self.deterministic | ||||||||||||||||
| if num_splits is not None: | ||||||||||||||||
| # Forward optional split control to flash-attn if available | ||||||||||||||||
| fa_optional_forward_kwargs["num_splits"] = num_splits | ||||||||||||||||
|
Comment on lines
+939
to
+941
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Verify that the flash-attn version supports
Comment on lines
+939
to
+941
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Unlike other optional parameters (
Suggested change
|
||||||||||||||||
| if inference_params is not None: | ||||||||||||||||
| # use block_table kwarg to support thd_2bshd for non-paged | ||||||||||||||||
| fa_optional_forward_kwargs["block_table"] = ( | ||||||||||||||||
|
|
@@ -959,6 +973,9 @@ def forward( | |||||||||||||||
| fa_3_optional_forward_kwargs["page_table"] = ( | ||||||||||||||||
| inference_params.cache_manager.page_table[:batch_size] | ||||||||||||||||
| ) | ||||||||||||||||
| if num_splits is not None: | ||||||||||||||||
| # Forward optional split control to flash-attn v3 if supported | ||||||||||||||||
| fa_3_optional_forward_kwargs["num_splits"] = num_splits | ||||||||||||||||
|
Comment on lines
+976
to
+978
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Same as FA v2: verify flash-attn v3 supports |
||||||||||||||||
| if fp8: | ||||||||||||||||
| QKV_quantizer = quantizers["scaling_fwd"][META_QKV] | ||||||||||||||||
| torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) | ||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -799,6 +799,7 @@ def forward( | |
| inference_params: Optional[InferenceParams] = None, | ||
| pad_between_seqs: Optional[bool] = None, | ||
| fp8_output: Optional[bool] = False, | ||
| num_splits: Optional[int] = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Missing documentation for the |
||
| ) -> torch.Tensor: | ||
| """ | ||
| Dot Product Attention Layer. | ||
|
|
@@ -973,6 +974,11 @@ def forward( | |
| If true, there are padding tokens between individual sequences in a packed batch. | ||
| fp8_output: Optional[bool], default = `False` | ||
| Whether to enforce output to be in FP8 or not. | ||
| num_splits: Optional[int], default = `None` | ||
| Optional split control for FlashAttention-3 only. When set, this value is forwarded | ||
| to the FA3 backend to control internal kernel splitting behavior. It is ignored for | ||
| other backends and will raise a ValueError if a non-FA3 backend is selected or if | ||
| FlashAttention-3 is not installed. | ||
| """ | ||
|
|
||
| with torch.cuda.device(query_layer.device), self.prepare_forward( | ||
|
|
@@ -1366,6 +1372,34 @@ def forward( | |
| fused_attention_backend = _attention_backends["fused_attention_backend"] | ||
| use_unfused_attention = _attention_backends["use_unfused_attention"] | ||
|
|
||
| # If num_splits is requested, ensure we are using FlashAttention-3. | ||
| if num_splits is not None: | ||
| is_fa3_selected = ( | ||
| use_flash_attention | ||
| and flash_attention_backend is not None | ||
| and flash_attention_backend == dpa_utils.FlashAttentionUtils.fa3_version | ||
| ) | ||
| if not is_fa3_selected: | ||
| backend_name = ( | ||
| f"FlashAttention ({str(flash_attention_backend)})" | ||
| if use_flash_attention | ||
| else ( | ||
| "FusedAttention" | ||
| if use_fused_attention | ||
| else "UnfusedDotProductAttention" | ||
| ) | ||
| ) | ||
| if not dpa_utils.FlashAttentionUtils.v3_is_installed: | ||
| raise ValueError( | ||
| "num_splits is only supported with FlashAttention-3, which is not" | ||
| " installed. " | ||
| ) | ||
| raise ValueError( | ||
| "num_splits is only supported with FlashAttention-3. Selected backend is" | ||
| f" {backend_name}. Please adjust configuration to enable FA3 for these" | ||
| " inputs." | ||
| ) | ||
|
|
||
| # raise exception if no backend is available | ||
| if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: | ||
| raise ValueError( | ||
|
|
@@ -1413,6 +1447,7 @@ def forward( | |
| inference_params=inference_params, | ||
| flash_attention_backend=flash_attention_backend, | ||
| fp8_output=fp8_output, | ||
| num_splits=num_splits, | ||
| ) | ||
|
|
||
| if use_fused_attention: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make the num_splits a separate test, instead of piggybacking on the max_logit test :) You can still call test_dot_product_attention in it the same way you do here. Thanks!