Skip to content
Open
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
484a814
[Common] Deleted unused header (#2324)
Oleg-Goncharov Oct 31, 2025
cbb8248
[JAX] L1_jax_distributed_test suit with individual executions (#2321)
phu0ngng Nov 3, 2025
190a84c
for branch
Nov 5, 2025
9c15c57
clean up and tests
wdykas Nov 10, 2025
1446f0a
change tests
wdykas Nov 12, 2025
7e8f3ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
e89d7a9
[PyTorch debug] Fixes to debug tests failures (#2268)
pggPL Nov 4, 2025
2ca40f9
[PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137)
pggPL Nov 5, 2025
7a06387
[JAX] Fix bug with pre scale bias (#2300)
pggPL Nov 5, 2025
6693da3
[JAX] Try to use pre-downloaded dataset artifacts first (#2345)
jberchtold-nvidia Nov 6, 2025
a68eda5
Fix out of bounds access in the FP4 dequantize kernel (#2346)
ptrendx Nov 6, 2025
7b395bf
Make FP8 weights compatible with older MCore version (#2342)
kunlunl Nov 6, 2025
3299742
[JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#…
jberchtold-nvidia Nov 7, 2025
b97fb5a
Fix sharding of segment position to match id in ring attention. (#2349)
mgoldfarb-nvidia Nov 7, 2025
7217045
Disable cuDNN attention for known IMA and NaNs (#2344)
ksivaman Nov 7, 2025
624fe37
[JAX] Default to fused attention in JAX DPA (#2363)
KshitijLakhani Nov 7, 2025
868f18a
Update cudnn frontend to v1.16.0 (#2362)
ksivaman Nov 7, 2025
46a0b65
[common] Remove kvpacked and qkvpacked attention functions for every …
pggPL Nov 7, 2025
ac3a513
Move Triton to common (#2359)
tdophung Nov 10, 2025
16316ba
[JAX] Fused layers argument default values changed (#2347)
tdophung Nov 10, 2025
52a5f37
remove comment from gpt
wdykas Nov 12, 2025
cf381a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
20ee52a
Merge branch 'NVIDIA:main' into num-splits-attention
wdykas Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[JAX] Default to fused attention in JAX DPA (#2363)
* Default to fused attention in JAX DPA

Signed-off-by: Kshitij Lakhani <[email protected]>

* Consolidate documentation for DPA in JAX

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <[email protected]>

* Correctly update the documentation for defaults in JAX DPA

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <[email protected]>

---------

Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Kshitij Lakhani <[email protected]>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Peter Dykas <[email protected]>
  • Loading branch information
2 people authored and wdykas committed Nov 12, 2025
commit 624fe37eb0be2d51a37591fa37d9162a9175d2c3
11 changes: 6 additions & 5 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
variable:

* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
kernel is not available on the system, a warning will be issued, and the module will
automatically fall back to the unfused backend.
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
attention kernel is not available on the system, a warning will be issued, and the module
will automatically fall back to the unfused backend.

.. note::
The DotProductAttention default setting enables non-deterministic kernels for reduced
Expand Down Expand Up @@ -602,7 +602,8 @@ def __call__(
else:
assert bias is not None

enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
# Use fused attn (if kernel check below passes) by default
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))

sequence_dim = 0 if self.transpose_batch_sequence else 1
seqlen_q = query.shape[sequence_dim]
Expand Down