Skip to content

Conversation

@wdykas
Copy link
Contributor

@wdykas wdykas commented Nov 6, 2025

Description

I want to be able to control num splits in FA3. This exposes this argument

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR exposes the num_splits parameter for FlashAttention v2 and v3 backends, allowing users to control memory optimization during attention computation.

Key Changes:

  • Added optional num_splits parameter to DotProductAttention.forward() method
  • Passes num_splits to both FlashAttention v2 and v3 backend implementations when provided
  • Parameter is conditionally added to kwargs only when not None

Areas for Improvement:

  • Missing parameter documentation in the docstring
  • No version compatibility check for flash-attn (unlike other optional parameters like window_size and deterministic)
  • No tests demonstrating the new functionality

Confidence Score: 4/5

  • This PR is safe to merge with minor documentation improvements recommended
  • The implementation correctly follows the existing pattern for optional parameters in FlashAttention backends. The changes are minimal and well-scoped. However, the score is not 5 due to: (1) missing parameter documentation, (2) lack of version compatibility checks that other optional parameters have, and (3) no accompanying tests. These are quality-of-life improvements rather than critical issues.
  • No files require special attention - the implementation is straightforward and follows existing patterns

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py 4/5 Added num_splits parameter to FlashAttention forward method and passes it to both FA v2 and FA v3 backends when provided
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py 4/5 Added num_splits parameter to DotProductAttention forward signature and forwards it to FlashAttention backend

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention
    participant FlashAttention
    participant FlashAttn_v2
    participant FlashAttn_v3

    User->>DotProductAttention: forward(query, key, value, num_splits=...)
    Note over DotProductAttention: New optional parameter
    DotProductAttention->>FlashAttention: forward(..., num_splits=num_splits)
    Note over FlashAttention: Check backend version
    alt Flash Attention v2
        FlashAttention->>FlashAttention: if num_splits is not None
        FlashAttention->>FlashAttention: fa_optional_forward_kwargs["num_splits"] = num_splits
        FlashAttention->>FlashAttn_v2: flash_attn_func(..., **fa_optional_forward_kwargs)
        FlashAttn_v2-->>FlashAttention: output
    else Flash Attention v3
        FlashAttention->>FlashAttention: if num_splits is not None
        FlashAttention->>FlashAttention: fa_3_optional_forward_kwargs["num_splits"] = num_splits
        FlashAttention->>FlashAttn_v3: flash_attn_func(..., **fa_3_optional_forward_kwargs)
        FlashAttn_v3-->>FlashAttention: output
    end
    FlashAttention-->>DotProductAttention: output
    DotProductAttention-->>User: output
Loading

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
num_splits: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing documentation for the num_splits parameter in the docstring. Add a parameter description explaining what this controls and when it should be used (e.g., "Controls the number of splits for FlashAttention computation. Used for memory optimization.")

Comment on lines +928 to +930
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Verify that the flash-attn version supports num_splits parameter. Unlike other optional parameters (e.g., window_size has fa_utils.v2_3_plus check, deterministic has fa_utils.v2_4_1_plus check), this parameter is added without a version guard.

Comment on lines +965 to +967
if num_splits is not None:
# Forward optional split control to flash-attn v3 if supported
fa_3_optional_forward_kwargs["num_splits"] = num_splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same as FA v2: verify flash-attn v3 supports num_splits to avoid potential TypeError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant