Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f335cc7
custom tests for selective activation checkpointing for layernorm mlp
jaimec00 Oct 27, 2025
e349f46
add selective layernorm mlp to te.pytorch
jaimec00 Oct 27, 2025
aa18e74
update test and fix SLNMLP bug
jaimec00 Oct 27, 2025
8f50f4a
implement slnmlp
jaimec00 Oct 28, 2025
f6f034b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
00841c2
fix tests pointed out by greptile app bot, still pass
jaimec00 Oct 28, 2025
955f068
minor formatting change in tests/pytorch/selective_layernorm_mlp/dist…
jaimec00 Oct 28, 2025
5e47706
remove duplicate import in test/pytorch/selective_layernorm_mlp/test_…
jaimec00 Oct 28, 2025
9a69a6c
clean up tests, remove unused imports
jaimec00 Oct 28, 2025
ea8270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
f896579
remove unused paths in test_deffered_init
jaimec00 Oct 28, 2025
9ee2df8
fix issue with zero_centered_gamma in test_numerics reference impleme…
jaimec00 Oct 28, 2025
05d3908
clean up tests
jaimec00 Oct 28, 2025
435fe9c
make comparison.py more extensive, cleaner output
jaimec00 Oct 28, 2025
903f37e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
0a31a70
fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
jaimec00 Oct 28, 2025
418dce6
fix typo by grepbot in compare.py
jaimec00 Oct 28, 2025
31cdd9d
make selectiuve activation checkpointing optional in slnmlp via check…
jaimec00 Oct 28, 2025
fae6052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
a6a927e
add comments to clarify logic
jaimec00 Oct 29, 2025
16b816b
add checkpoint param to pytests, change compare.py to compare checkpp…
jaimec00 Oct 29, 2025
f623124
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
ff6f58f
refactor tests to call modified LayerNormMLP
jaimec00 Oct 29, 2025
8cbdb91
refactor to implement selective activation checkpointing directly int…
jaimec00 Oct 29, 2025
c46ad4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
b068c5f
fix skip explanation for cuda_graphs.py
jaimec00 Oct 29, 2025
f0670ed
make _recompute deal with lists instead of tuples
jaimec00 Oct 29, 2025
5a34186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
e12fa7c
fix MOST cuda graph failures by initializing identical quantizers dur…
jaimec00 Oct 30, 2025
9b29e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2025
cc52db5
fix cuda graphs issue, all tests pass now
jaimec00 Oct 31, 2025
e94ef33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
ebd2329
fix small logic bugs, clean up
jaimec00 Nov 1, 2025
212fadb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
402e5f9
integrate tests into main testing scripts
jaimec00 Nov 5, 2025
483bbf6
incorporate rng state tracking in checkpointing
jaimec00 Nov 5, 2025
643a3c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
0d0255f
clean up tests
jaimec00 Nov 5, 2025
d86bc00
fix return type mismatches
jaimec00 Nov 5, 2025
9aaa1b9
merge main into features/SLNMLP
jaimec00 Nov 12, 2025
07ff0c1
remove checkpoint test from test_recipe, add sperate test in test_num…
jaimec00 Nov 12, 2025
8dec0fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
78b3437
minor typo fix
jaimec00 Nov 12, 2025
2ec3f18
merge main into features/SLNMLP
jaimec00 Nov 17, 2025
b959044
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2025
197ef5e
clear up assertions in tests/pytorch/layernorm_mlp/test_selective_act…
jaimec00 Nov 17, 2025
332c5c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2025
b509cec
add license and copyright info
jaimec00 Nov 17, 2025
7e6648b
fix lint issues in layernorm_mlp
jaimec00 Nov 17, 2025
fe88ceb
Merge branch 'main' into features/SLNMLP
ksivaman Nov 17, 2025
1b3ff5f
fix cpu_offload_v1 error
jaimec00 Nov 18, 2025
5fd59e1
possibly fix recomputation in cuda graph bug
jaimec00 Nov 18, 2025
26682b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
80a7229
skip cuda graphs test for SLNMLP with SM>=10.0 and using delayed scaling
jaimec00 Nov 18, 2025
ef21ac6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
4a09a0f
fix typo for setting IS_FIRST_FP8_MODULE
jaimec00 Nov 18, 2025
e21fe22
Merge remote-tracking branch 'upstream/main' into features/SLNMLP
jaimec00 Nov 18, 2025
906ca4b
Merge branch 'main' into features/SLNMLP
ksivaman Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor tests to call modified LayerNormMLP
Signed-off-by: Jaime Cardenas <[email protected]>
  • Loading branch information
jaimec00 committed Oct 29, 2025
commit ff6f58fe8a7d5eb2751442157d4daff4353058b0
8 changes: 4 additions & 4 deletions tests/pytorch/selective_layernorm_mlp/compare.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment about this file - it is really nice, but it is not a test - it doesn't actually test anything, it just measures. We could introduce some test functionality here by e.g. ensuring that the error between the checkpointed LayerNormMLP is zero (since this shouldn't affect numerics) or that the memory used is lower (ideally we would quantify the expected memory usage and test against that, but for now even just making sure that the memory usage goes down would be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I converted it into a test for checking that memory goes down at least 6X in the forward pass. I also asserted that checkpointing is slower than not checkpointing in the backward pass (not sure if this is helpful, but let me know), and that the differences are 0. I put this test in tests/pytorch/layernorm_mlp/test_selective_activation_checkpointing.py because I wasn't sure where it fit in the rest of the testing scripts, but let me know if this test would be better elsewhere!

from transformer_engine.pytorch import SelectiveLayerNormMLP
from transformer_engine.pytorch import LayerNormMLP
from collections import defaultdict

torch.manual_seed(1234)
Expand Down Expand Up @@ -32,10 +32,10 @@ def build(self):

ln_list, sln_list = [], []
for _ in range(self._layers):
ln = SelectiveLayerNormMLP(
ln = LayerNormMLP(
self._hidden_size, self._ffn_hidden_size, checkpoint=False
).to(device)
sln = SelectiveLayerNormMLP(
sln = LayerNormMLP(
self._hidden_size, self._ffn_hidden_size, checkpoint=True
).to(device)
with torch.no_grad():
Expand Down Expand Up @@ -180,7 +180,7 @@ def _run_bwd(model, out):
self.stats[desc]["diff"][key] = self._max_diff(ln_grads[key], sln_grads[key])

def summarize(self):
_modules = [("ln_stats", "LayerNormMLP"), ("sln_stats", "SelectiveLayerNormMLP")]
_modules = [("ln_stats", "No Checkpointing"), ("sln_stats", "Checkpointing")]
_metric_map = {"time": (1, "ms"), "mem": (1e-6, "MB")}

left_w = 18 # "fwd time" / "bwd mem" label
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main(argv=None, namespace=None):
HIDDEN_SIZE = 512

test_dict = [
test_selective_layernorm_mlp,
test_layernorm_mlp,
]

for test in test_dict:
Expand Down Expand Up @@ -378,13 +378,13 @@ def _alloc_main_grad(model_single_node, model_distributed):


############################################
# SelectiveLayerNormMLP #
# LayerNormMLP #
############################################


@run_distributed_test()
def _test_selective_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the SelectiveLayerNormMLP with specified parallel mode and sequence parallelization.
def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the LayerNormMLP with specified parallel mode and sequence parallelization.

Args:
set_parallel_mode (bool): Enable parallel mode.
Expand All @@ -396,8 +396,8 @@ def _test_selective_layernorm_mlp(set_parallel_mode=None, sequence_parallel=Fals
FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128

# Create models
model_single_node = te.SelectiveLayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs)
model_distributed = te.SelectiveLayerNormMLP(
model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs)
model_distributed = te.LayerNormMLP(
HIDDEN_SIZE,
FFN_HIDDEN_SIZE,
tp_size=WORLD_SIZE,
Expand Down Expand Up @@ -464,7 +464,7 @@ def _test_selective_layernorm_mlp(set_parallel_mode=None, sequence_parallel=Fals
)


def test_selective_layernorm_mlp():
def test_layernorm_mlp():
kwargs_list = [
{},
{"init_method": _constant},
Expand All @@ -485,4 +485,4 @@ def test_selective_layernorm_mlp():
for kwargs in kwargs_list:
for set_parallel_mode in [True]:
for sequence_parallel in [False, True]:
_test_selective_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)
_test_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)
10 changes: 5 additions & 5 deletions tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
from transformer_engine.pytorch import (
SelectiveLayerNormMLP,
LayerNormMLP,
autocast,
quantized_model_init,
make_graphed_callables,
Expand Down Expand Up @@ -165,7 +165,7 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:


# Supported modules
_test_cuda_graphs_modules: List[str] = ["selective_layernorm_mlp"]
_test_cuda_graphs_modules: List[str] = ["layernorm_mlp"]


def _test_cuda_graphs(
Expand All @@ -192,9 +192,9 @@ def _test_cuda_graphs(
# Create modules.
with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe):

if module == "selective_layernorm_mlp":
if module == "layernorm_mlp":
modules = [
SelectiveLayerNormMLP(
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_make_graphed_callables(


_test_make_graphed_callables_with_fp8_weight_caching_modules = [
"selective_layernorm_mlp",
"layernorm_mlp",
]


Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/selective_layernorm_mlp/test_deferred_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import transformer_engine.pytorch as te

_core_modules = [
te.SelectiveLayerNormMLP,
te.LayerNormMLP,
]
_composed_modules = []

Expand All @@ -26,7 +26,7 @@ def get_module_args(module, checkpoint):
hidden_size = num_heads * head_dim
args = (hidden_size,)
kwargs = {"params_dtype": dtype, "device": "meta"}
if module == te.SelectiveLayerNormMLP:
if module == te.LayerNormMLP:
ffn_hidden_size = 2 * hidden_size
args += (ffn_hidden_size,)
kwargs["bias"] = True
Expand Down
12 changes: 6 additions & 6 deletions tests/pytorch/selective_layernorm_mlp/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch import (
autocast,
SelectiveLayerNormMLP,
LayerNormMLP,
get_device_compute_capability,
is_fp8_available,
is_mxfp8_available,
Expand Down Expand Up @@ -388,13 +388,13 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_selective_layernorm_mlp_accuracy(
def test_layernorm_mlp_accuracy(
dtype, bs, model, activation, normalization, return_bias, bias, checkpoint
):
config = model_configs[model]

te_ln_mlp = TestReturnBiasModule(
SelectiveLayerNormMLP,
LayerNormMLP,
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
activation=activation,
Expand Down Expand Up @@ -466,12 +466,12 @@ def test_selective_layernorm_mlp_accuracy(
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_selective_layernorm_mlp_accuracy_delay_wgrad_compute(
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, bias, fuse_wgrad_accumulation, checkpoint
):
config = model_configs[model]

ln_mlp = SelectiveLayerNormMLP(
ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
Expand All @@ -483,7 +483,7 @@ def test_selective_layernorm_mlp_accuracy_delay_wgrad_compute(
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()

ln_mlp_ref = SelectiveLayerNormMLP(
ln_mlp_ref = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/selective_layernorm_mlp/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
quantized_model_init,
SelectiveLayerNormMLP,
LayerNormMLP,
)

from transformer_engine.common.recipe import DelayedScaling
Expand All @@ -35,7 +35,7 @@ def setup_class(cls) -> None:

@pytest.mark.parametrize(
"module_class",
[SelectiveLayerNormMLP],
[LayerNormMLP],
)
@pytest.mark.parametrize("checkpoint", (True, False))
def test_quantizer_update(self, module_class, checkpoint):
Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch/selective_layernorm_mlp/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from transformer_engine.pytorch import (
autocast,
SelectiveLayerNormMLP,
LayerNormMLP,
is_bf16_available,
)
from transformer_engine.common import recipe
Expand Down Expand Up @@ -158,7 +158,7 @@ def _test_sanity_common(
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_sanity_selective_layernorm_mlp(
def test_sanity_layernorm_mlp(
dtype,
fp8_recipe,
model,
Expand All @@ -182,7 +182,7 @@ def test_sanity_selective_layernorm_mlp(
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

block = SelectiveLayerNormMLP(
block = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
init_method=init_method,
Expand Down