Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
484a814
[Common] Deleted unused header (#2324)
Oleg-Goncharov Oct 31, 2025
cbb8248
[JAX] L1_jax_distributed_test suit with individual executions (#2321)
phu0ngng Nov 3, 2025
190a84c
for branch
Nov 5, 2025
9c15c57
clean up and tests
wdykas Nov 10, 2025
1446f0a
change tests
wdykas Nov 12, 2025
7e8f3ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
e89d7a9
[PyTorch debug] Fixes to debug tests failures (#2268)
pggPL Nov 4, 2025
2ca40f9
[PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137)
pggPL Nov 5, 2025
7a06387
[JAX] Fix bug with pre scale bias (#2300)
pggPL Nov 5, 2025
6693da3
[JAX] Try to use pre-downloaded dataset artifacts first (#2345)
jberchtold-nvidia Nov 6, 2025
a68eda5
Fix out of bounds access in the FP4 dequantize kernel (#2346)
ptrendx Nov 6, 2025
7b395bf
Make FP8 weights compatible with older MCore version (#2342)
kunlunl Nov 6, 2025
3299742
[JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#…
jberchtold-nvidia Nov 7, 2025
b97fb5a
Fix sharding of segment position to match id in ring attention. (#2349)
mgoldfarb-nvidia Nov 7, 2025
7217045
Disable cuDNN attention for known IMA and NaNs (#2344)
ksivaman Nov 7, 2025
624fe37
[JAX] Default to fused attention in JAX DPA (#2363)
KshitijLakhani Nov 7, 2025
868f18a
Update cudnn frontend to v1.16.0 (#2362)
ksivaman Nov 7, 2025
46a0b65
[common] Remove kvpacked and qkvpacked attention functions for every …
pggPL Nov 7, 2025
ac3a513
Move Triton to common (#2359)
tdophung Nov 10, 2025
16316ba
[JAX] Fused layers argument default values changed (#2347)
tdophung Nov 10, 2025
52a5f37
remove comment from gpt
wdykas Nov 12, 2025
cf381a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
20ee52a
Merge branch 'NVIDIA:main' into num-splits-attention
wdykas Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[JAX] Fused layers argument default values changed (#2347)
* Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False

Signed-off-by: tdophung <[email protected]>

* Fixing the failing tests by hard coding  arguments to the previous values instead of relying on newer default values

Signed-off-by: tdophung <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: tdophung <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Peter Dykas <[email protected]>
  • Loading branch information
2 people authored and wdykas committed Nov 12, 2025
commit 16316baf92b6699c092ccca88f90135e110d8fc1
2 changes: 2 additions & 0 deletions tests/jax/test_distributed_layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def _test_layernorm_mlp(
intermediate_dim=INTERMEDIATE,
activations=activation_type,
use_bias=use_bias,
return_layernorm_output=True,
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
Expand Down Expand Up @@ -417,6 +418,7 @@ def _test_layernorm_mlp(
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
return_layernorm_output=True,
)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
Expand Down
12 changes: 6 additions & 6 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,9 @@ class MlpBlock(nn.Module):

transpose_batch_sequence: bool
intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ("relu",)
activations: Sequence[Union[str, Callable]] = ("gelu",)
kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1
intermediate_dropout_rate: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
use_bias: bool = False
dtype: Any = jnp.float32
Expand Down Expand Up @@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module):
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
Expand Down Expand Up @@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module):
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
Expand Down
16 changes: 8 additions & 8 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
Expand Down Expand Up @@ -644,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True
return_layernorm_output: bool = False
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
Expand Down Expand Up @@ -891,10 +891,10 @@ class LayerNormMLP(TransformerEngineBase):
The name of axes used to shard bias with a corresponding mesh for
the weight of the second dense layer transformation.
Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',)
activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer.
activation_params: dict, default = None
Expand All @@ -903,7 +903,7 @@ class LayerNormMLP(TransformerEngineBase):
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1
intermediate_dropout_rate: float, default = 0.0
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
Expand Down Expand Up @@ -959,11 +959,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ("act", "mlp")
bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
return_layernorm_output: bool = False
activations: Sequence[Union[str, Callable]] = ("gelu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1
intermediate_dropout_rate: float = 0.0
intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1
Dropout probability for the dropout op during multi-head attention.
intermediate_dropout: float, default = 0.1
intermediate_dropout: float, default = 0.0
Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden after FC1 layer.
Expand All @@ -1635,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations: Sequence[str], default = ('relu', )
mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
mlp_activation_params: dict = None
Expand Down Expand Up @@ -1755,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout"
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
mlp_activation_params: dict = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
Expand Down