Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f335cc7
custom tests for selective activation checkpointing for layernorm mlp
jaimec00 Oct 27, 2025
e349f46
add selective layernorm mlp to te.pytorch
jaimec00 Oct 27, 2025
aa18e74
update test and fix SLNMLP bug
jaimec00 Oct 27, 2025
8f50f4a
implement slnmlp
jaimec00 Oct 28, 2025
f6f034b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
00841c2
fix tests pointed out by greptile app bot, still pass
jaimec00 Oct 28, 2025
955f068
minor formatting change in tests/pytorch/selective_layernorm_mlp/dist…
jaimec00 Oct 28, 2025
5e47706
remove duplicate import in test/pytorch/selective_layernorm_mlp/test_…
jaimec00 Oct 28, 2025
9a69a6c
clean up tests, remove unused imports
jaimec00 Oct 28, 2025
ea8270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
f896579
remove unused paths in test_deffered_init
jaimec00 Oct 28, 2025
9ee2df8
fix issue with zero_centered_gamma in test_numerics reference impleme…
jaimec00 Oct 28, 2025
05d3908
clean up tests
jaimec00 Oct 28, 2025
435fe9c
make comparison.py more extensive, cleaner output
jaimec00 Oct 28, 2025
903f37e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
0a31a70
fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
jaimec00 Oct 28, 2025
418dce6
fix typo by grepbot in compare.py
jaimec00 Oct 28, 2025
31cdd9d
make selectiuve activation checkpointing optional in slnmlp via check…
jaimec00 Oct 28, 2025
fae6052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
a6a927e
add comments to clarify logic
jaimec00 Oct 29, 2025
16b816b
add checkpoint param to pytests, change compare.py to compare checkpp…
jaimec00 Oct 29, 2025
f623124
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
ff6f58f
refactor tests to call modified LayerNormMLP
jaimec00 Oct 29, 2025
8cbdb91
refactor to implement selective activation checkpointing directly int…
jaimec00 Oct 29, 2025
c46ad4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
b068c5f
fix skip explanation for cuda_graphs.py
jaimec00 Oct 29, 2025
f0670ed
make _recompute deal with lists instead of tuples
jaimec00 Oct 29, 2025
5a34186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
e12fa7c
fix MOST cuda graph failures by initializing identical quantizers dur…
jaimec00 Oct 30, 2025
9b29e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2025
cc52db5
fix cuda graphs issue, all tests pass now
jaimec00 Oct 31, 2025
e94ef33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
ebd2329
fix small logic bugs, clean up
jaimec00 Nov 1, 2025
212fadb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
402e5f9
integrate tests into main testing scripts
jaimec00 Nov 5, 2025
483bbf6
incorporate rng state tracking in checkpointing
jaimec00 Nov 5, 2025
643a3c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
0d0255f
clean up tests
jaimec00 Nov 5, 2025
d86bc00
fix return type mismatches
jaimec00 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update test and fix SLNMLP bug
Signed-off-by: Jaime Cardenas <[email protected]>
  • Loading branch information
jaimec00 committed Oct 29, 2025
commit aa18e74bdcf4d8a61fbb9b0661718fbddd5990f0
104 changes: 64 additions & 40 deletions tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,27 @@
import os
import sys
from functools import wraps
import warnings
import math

import torch
from torch import nn
import torch.distributed as dist

import transformer_engine.pytorch as te

import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
MXFP8BlockScaling,
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
NVFP4BlockScaling,
Format,
Recipe,
QParams,
)

SEQ_LEN, BATCH_SIZE = 16, 16
HIDDEN_SIZE = 64
NR_HEADS = 4
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
QUANTIZATION = None

if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api

debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)

from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.distributed import gather_along_first_dim

def _compare_tensors(name, test, ref, rtol, atol):
# Make sure tensors aren't zero and we don't pass trivially
Expand Down Expand Up @@ -85,6 +70,36 @@ def _compare_tensors(name, test, ref, rtol, atol):

return numerics_failed, numerics_info


SEQ_LEN, BATCH_SIZE = 16, 16
HIDDEN_SIZE = 64
NR_HEADS = 4
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
QUANTIZATION = None

if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api

debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)


def nvfp4_vanilla():
nvfp4_recipe = NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = QParams()
nvfp4_recipe.fp4_quant_fwd_weight = QParams()
nvfp4_recipe.fp4_quant_bwd_grad = QParams()
return nvfp4_recipe


# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8":
Expand All @@ -97,7 +112,9 @@ def quantization_recipe() -> Recipe:
return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
return te.fp8.get_default_fp8_recipe()
if QUANTIZATION == "nvfp4":
return nvfp4_vanilla()
return te.quantization.get_default_fp8_recipe()


def main(argv=None, namespace=None):
Expand Down Expand Up @@ -134,16 +151,22 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
if QUANTIZATION in ("fp8", "mxfp8"):
if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"):
SEQ_LEN = 32
BATCH_SIZE = 32
HIDDEN_SIZE = 128
# For fp8 block scaling, block size is 128,
# and to make low precision TP work, input tensor
# must be 128x128 divisible to be eligible for
# low precision All-Gather when needed
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
HIDDEN_SIZE = 512

test_dict = [test_layernorm_mlp]
test_dict = [
test_selective_layernorm_mlp,
]

for test in test_dict:
test()
Expand Down Expand Up @@ -207,6 +230,9 @@ def _get_tolerances(dtype):
# row parallel & sequence parallel, because we do the all_gather in backward pass
if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION == "nvfp4":
# TODO(zhongboz): investigate why the tolerance is so large
return {"rtol": 0.125, "atol": 0.12}
elif QUANTIZATION is not None:
return {"rtol": 0.125, "atol": 0.0625}

Expand Down Expand Up @@ -323,15 +349,15 @@ def _apply_models(
_alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True
input_single_node.requires_grad_()
input_distributed.requires_grad_()
with te.fp8_autocast(
with te.autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
recipe=quantization_recipe(),
):
output_single_node = model_single_node(input_single_node, **kwargs)
with te.fp8_autocast(
with te.autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
fp8_group=NCCL_WORLD,
recipe=quantization_recipe(),
amax_reduction_group=NCCL_WORLD,
):
output_distributed = model_distributed(input_distributed, **kwargs)
return output_single_node, output_distributed
Expand All @@ -354,13 +380,14 @@ def _alloc_main_grad(model_single_node, model_distributed):
param.main_grad = torch.zeros_like(param, dtype=torch.float32)



############################################
# LayerNormMLP #
############################################


@run_distributed_test()
def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwargs):
def _test_selective_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the LayerNormMLP with specified parallel mode and sequence parallelization.

Args:
Expand All @@ -373,8 +400,8 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128

# Create models
model_single_node = te.SelectiveLayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs)
model_distributed = te.SelectiveLayerNormMLP(
model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs)
model_distributed = te.LayerNormMLP(
HIDDEN_SIZE,
FFN_HIDDEN_SIZE,
tp_size=WORLD_SIZE,
Expand Down Expand Up @@ -441,15 +468,15 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
)


def test_layernorm_mlp():
def test_selective_layernorm_mlp():
kwargs_list = [
{},
{"init_method": _constant},
{"output_layer_init_method": _constant},
{"normalization": "RMSNorm"},
{"zero_centered_gamma": True},
{"bias": False},
{"params_dtype": torch.float16},
{"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"activation": "relu"},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
Expand All @@ -460,8 +487,5 @@ def test_layernorm_mlp():
for kwargs in kwargs_list:
for set_parallel_mode in [True]:
for sequence_parallel in [False, True]:
_test_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)

_test_selective_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)

if __name__ == "__main__":
sys.exit(main())
17 changes: 11 additions & 6 deletions tests/pytorch/selective_layernorm_mlp/distributed/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch as te

"""
Distributed numerics tests
Expand All @@ -26,11 +26,12 @@
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
Expand All @@ -51,7 +52,9 @@ def _run_test(quantization):
all_boolean = [True, False]


@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
@pytest.mark.parametrize(
"quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
)
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
Expand All @@ -61,4 +64,6 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
_run_test(quantization)
127 changes: 0 additions & 127 deletions tests/pytorch/selective_layernorm_mlp/test_cpu_offloading.py

This file was deleted.

Loading