Skip to content
Open
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f335cc7
custom tests for selective activation checkpointing for layernorm mlp
jaimec00 Oct 27, 2025
e349f46
add selective layernorm mlp to te.pytorch
jaimec00 Oct 27, 2025
aa18e74
update test and fix SLNMLP bug
jaimec00 Oct 27, 2025
8f50f4a
implement slnmlp
jaimec00 Oct 28, 2025
f6f034b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
00841c2
fix tests pointed out by greptile app bot, still pass
jaimec00 Oct 28, 2025
955f068
minor formatting change in tests/pytorch/selective_layernorm_mlp/dist…
jaimec00 Oct 28, 2025
5e47706
remove duplicate import in test/pytorch/selective_layernorm_mlp/test_…
jaimec00 Oct 28, 2025
9a69a6c
clean up tests, remove unused imports
jaimec00 Oct 28, 2025
ea8270d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
f896579
remove unused paths in test_deffered_init
jaimec00 Oct 28, 2025
9ee2df8
fix issue with zero_centered_gamma in test_numerics reference impleme…
jaimec00 Oct 28, 2025
05d3908
clean up tests
jaimec00 Oct 28, 2025
435fe9c
make comparison.py more extensive, cleaner output
jaimec00 Oct 28, 2025
903f37e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
0a31a70
fix small typo in tests/pytorch/selective_layernorm_mlp/compare.py
jaimec00 Oct 28, 2025
418dce6
fix typo by grepbot in compare.py
jaimec00 Oct 28, 2025
31cdd9d
make selectiuve activation checkpointing optional in slnmlp via check…
jaimec00 Oct 28, 2025
fae6052
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
a6a927e
add comments to clarify logic
jaimec00 Oct 29, 2025
16b816b
add checkpoint param to pytests, change compare.py to compare checkpp…
jaimec00 Oct 29, 2025
f623124
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
ff6f58f
refactor tests to call modified LayerNormMLP
jaimec00 Oct 29, 2025
8cbdb91
refactor to implement selective activation checkpointing directly int…
jaimec00 Oct 29, 2025
c46ad4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
b068c5f
fix skip explanation for cuda_graphs.py
jaimec00 Oct 29, 2025
f0670ed
make _recompute deal with lists instead of tuples
jaimec00 Oct 29, 2025
5a34186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
e12fa7c
fix MOST cuda graph failures by initializing identical quantizers dur…
jaimec00 Oct 30, 2025
9b29e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2025
cc52db5
fix cuda graphs issue, all tests pass now
jaimec00 Oct 31, 2025
e94ef33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
ebd2329
fix small logic bugs, clean up
jaimec00 Nov 1, 2025
212fadb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
402e5f9
integrate tests into main testing scripts
jaimec00 Nov 5, 2025
483bbf6
incorporate rng state tracking in checkpointing
jaimec00 Nov 5, 2025
643a3c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
0d0255f
clean up tests
jaimec00 Nov 5, 2025
d86bc00
fix return type mismatches
jaimec00 Nov 5, 2025
9aaa1b9
merge main into features/SLNMLP
jaimec00 Nov 12, 2025
07ff0c1
remove checkpoint test from test_recipe, add sperate test in test_num…
jaimec00 Nov 12, 2025
8dec0fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
78b3437
minor typo fix
jaimec00 Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add comments to clarify logic
Signed-off-by: Jaime Cardenas <[email protected]>
  • Loading branch information
jaimec00 committed Oct 29, 2025
commit a6a927eec03bd7310883c0f93c5a0a0404eed4c6
57 changes: 42 additions & 15 deletions transformer_engine/pytorch/module/selective_layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,22 @@ def _forward(
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring

# save the initial state for recomputation by bwd
# would be better to do prep stuff and save that in the beginning
# but want to get a working version first
if not recompute_for_bwd:
# if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take
if is_grad_enabled and not recompute_for_bwd:
ctx.checkpoint = checkpoint

# few helper flags for simpler logic

# whether to save activations regularly, or save inputs for recomputation in bwd
save_for_checkpoint = checkpoint and is_grad_enabled and not recompute_for_bwd

# whether we are in the forward stage, or recomputing in the bwd stage (false if not checkpointing)
is_recomputation = checkpoint and is_grad_enabled and recompute_for_bwd

# save the initial state for recomputation by bwd
if save_for_checkpoint:

# save tensors
tensors_to_save, tensor_objects = prepare_for_saving(
inp,
ln_weight,
Expand All @@ -235,6 +243,8 @@ def _forward(
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects

# save other arguments to _forward
ctx.other_args = [
eps,
is_first_microbatch,
Expand Down Expand Up @@ -303,9 +313,14 @@ def _forward(
ln_bias = cast_if_needed(ln_bias, activation_dtype)

tp_world_size = get_distributed_world_size(tp_group)

# bwd needs fc1 input when grad is enabled, fc1 needs grad, and either
# 1) no checkpointing
# or 2) doing the recomputation with checkpointing
backwards_needs_fc1_input = fc1_weight.requires_grad and (
is_recomputation or (is_grad_enabled and not checkpoint)
)

device = inp.device

# Configure Userbuffers communication (comm+GEMM overlap)
Expand Down Expand Up @@ -363,6 +378,8 @@ def _forward(
zero_centered_gamma,
)
ln_out_return = None

# do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing
if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation:
ln_out_return = ln_out

Expand All @@ -371,6 +388,8 @@ def _forward(
ln_out_total = None
ub_obj_lnout = None
if sequence_parallel:

# do not return ln output if checkpointing and in recomputation, not necessary
if return_layernorm_output_gathered and not is_recomputation:
# Perform all-gather in high precision if gathered
# norm output will be returned
Expand Down Expand Up @@ -420,7 +439,7 @@ def _forward(
# FP8 cast to workspace buffer
update_workspace = (
is_first_microbatch is None or is_first_microbatch
) and not is_recomputation
) and not is_recomputation # only update workspace if not checkpointing or checkpointing with no recomp, otherwise cache workspace
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace(
Expand Down Expand Up @@ -508,6 +527,9 @@ def _forward(
# ------------------------------------------------------

# Deallocate FC1 GEMM input tensor if no longer needed
# first part of if statement means that we only clear ln_out_total if
# 1) checkpointing and not recomputing (in the forward stage, not bwd recompute stage)
# 2) not checkpointing
if (not is_recomputation and checkpoint) and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total)

Expand Down Expand Up @@ -545,15 +567,16 @@ def _forward(
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)

# we want to skip fc2 computation if we are checkpointing and recomputing,
# otherwise we compute fc2
if not (is_recomputation and checkpoint):

# if we get to this point,
# we know this is not bwd recomputation, bc would have returned above block
# if we get to this point, we know this is not bwd recomputation
# so we must be in the fwd
# is_grad_enabled can be true or false
# if false, can safely delete
# if true, we can only delete if checkpoint is true, since we will recompute anyways,
# otherwise, checkpoint is false, so cant delete
# now is_grad_enabled can be true or false
# if false, can safely delete
# if true, we can only delete if checkpoint is true, since we will recompute anyways,
# otherwise, checkpoint is false, so cant delete
if (
checkpoint or not is_grad_enabled
): # we can safely get rid of these if this is the case
Expand Down Expand Up @@ -595,9 +618,9 @@ def _forward(
# ------------------------------------------------------

# Deallocate tensors if no longer needed, again, can safely deallocate
if (
if ( # same logic as lasy clear_tensor_data block
checkpoint or not is_grad_enabled
): # we can safely get rid of these if this is the case
):
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)

# Prepare output tensor
Expand All @@ -618,6 +641,9 @@ def _forward(
fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])

# now saving stuff for bwd:
# if we are using checkpointing, this information will be saved in the bwd recomputation stage, so can skip it in fwd
# if we are not checkpointing, then we must save this if grad is enabled
if is_grad_enabled and not save_for_checkpoint:

# Weight with column-wise usage is needed for dgrad GEMM.
Expand Down Expand Up @@ -782,6 +808,7 @@ def _forward(
rsigma,
)

# we only get to this point if we are not recomputing for bwd, since that would have returned in the block above
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp_shape)
Expand Down Expand Up @@ -1775,7 +1802,7 @@ class SelectiveLayerNormMLP(TransformerEngineBaseModule):
checkpoint: bool, default = False
whether to use selective activation checkpointing, where activations are not saved for bwd,
and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute
for memory. default is false, in which activations are saved in fwd.
for memory. default is false, in which activations are saved in fwd. not supported for onnx forward
"""

def __init__(
Expand Down Expand Up @@ -1810,7 +1837,7 @@ def __init__(
ub_bulk_wgrad: bool = False,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
checkpoint: Optional[bool] = True,
checkpoint: Optional[bool] = False,
) -> None:
super().__init__()

Expand Down