Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f335cc7
custom tests for selective activation checkpointing for layernorm mlp
jaimec00 Oct 27, 2025
e349f46
add selective layernorm mlp to te.pytorch
jaimec00 Oct 27, 2025
aa18e74
update test and fix SLNMLP bug
jaimec00 Oct 27, 2025
8f50f4a
implement slnmlp
jaimec00 Oct 28, 2025
f6f034b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
00841c2
fix tests pointed out by greptile app bot, still pass
jaimec00 Oct 28, 2025
955f068
minor formatting change in tests/pytorch/selective_layernorm_mlp/dist…
jaimec00 Oct 28, 2025
5e47706
remove duplicate import in test/pytorch/selective_layernorm_mlp/test_…
jaimec00 Oct 28, 2025
9a69a6c
clean up tests, remove unused imports
jaimec00 Oct 28, 2025
ea8270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
f896579
remove unused paths in test_deffered_init
jaimec00 Oct 28, 2025
9ee2df8
fix issue with zero_centered_gamma in test_numerics reference impleme…
jaimec00 Oct 28, 2025
05d3908
clean up tests
jaimec00 Oct 28, 2025
435fe9c
make comparison.py more extensive, cleaner output
jaimec00 Oct 28, 2025
903f37e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
0a31a70
fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
jaimec00 Oct 28, 2025
418dce6
fix typo by grepbot in compare.py
jaimec00 Oct 28, 2025
31cdd9d
make selectiuve activation checkpointing optional in slnmlp via check…
jaimec00 Oct 28, 2025
fae6052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
a6a927e
add comments to clarify logic
jaimec00 Oct 29, 2025
16b816b
add checkpoint param to pytests, change compare.py to compare checkpp…
jaimec00 Oct 29, 2025
f623124
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
ff6f58f
refactor tests to call modified LayerNormMLP
jaimec00 Oct 29, 2025
8cbdb91
refactor to implement selective activation checkpointing directly int…
jaimec00 Oct 29, 2025
c46ad4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
b068c5f
fix skip explanation for cuda_graphs.py
jaimec00 Oct 29, 2025
f0670ed
make _recompute deal with lists instead of tuples
jaimec00 Oct 29, 2025
5a34186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
e12fa7c
fix MOST cuda graph failures by initializing identical quantizers dur…
jaimec00 Oct 30, 2025
9b29e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2025
cc52db5
fix cuda graphs issue, all tests pass now
jaimec00 Oct 31, 2025
e94ef33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
ebd2329
fix small logic bugs, clean up
jaimec00 Nov 1, 2025
212fadb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
402e5f9
integrate tests into main testing scripts
jaimec00 Nov 5, 2025
483bbf6
incorporate rng state tracking in checkpointing
jaimec00 Nov 5, 2025
643a3c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
0d0255f
clean up tests
jaimec00 Nov 5, 2025
d86bc00
fix return type mismatches
jaimec00 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix cuda graphs issue, all tests pass now
Signed-off-by: Jaime Cardenas <[email protected]>
  • Loading branch information
jaimec00 committed Oct 31, 2025
commit cc52db533281cdc0c34ddf15749e1424d7415d2e
14 changes: 4 additions & 10 deletions tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Iterable, List, Union
import pytest
from torch.profiler import profile, ProfilerActivity
import torch.cuda.nvtx as nvtx

import torch
from transformer_engine.pytorch import (
Expand Down Expand Up @@ -246,14 +248,15 @@ def _test_cuda_graphs(
# Training steps.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
for grad_accumulation_step in range(1):
input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
with autocast(enabled=fp8, recipe=fp8_recipe):
kwargs = {}
if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0
output = model(input_, **kwargs)

output.backward(grad_output)
optimizer.step()

Expand Down Expand Up @@ -296,15 +299,6 @@ def test_make_graphed_callables(
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
if (
checkpoint
and type(fp8_recipe).__name__ == "Float8CurrentScaling"
and dtype != torch.float32
):
pytest.skip(
"CUDA graphs for LayerNormMLP with checkpointing, Float8CurrentScaling recipe, with"
f" {dtype} dtype tensors not supported yet"
)

# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
Expand Down
181 changes: 64 additions & 117 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,86 +150,6 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None):
raise NotImplementedError("Activation type " + activation + " is not supported!")
return funcs[activation]


@torch.no_grad()
def _copy_quantizer(
quantizer: Optional[Quantizer],
device: Optional[Union[str, torch.device]] = None,
) -> Optional[Quantizer]:
if quantizer is None:
return None

inferred_device = None
for attr in ("scale", "amax"):
tensor = getattr(quantizer, attr, None)
if isinstance(tensor, torch.Tensor):
inferred_device = tensor.device
break
if device is not None:
inferred_device = torch.device(device)
if inferred_device is None:
inferred_device = torch.device("cuda")

if isinstance(quantizer, Float8BlockQuantizer):
q = Float8BlockQuantizer(
fp8_dtype=quantizer.dtype,
rowwise=quantizer.rowwise_usage,
columnwise=quantizer.columnwise_usage,
amax_epsilon=quantizer.amax_epsilon,
force_pow_2_scales=quantizer.force_pow_2_scales,
block_scaling_dim=quantizer.block_scaling_dim,
all_gather_usage=quantizer.all_gather_usage,
)
elif isinstance(quantizer, Float8Quantizer):
q = Float8Quantizer(
scale=quantizer.scale.clone(),
amax=quantizer.amax.clone(),
fp8_dtype=quantizer.dtype,
rowwise=quantizer.rowwise_usage,
columnwise=quantizer.columnwise_usage,
)
elif isinstance(quantizer, Float8CurrentScalingQuantizer):
q = Float8CurrentScalingQuantizer(
fp8_dtype=quantizer.dtype,
device=inferred_device,
rowwise=quantizer.rowwise_usage,
columnwise=quantizer.columnwise_usage,
use_existing_amax=quantizer.use_existing_amax,
with_amax_reduction=quantizer.with_amax_reduction,
amax_reduction_group=quantizer.amax_reduction_group,
force_pow_2_scales=quantizer.force_pow_2_scales,
amax_epsilon=quantizer.amax_epsilon,
)
q.scale = quantizer.scale.clone()
q.amax = quantizer.amax.clone()
elif isinstance(quantizer, MXFP8Quantizer):
q = MXFP8Quantizer(
fp8_dtype=quantizer.dtype,
rowwise=quantizer.rowwise_usage,
columnwise=quantizer.columnwise_usage,
)
elif isinstance(quantizer, NVFP4Quantizer):
q = NVFP4Quantizer(
fp4_dtype=quantizer.dtype,
rowwise=quantizer.rowwise_usage,
columnwise=quantizer.columnwise_usage,
with_amax_reduction=quantizer.with_amax_reduction,
amax_reduction_group=quantizer.amax_reduction_group,
with_rht=quantizer.with_rht,
with_post_rht_amax=quantizer.with_post_rht_amax,
with_2d_quantization=quantizer.with_2d_quantization,
stochastic_rounding=quantizer.stochastic_rounding,
with_random_sign_mask=quantizer.rht_matrix_random_sign_mask_t != 0,
)
else:
raise NotImplementedError(
"Checkpointing in LayerNormMLP not implemented for"
f" {type(quantizer).__name__} quantizer yet"
)
q.internal = quantizer.internal
return q


class _LayerNormMLP(torch.autograd.Function):
"""LayerNormMLP semi-top level module
Calls custom cuda extensions.
Expand Down Expand Up @@ -298,8 +218,11 @@ def _forward(
# if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take
if is_grad_enabled and not recompute_for_bwd:
ctx.checkpoint = checkpoint

# few helper flags for simpler logic
if checkpoint:
# save the state of autocast and quantizers for recomputation
ctx.autocast_state = FP8GlobalStateManager.get_autocast_state() # to restore autocast state during recomputation
if fp8 and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling": # only applicable for delayed scaling
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(module.fp8_meta) # to restore quantizers during recomputation

# whether to save activations regularly, or save inputs for recomputation in bwd
save_for_checkpoint = checkpoint and is_grad_enabled and not recompute_for_bwd
Expand All @@ -323,29 +246,25 @@ def _forward(
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects

# save other arguments to _forward
def clone_quantizer(q: Optional[Quantizer]) -> Optional[Quantizer]:
return _copy_quantizer(q, device=inp.device) if q is not None else None

ctx.other_args = {
"eps": eps,
"is_first_microbatch": is_first_microbatch,
"fp8": fp8,
"fp8_calibration": fp8_calibration,
"wgrad_store": wgrad_store,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"fc1_input_quantizer": clone_quantizer(fc1_input_quantizer),
"fc1_weight_quantizer": clone_quantizer(fc1_weight_quantizer),
"fc1_output_quantizer": clone_quantizer(fc1_output_quantizer),
"fc1_grad_input_quantizer": clone_quantizer(fc1_grad_input_quantizer),
"fc1_grad_weight_quantizer": clone_quantizer(fc1_grad_weight_quantizer),
"fc1_grad_output_quantizer": clone_quantizer(fc1_grad_output_quantizer),
"fc2_input_quantizer": clone_quantizer(fc2_input_quantizer),
"fc2_weight_quantizer": clone_quantizer(fc2_weight_quantizer),
"fc2_output_quantizer": clone_quantizer(fc2_output_quantizer),
"fc2_grad_input_quantizer": clone_quantizer(fc2_grad_input_quantizer),
"fc2_grad_weight_quantizer": clone_quantizer(fc2_grad_weight_quantizer),
"fc2_grad_output_quantizer": clone_quantizer(fc2_grad_output_quantizer),
"fc1_input_quantizer": fc1_input_quantizer,
"fc1_weight_quantizer": fc1_weight_quantizer,
"fc1_output_quantizer": fc1_output_quantizer,
"fc1_grad_input_quantizer": fc1_grad_input_quantizer,
"fc1_grad_weight_quantizer": fc1_grad_weight_quantizer,
"fc1_grad_output_quantizer": fc1_grad_output_quantizer,
"fc2_input_quantizer": fc2_input_quantizer,
"fc2_weight_quantizer": fc2_weight_quantizer,
"fc2_output_quantizer": fc2_output_quantizer,
"fc2_grad_input_quantizer": fc2_grad_input_quantizer,
"fc2_grad_weight_quantizer": fc2_grad_weight_quantizer,
"fc2_grad_output_quantizer": fc2_grad_output_quantizer,
"cpu_offloading": cpu_offloading,
"tp_group": tp_group,
"tp_size": tp_size,
Expand Down Expand Up @@ -738,7 +657,7 @@ def clone_quantizer(q: Optional[Quantizer]) -> Optional[Quantizer]:
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorStorage):
fc2_weight_final.update_usage(columnwise_usage=True)

ctx.fc1_weight_quantizer = fc1_weight_quantizer
ctx.fc2_weight_quantizer = fc2_weight_quantizer

Expand Down Expand Up @@ -1032,9 +951,23 @@ def _recompute(ctx):
ctx.tensor_objects = None

if ctx.checkpoint: # do recomputation from the original args
return _LayerNormMLP._forward(

# backward is not in autocast context, so we set the state here
# we also have to set the quantizer states to what they were before the forward pass (only relevant for DelayedScaling recipe)
final_autocast_state = FP8GlobalStateManager.get_autocast_state() # get current autocast state
FP8GlobalStateManager.set_autocast_state(ctx.autocast_state) # set old autocast state
if ctx.other_args["fp8"] and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling": # only applicable for delayed scaling
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(ctx.other_args["module"].fp8_meta) # set old quantizer state
out = _LayerNormMLP._forward( # recompute
ctx, *tensors, *ctx.other_args.values(), recompute_for_bwd=True
)

FP8GlobalStateManager.set_autocast_state(final_autocast_state) # restore autocast state
if ctx.other_args["fp8"] and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling":
FP8GlobalStateManager.restore_fp8_meta_tensors(ctx.other_args["module"].fp8_meta) # restore quantizers

return out

else: # load from saved (return ctx is just because the other branch does too)
return [ctx] + tensors

Expand All @@ -1044,24 +977,24 @@ def backward(
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring

( # pylint: disable=unbalanced-tuple-unpacking
ctx,
inputmat,
ln_weight,
ln_out,
fc1_weight,
origin_fc1_weight,
fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight,
origin_fc2_weight,
fc2_bias,
mu,
rsigma,
) = _LayerNormMLP._recompute(ctx)
with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
( # pylint: disable=unbalanced-tuple-unpacking
ctx,
inputmat,
ln_weight,
ln_out,
fc1_weight,
origin_fc1_weight,
fc1_bias,
fc1_out,
fc1_out_without_bias,
act_out,
fc2_weight,
origin_fc2_weight,
fc2_bias,
mu,
rsigma,
) = _LayerNormMLP._recompute(ctx)

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = (
Expand Down Expand Up @@ -1725,12 +1658,26 @@ def fc1_wgrad_gemm(
# )
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
# inputmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,

dgamma,
# ln_weight,

dbeta,
# ln_weight,

fc1_wgrad,
# origin_fc1_weight,

fc1_bias_grad if fc1_bias is not None else None,
# fc1_bias,

fc2_wgrad, # pylint: disable=possibly-used-before-assignment
# origin_fc2_weight,

fc2_bias_grad,
# fc2_bias,

None, # eps
None, # is_first_microbatch
None, # fp8
Expand Down