-
Notifications
You must be signed in to change notification settings - Fork 546
Add num_splits support for FA3 backend #2357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 23 commits
484a814
cbb8248
190a84c
9c15c57
1446f0a
7e8f3ab
e89d7a9
2ca40f9
7a06387
6693da3
a68eda5
7b395bf
3299742
b97fb5a
7217045
624fe37
868f18a
46a0b65
ac3a513
16316ba
52a5f37
cf381a0
20ee52a
bba56ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -669,6 +669,7 @@ def forward( | |||||||||||||||||||||||
| inference_params: Optional[InferenceParams] = None, | ||||||||||||||||||||||||
| flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), | ||||||||||||||||||||||||
| fp8_output: bool = False, | ||||||||||||||||||||||||
| num_splits: Optional[int] = None, | ||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||
| """flash-attn fprop""" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -843,6 +844,16 @@ def forward( | |||||||||||||||||||||||
| use_flash_attn_3 = False | ||||||||||||||||||||||||
| if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): | ||||||||||||||||||||||||
| use_flash_attn_3 = True | ||||||||||||||||||||||||
| # Enforce FA3 when num_splits is provided | ||||||||||||||||||||||||
| if num_splits is not None and not use_flash_attn_3: | ||||||||||||||||||||||||
| if not fa_utils.v3_is_installed: | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "num_splits is only supported with FlashAttention-3, which is not installed. " | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "num_splits is only supported with FlashAttention-3. " | ||||||||||||||||||||||||
| "Please adjust configuration to enable FA3 for these inputs." | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| if context_parallel and all( | ||||||||||||||||||||||||
| not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
|
|
@@ -925,6 +936,9 @@ def forward( | |||||||||||||||||||||||
| fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes | ||||||||||||||||||||||||
| if fa_utils.v2_4_1_plus: | ||||||||||||||||||||||||
| fa_optional_forward_kwargs["deterministic"] = self.deterministic | ||||||||||||||||||||||||
| if num_splits is not None: | ||||||||||||||||||||||||
| # Forward optional split control to flash-attn if available | ||||||||||||||||||||||||
| fa_optional_forward_kwargs["num_splits"] = num_splits | ||||||||||||||||||||||||
|
Comment on lines
+939
to
+941
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Verify that the flash-attn version supports
Comment on lines
+939
to
+941
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Unlike other optional parameters (
Suggested change
Comment on lines
+939
to
+941
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: This is unreachable dead code. The validation at line 848 already ensures
Comment on lines
+939
to
+941
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Dead code: unreachable due to validation at line 848. The check at line 848 ensures that if
Suggested change
|
||||||||||||||||||||||||
| if inference_params is not None: | ||||||||||||||||||||||||
| # use block_table kwarg to support thd_2bshd for non-paged | ||||||||||||||||||||||||
| fa_optional_forward_kwargs["block_table"] = ( | ||||||||||||||||||||||||
|
|
@@ -959,6 +973,9 @@ def forward( | |||||||||||||||||||||||
| fa_3_optional_forward_kwargs["page_table"] = ( | ||||||||||||||||||||||||
| inference_params.cache_manager.page_table[:batch_size] | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| if num_splits is not None: | ||||||||||||||||||||||||
| # Forward optional split control to flash-attn v3 if supported | ||||||||||||||||||||||||
| fa_3_optional_forward_kwargs["num_splits"] = num_splits | ||||||||||||||||||||||||
|
Comment on lines
+976
to
+978
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Same as FA v2: verify flash-attn v3 supports |
||||||||||||||||||||||||
| if fp8: | ||||||||||||||||||||||||
| QKV_quantizer = quantizers["scaling_fwd"][META_QKV] | ||||||||||||||||||||||||
| torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -799,6 +799,7 @@ def forward( | |
| inference_params: Optional[InferenceParams] = None, | ||
| pad_between_seqs: Optional[bool] = None, | ||
| fp8_output: Optional[bool] = False, | ||
| num_splits: Optional[int] = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Missing documentation for the |
||
| ) -> torch.Tensor: | ||
| """ | ||
| Dot Product Attention Layer. | ||
|
|
@@ -973,6 +974,11 @@ def forward( | |
| If true, there are padding tokens between individual sequences in a packed batch. | ||
| fp8_output: Optional[bool], default = `False` | ||
| Whether to enforce output to be in FP8 or not. | ||
| num_splits: Optional[int], default = `None` | ||
| Optional split control for FlashAttention-3 only. When set, this value is forwarded | ||
| to the FA3 backend to control internal kernel splitting behavior. It is ignored for | ||
| other backends and will raise a ValueError if a non-FA3 backend is selected or if | ||
| FlashAttention-3 is not installed. | ||
| """ | ||
|
|
||
| with torch.cuda.device(query_layer.device), self.prepare_forward( | ||
|
|
@@ -1366,6 +1372,34 @@ def forward( | |
| fused_attention_backend = _attention_backends["fused_attention_backend"] | ||
| use_unfused_attention = _attention_backends["use_unfused_attention"] | ||
|
|
||
| # If num_splits is requested, ensure we are using FlashAttention-3. | ||
| if num_splits is not None: | ||
| is_fa3_selected = ( | ||
| use_flash_attention | ||
| and flash_attention_backend is not None | ||
| and flash_attention_backend == dpa_utils.FlashAttentionUtils.fa3_version | ||
| ) | ||
| if not is_fa3_selected: | ||
| backend_name = ( | ||
| f"FlashAttention ({str(flash_attention_backend)})" | ||
| if use_flash_attention | ||
| else ( | ||
| "FusedAttention" | ||
| if use_fused_attention | ||
| else "UnfusedDotProductAttention" | ||
| ) | ||
| ) | ||
| if not dpa_utils.FlashAttentionUtils.v3_is_installed: | ||
| raise ValueError( | ||
| "num_splits is only supported with FlashAttention-3, which is not" | ||
| " installed. " | ||
| ) | ||
| raise ValueError( | ||
| "num_splits is only supported with FlashAttention-3. Selected backend is" | ||
| f" {backend_name}. Please adjust configuration to enable FA3 for these" | ||
| " inputs." | ||
| ) | ||
|
|
||
| # raise exception if no backend is available | ||
| if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: | ||
| raise ValueError( | ||
|
|
@@ -1413,6 +1447,7 @@ def forward( | |
| inference_params=inference_params, | ||
| flash_attention_backend=flash_attention_backend, | ||
| fp8_output=fp8_output, | ||
| num_splits=num_splits, | ||
| ) | ||
|
|
||
| if use_fused_attention: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The validation logic here correctly prevents
num_splitsfrom being used with non-FA3 backends. However, there's a potential issue with code maintainability: later in the code (line 939),num_splitsis added tofa_optional_forward_kwargswithout a version check, unlike other optional parameters. While the current validation prevents reaching that code withnum_splitsset and FA2, this creates fragile coupling between distant code sections. Consider either: (1) adding a version guard at line 939 similar towindow_sizeanddeterministic, or (2) adding an assertion thatnum_splitsis None in the FA2 branch.