Skip to content

Conversation

@jaimec00
Copy link

@jaimec00 jaimec00 commented Oct 28, 2025

Description

Implement Selective Activation Checkpointing for LayerNormMLP by adding a "checkpoint" flag. If checkpoint=True, activations are recomputed in the backward pass, while skipping the recomputation of fc2, as it is not needed for the backward pass. This reduces memory significantly, allowing for larger MLPs without running into OOMs, while still keeping the functionality for SequenceParallel and TensorParallel. Only functionality that is changed is cpu offloading, since there are no more activations to offload when checkpoint=True.

When checkpoint=False, runs regular LayerNormMLP, and all tests in tests/pytorch/selective_layernorm_mlp (listed in "Changes") pass. When checkpoint=True, all tests pass, except for test_cuda_graphs.py.

NOTE: all tests pass now. The cuda_graphs.py issue was because the recomputation was being done outside of the autocast context. saving the autocast state (and quantizer states) via FP8GlobalStateManager in fwd, setting them to that in recomputation, and restoring for bwd fixed the issue.

Fixes #623

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • restructure transformer_engine/pytorch/module/layernorm_mlp.py to allow for selective activation checkpointing with checkpoint=True
  • Add tests for checkpoint=True (in tests/pytorch/selective_layernorm_mlp). Note, these are basically just refactored tests for LayerNormMLP:
    • test_numerics.py
    • test_sanity.py
    • test_cuda_graphs.py
    • test_deferred_init.py
    • test_recipe.py
    • distributed/test_numerics.py
    • compare.py, compares LayerNormMLP with and without checkpoint flag, here are a couple of example comparisons

Performance

with checkpoint=True, forward pass sees no notable change in runtime, but >6X reduction in memory. Backward pass is $\approx$ 20% slower, with larger memory than with checkpoint=False, but still significantly less memory overall.

#########################################

Model Config

seq hidden ffn_hidden layers
65536 2048 8192 12

TIME (ms)

Metric No Checkpointing Checkpointing
fwd time 166.687 167.204
bwd time 452.350 550.005

MEM (MB)

Metric No Checkpointing Checkpointing
fwd mem 64,430.801 10,737.943
bwd mem 2,751.767 7,584.130

MAX ABSOLUTE ERRORS

Tensor Max Abs Error
output 0.000e+00
layer_norm_weight.grad 0.000e+00
layer_norm_bias.grad 0.000e+00
fc1_weight.grad 0.000e+00
fc1_bias.grad 0.000e+00
fc2_weight.grad 0.000e+00
fc2_bias.grad 0.000e+00

#########################################

Model Config

seq hidden ffn_hidden layers
16384 1024 4096 12

TIME (ms)

Metric No Checkpointing Checkpointing
fwd time 11.924 12.428
bwd time 31.612 39.968

MEM (MB)

Metric No Checkpointing Checkpointing
fwd mem 8,054.637 1,342.308
bwd mem 352.343 956.453

MAX ABSOLUTE ERRORS

Tensor Max Abs Error
output 0.000e+00
layer_norm_weight.grad 0.000e+00
layer_norm_bias.grad 0.000e+00
fc1_weight.grad 0.000e+00
fc1_bias.grad 0.000e+00
fc2_weight.grad 0.000e+00
fc2_bias.grad 0.000e+00

#########################################

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers the entire PR, which introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The core implementation trades compute for memory by avoiding saving intermediate activations (ln_out, fc1_out, gelu_out) during forward passes and instead recomputing them during backward passes—crucially skipping fc2 recomputation since it's not needed for gradients. The module integrates with TransformerEngine's existing infrastructure (sequence_parallel, tensor_parallel, FP8 quantization) while explicitly disabling CPU offloading. The benchmark data shows dramatic memory reduction (64GB → 10GB forward memory in the example) at the cost of 20% slower backward passes. The change maintains full compatibility with existing LayerNormMLP configurations and is implemented as an additional module rather than modifying existing code.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 3/5 Implements the new SelectiveLayerNormMLP module with selective checkpointing; contains commented-out FSDP/offloading code that needs verification
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 2/5 Adds deferred init tests but has a critical bug: seq_length kwarg is not set for SelectiveLayerNormMLP (line 34 condition only checks LayerNormMLP)
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 4/5 Distributed test file that creates LayerNormMLP instances instead of SelectiveLayerNormMLP (lines 403-404), so doesn't actually test the new feature
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 5/5 Adds CUDA graph tests with proper module instantiation and parametrization; PR description notes tests fail with 4e-4 errors
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Minor formatting changes only: function signature reformatted and blank lines removed for style consistency
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Comprehensive sanity tests for the new module with minor formatting adjustments to match project style
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Formatting cleanup: removed duplicate imports and adjusted parameter list formatting
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Benchmark script comparing LayerNormMLP vs SelectiveLayerNormMLP; formatting changes only (PEP8 compliance)

Confidence score: 2/5

  • This PR requires careful review due to critical bugs in test files and incomplete CUDA graph compatibility
  • Score reflects three major issues: (1) test_deferred_init.py will fail due to missing seq_length parameter handling, (2) distributed/run_numerics.py doesn't actually test the new module, (3) CUDA graph tests are failing per PR description, and (4) extensive commented-out code in the main implementation suggests incomplete feature development particularly around FSDP and CPU offloading
  • Pay close attention to transformer_engine/pytorch/module/selective_layernorm_mlp.py (commented sections lines 550-571, 915-929, 1545-1552), tests/pytorch/selective_layernorm_mlp/test_deferred_init.py (line 34 conditional logic), and tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (lines 403-404 model instantiation)

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the formatting and style changes made since the last review, not the entire PR. The changes consist entirely of whitespace adjustments: removing extraneous blank lines from seven test files (test_sanity.py, test_cuda_graphs.py, compare.py, test_recipe.py, test_numerics.py, and distributed/run_numerics.py) and the end-of-file blank line. These modifications align the new test suite with the project's formatting standards enforced by pre-commit hooks (black formatter, end-of-file-fixer, trailing-whitespace checker) and the .clang-format configuration that specifies 100-character column limits, 2-space indentation for C++, and strict whitespace rules. Since SelectiveLayerNormMLP is a memory-optimized variant of LayerNormMLP with identical API and forward-pass behavior, the test files are appropriately adapted copies of the existing LayerNormMLP test infrastructure, ensuring comprehensive coverage of FP8 recipes, distributed configurations, deferred initialization, and numerical accuracy without duplicating test logic.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Removed two extraneous blank lines (lines 131,161) for formatting consistency
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 5/5 Removed trailing blank line at end of file
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Removed leading blank line and reformatted long print statement to fit 100-char limit
tests/pytorch/selective_layernorm_mlp/test_recipe.py 4/5 Contains duplicate import of transformer_engine_torch (lines 22 and 29)
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 5/5 No visible formatting changes in this diff
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Removed three blank lines (lines 202, 273, 531) and reformatted function signature (lines 406-408)
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Removed three blank lines (lines 33, 383, 491) for consistent spacing
transformer_engine/pytorch/module/selective_layernorm_mlp.py 5/5 No whitespace-related changes; core implementation is unchanged

Confidence score: 4/5

  • This PR is generally safe to merge with respect to the formatting changes; however, the test_cuda_graphs.py failure and the duplicate import in test_recipe.py require attention before merging.
  • Score reflects that the formatting changes are correct and align with project standards, but the duplicate import (test_recipe.py lines 22 and 29) should be removed, and the test_cuda_graphs.py failures mentioned in the PR description remain unresolved and may indicate numerical instability or a CUDA graph compatibility issue that requires investigation.
  • Pay close attention to tests/pytorch/selective_layernorm_mlp/test_recipe.py (duplicate import) and confirm that test_cuda_graphs.py failures are understood and either resolved or deemed acceptable before merging.

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the most recent docstring correction in tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py (line 391), updating the reference from "LayerNormMLP" to "SelectiveLayerNormMLP". This is a trivial documentation fix that completes the refactoring of existing tests to validate the new SelectiveLayerNormMLP module introduced by this PR. All substantive issues (duplicate import, missing condition logic, extra blank lines) were already flagged in earlier reviews and are not repeated here.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Corrected docstring to reference SelectiveLayerNormMLP instead of LayerNormMLP (line 391).

Confidence score: 5/5

  • This single‑line docstring fix is safe to merge with no risk; it improves documentation accuracy.
  • Score remains perfect because the change is purely cosmetic—no logic, imports, or functionality is affected.
  • No files require additional attention; earlier reviews have already identified all substantive issues in this test file.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. A duplicate import of transformer_engine_torch as tex has been removed from the test file tests/pytorch/selective_layernorm_mlp/test_recipe.py, cleaning up the code. However, the file still contains the same import on line 22, indicating either an incomplete change or a subsequent revert not captured in the diff.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Removed duplicate import statement, no functional impact on test logic

Confidence score: 5/5

  • This change is safe to merge with minimal risk
  • The duplicate import removal is a simple cleanup with no functional impact; the test logic remains identical and validation of FP8 quantizer behavior is unchanged
  • No files require special attention

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The most recent updates focus on cleaning up test files for the new SelectiveLayerNormMLP module by removing unused imports and fixing minor formatting issues. The changes address previous review feedback by eliminating duplicate imports (like transformer_engine_torch as tex) and removing extraneous blank lines. However, several issues remain unresolved from earlier reviews, including the incomplete fix in test_deferred_init.py where the conditional logic in get_module_args still references unused module types (te.Linear, te.LayerNormLinear, te.MultiheadAttention, te.TransformerLayer), creating unreachable dead code. The CUDA graph test failures mentioned in the PR description (errors around 4e-4 and 0.1 absolute errors) persist and are acknowledged as incomplete work.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Removed numerous unused imports while preserving FP8 recipe test functionality
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 4/5 Removed blank line but left dead code in get_module_args referencing unused module types
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 3/5 Simplified to test only SelectiveLayerNormMLP; tests currently failing per PR description
tests/pytorch/selective_layernorm_mlp/test_sanity.py 4/5 Cleaned up imports for focused SelectiveLayerNormMLP testing with comprehensive parameterization
tests/pytorch/selective_layernorm_mlp/test_numerics.py 4/5 Removed unused imports and functions; minor formatting inconsistency remains
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Removed unused imports with no functional changes

Confidence score: 3/5

  • This PR requires careful review due to incomplete CUDA graph functionality and remaining code quality issues.
  • Score reflects acknowledged test failures (test_cuda_graphs.py), dead code in helper methods (test_deferred_init.py line 25-43), and minor formatting inconsistencies. The changes are primarily test infrastructure cleanup, but the core feature is not yet production-ready.
  • Pay close attention to test_cuda_graphs.py (numerical precision failures) and test_deferred_init.py (unreachable conditional branches for te.Linear, te.MultiheadAttention, te.TransformerLayer).

6 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR introduces SelectiveLayerNormMLP, a memory-optimized variant of LayerNormMLP that implements selective activation checkpointing. The forward pass saves only input tensors and weights rather than intermediate activations (fc1_out, gelu_out), reducing memory consumption by ~80% during training. During backward, FC1 and activation are recomputed, but FC2 recomputation is skipped since it's unnecessary for gradient calculation. The implementation maintains compatibility with TransformerEngine's quantization systems (FP8, MXFP8, NVFP4) and distributed features (tensor parallel, sequence parallel), though CPU offloading is explicitly disabled. The module integrates into the existing PyTorch API alongside LayerNormMLP and includes comprehensive test coverage mirroring the standard module's test suite. CUDA graph tests currently fail with ~4e-4 numerical differences, indicating potential interactions between graph capture and the selective recomputation strategy.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 4/5 New 1000+ line module implementing selective checkpointing with bifurcated forward logic for training vs. recomputation
transformer_engine/pytorch/module/init.py 5/5 Adds SelectiveLayerNormMLP import to module's public API
transformer_engine/pytorch/init.py 5/5 Exposes SelectiveLayerNormMLP in top-level PyTorch API
tests/pytorch/selective_layernorm_mlp/test_numerics.py 1/5 Critical bug: reference implementations initialize normalization weights to zeros instead of ones when zero_centered_gamma=False
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 4/5 Missing import warnings causes NameError when zero-tensors are detected; otherwise sound distributed validation
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 4/5 CUDA graph tests currently failing per PR description; unused import and incorrect return type annotation present
tests/pytorch/selective_layernorm_mlp/compare.py 3/5 Performance comparison script with reversed weight-copy direction (copies from SLN to LN instead of vice versa)
tests/pytorch/selective_layernorm_mlp/utils.py 4/5 Test utilities with incorrect return type annotation (declares 2-tuple but returns 3 elements)
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Comprehensive sanity tests covering dtypes, recipes, activations, normalizations, and microbatching
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 4/5 Deferred initialization test with dead code for untested modules
tests/pytorch/selective_layernorm_mlp/test_recipe.py 4/5 FP8 recipe validation test with unused capability-check imports
tests/pytorch/selective_layernorm_mlp/distributed/test_numerics.py 4/5 Distributed test wrapper with typo in docstring and unused variable

Confidence score: 2/5

  • This PR cannot be merged safely due to critical bugs in the test reference implementations and missing imports that will cause runtime failures.
  • Score reflects: (1) test_numerics.py initializes reference normalization weights to zeros instead of ones, making all tests invalid; (2) distributed/run_numerics.py uses warnings.warn() without importing warnings; (3) CUDA graph tests are explicitly failing per PR description; (4) compare.py has reversed weight-copy direction; (5) multiple type annotation mismatches that will cause type-checking failures.
  • Pay close attention to tests/pytorch/selective_layernorm_mlp/test_numerics.py (lines 144-145, 170-171), distributed/run_numerics.py (line 34), and the core module's commented-out FSDP code which may indicate incomplete distributed functionality.

Sequence Diagram

sequenceDiagram
    participant User
    participant SelectiveLayerNormMLP
    participant _SelectiveLayerNormMLP
    participant ForwardPass
    participant BackwardPass
    participant Quantizers
    participant GEMM

    User->>SelectiveLayerNormMLP: forward(inp, is_first_microbatch)
    SelectiveLayerNormMLP->>SelectiveLayerNormMLP: prepare_forward()
    SelectiveLayerNormMLP->>SelectiveLayerNormMLP: _get_quantizers()
    SelectiveLayerNormMLP->>Quantizers: Initialize quantizers
    Quantizers-->>SelectiveLayerNormMLP: Return quantizers
    
    SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=False)
    Note over _SelectiveLayerNormMLP: Save tensors for backward (inp, weights, etc.)
    
    _SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
    ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
    
    alt sequence_parallel
        _SelectiveLayerNormMLP->>ForwardPass: gather_along_first_dim()
        ForwardPass-->>_SelectiveLayerNormMLP: ln_out_total
    end
    
    _SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
    Note over GEMM: FC1 GEMM with optional gelu fusion
    GEMM-->>_SelectiveLayerNormMLP: fc1_out
    
    _SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
    ForwardPass-->>_SelectiveLayerNormMLP: act_out
    
    _SelectiveLayerNormMLP->>GEMM: general_gemm(fc2_weight, act_out)
    Note over GEMM: FC2 GEMM
    GEMM-->>_SelectiveLayerNormMLP: fc2_out
    
    alt sequence_parallel
        _SelectiveLayerNormMLP->>ForwardPass: reduce_scatter_along_first_dim()
        ForwardPass-->>_SelectiveLayerNormMLP: fc2_out
    end
    
    _SelectiveLayerNormMLP-->>SelectiveLayerNormMLP: fc2_out
    SelectiveLayerNormMLP-->>User: Return output
    
    User->>BackwardPass: loss.backward()
    BackwardPass->>_SelectiveLayerNormMLP: _recompute(ctx)
    Note over _SelectiveLayerNormMLP: Restore saved tensors
    
    _SelectiveLayerNormMLP->>_SelectiveLayerNormMLP: _forward(..., recompute_for_bwd=True)
    Note over _SelectiveLayerNormMLP: Recompute activations WITHOUT saving fc2
    
    _SelectiveLayerNormMLP->>ForwardPass: apply_normalization()
    ForwardPass-->>_SelectiveLayerNormMLP: ln_out, mu, rsigma
    
    _SelectiveLayerNormMLP->>GEMM: general_gemm(fc1_weight, ln_out_total)
    GEMM-->>_SelectiveLayerNormMLP: fc1_out
    
    _SelectiveLayerNormMLP->>ForwardPass: activation_func(fc1_out)
    ForwardPass-->>_SelectiveLayerNormMLP: act_out
    
    Note over _SelectiveLayerNormMLP: Skip FC2 forward (not needed for backward)
    
    _SelectiveLayerNormMLP->>BackwardPass: backward(grad_outputs)
    
    BackwardPass->>GEMM: general_gemm(fc2_weight, grad_output)
    Note over GEMM: FC2 DGRAD
    GEMM-->>BackwardPass: fc2_dgrad
    
    BackwardPass->>ForwardPass: activation_func_bwd(fc2_dgrad, fc1_out)
    ForwardPass-->>BackwardPass: dact
    
    alt fc2_weight_requires_grad
        BackwardPass->>GEMM: general_gemm(act_out, grad_output)
        Note over GEMM: FC2 WGRAD
        GEMM-->>BackwardPass: fc2_wgrad, fc2_bias_grad
    end
    
    BackwardPass->>GEMM: general_gemm(fc1_weight, dact)
    Note over GEMM: FC1 DGRAD
    GEMM-->>BackwardPass: fc1_dgrad
    
    alt fc1_weight_requires_grad
        BackwardPass->>GEMM: general_gemm(ln_out_total, dact)
        Note over GEMM: FC1 WGRAD
        GEMM-->>BackwardPass: fc1_wgrad, fc1_bias_grad
    end
    
    BackwardPass->>ForwardPass: layernorm_bwd/rmsnorm_bwd()
    ForwardPass-->>BackwardPass: dgrad, dgamma, dbeta
    
    BackwardPass-->>User: Return gradients
Loading

12 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer removed redundant nested if-condition logic in the deferred initialization test file for SelectiveLayerNormMLP. Previously, the code checked if module == te.SelectiveLayerNormMLP twice: once at line 29 and again at line 33 within the same block. This review captures the cleanup of that duplication, streamlining the conditional logic. The test file verifies that SelectiveLayerNormMLP can be instantiated on a 'meta' device with zero memory allocation and then properly materialized via reset_parameters().

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 5/5 Removed redundant nested if-check for SelectiveLayerNormMLP module condition

Confidence score: 5/5

  • This PR is safe to merge with minimal risk
  • Score reflects straightforward refactoring with no functional changes or edge cases introduced
  • No files require special attention

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest commit fixes a critical bug in the test reference implementations (TorchLayerNorm and TorchRMSNorm) within tests/pytorch/selective_layernorm_mlp/test_numerics.py. Previously, the weight initialization logic was inverted: when zero_centered_gamma=False, the gamma parameter was incorrectly initialized to zeros instead of ones, which would cause the normalized output to be scaled incorrectly. The fix swaps the conditional branches so that zero_centered_gamma=False correctly initializes weights to ones (standard LayerNorm behavior), and zero_centered_gamma=True initializes to zeros (zero-centered variant). This correction ensures that the torch-based reference models produce the expected numerical outputs, allowing accurate validation of the SelectiveLayerNormMLP module against known-correct baselines. This change is isolated to the test infrastructure and does not affect the production implementation of SelectiveLayerNormMLP.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Fixed inverted weight initialization logic in TorchLayerNorm and TorchRMSNorm reference implementations

Confidence score: 5/5

  • This change is safe to merge—it fixes a clear bug in test reference code with no risk to production paths
  • The fix corrects an inverted conditional that was causing reference models to initialize gamma weights incorrectly, ensuring accurate numerical validation going forward
  • No files require special attention—the change is a straightforward two-line correction in test utilities

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jaimec00 jaimec00 changed the title Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) [PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) Oct 28, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer has addressed previous feedback by fixing the parameter-setting logic in test_deferred_init.py, correcting the weight-copy direction in compare.py, and refining timing/measurement code. However, several critical issues remain: a type annotation bug in test_cuda_graphs.py (get_nvfp4_inp_supported_dtypes returns List[torch.dtype] but is annotated as bool), unused imports (warnings in run_numerics.py, time in compare.py), and a typo in the PR description (ffn_fidden on line 261 of compare.py). The test files validate that SelectiveLayerNormMLP maintains numerical parity with LayerNormMLP across distributed and non-distributed configurations while significantly reducing memory usage by recomputing activations instead of caching them. The PR description notes that test_cuda_graphs.py is failing with numerical errors (typically 4e-4, occasionally 0.1 absolute), which correlates with the type annotation bug in that file.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 2/5 Critical type annotation bug on line 71: function returns List[torch.dtype] but annotated as bool, causing type-checking failures and possibly contributing to test failures
tests/pytorch/selective_layernorm_mlp/test_numerics.py 4/5 Adds comprehensive numerical validation comparing SelectiveLayerNormMLP against torch reference implementations across dtypes, activations, and FP8/NVFP4 recipes
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 3/5 Validates distributed numerics for tensor/sequence parallelism; unused warnings import added but comment updated correctly
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 4/5 Simplified parameter-setting logic by removing redundant conditional check; correct for single-module test file
tests/pytorch/selective_layernorm_mlp/test_sanity.py 4/5 Standard sanity tests adapted for SelectiveLayerNormMLP; comprehensive coverage of configurations and FP8 recipes
tests/pytorch/selective_layernorm_mlp/compare.py 3/5 Benchmark script refactored to use CUDA events and per-config stats; unused time import remains and typo ffn_fidden on line 261

Confidence score: 3/5

  • This PR requires careful attention due to a critical type annotation bug and known test failures (cuda_graphs), though the core numerical validation appears sound
  • Score reflects: (1) critical type annotation error in test_cuda_graphs.py that breaks type checking and may contribute to reported test failures; (2) unused imports in two files suggesting incomplete cleanup; (3) PR description acknowledges cuda_graphs tests are failing with numerical errors, indicating incomplete functionality
  • Pay close attention to test_cuda_graphs.py (line 71 type annotation) and investigate why CUDA graph tests fail with 4e-4 errors—the annotation bug may be masking logic errors in the supported-dtypes check

6 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The most recent changes apply code formatting to tests/pytorch/selective_layernorm_mlp/compare.py, reformatting a multi-line f-string descriptor to comply with line-length and style guidelines. While spacing, operator placement, and list formatting improvements are beneficial, a critical syntax error was introduced: the descriptor string is now missing its closing parenthesis after the f-string definition, preventing the script from running. No other changes were made to imports, logic, or functionality in this update.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/compare.py 0/5 Applied code formatting but introduced syntax error: missing closing parenthesis after multi-line f-string definition (line 261)

Confidence score: 0/5

  • This PR will fail to run due to a syntax error that prevents the comparison script from executing
  • Score is zero because a closing parenthesis is missing after the f-string definition on line 261, causing a guaranteed SyntaxError when the script is imported or executed
  • The file tests/pytorch/selective_layernorm_mlp/compare.py requires immediate attention to add the missing closing parenthesis after line 262

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This update addresses a single-character syntax fix in the benchmark comparison script for SelectiveLayerNormMLP. The change adds a missing closing parenthesis on line 262 of tests/pytorch/selective_layernorm_mlp/compare.py that completes the multi-line f-string definition for the desc variable. This variable formats configuration metadata (sequence length, hidden size, FFN hidden size, and number of layers) that labels benchmark results comparing LayerNormMLP and SelectiveLayerNormMLP. Without this closing parenthesis, the script would raise a SyntaxError and fail to execute. The fix ensures the benchmarking tool can run properly and display properly formatted configuration descriptions in its output.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Added missing closing parenthesis to complete multi-line f-string definition

Confidence score: 5/5

  • This PR is safe to merge with minimal risk
  • The change fixes a syntax error that would have prevented the script from running, and introduces no new logic or behavioral changes
  • No files require special attention; this is a trivial syntax correction

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jaimec00 jaimec00 changed the title [PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) [Draft][RFC][PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) Oct 28, 2025
@jaimec00 jaimec00 changed the title [Draft][RFC][PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) [RFC][PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP (SelectiveLayerNormMLP) Oct 28, 2025
@ptrendx
Copy link
Member

ptrendx commented Oct 28, 2025

Hi @jaimec00, thank you for the contribution!
The first general question (before really diving into the actual review) is why was it necessary to make a completely new module for that functionality? We should be able to introduce this as an option to the regular LayerNormMLP. Or is the idea to first show the approach in its own module and then integrate it back into LayerNormMLP?

@jaimec00
Copy link
Author

jaimec00 commented Oct 28, 2025

Thanks for the comment, @ptrendx!
I started as a separate module to de-risk changes to LayerNormMLP and to enable clean A/B comparisons with LayerNormMLP while iterating on selective checkpointing (and CUDA Graphs/TP overlap specifics). Functionally it’s a refactor of LayerNormMLP, but as of right now, it only implements selective activation checkpointing (no option for standard LayerNormMLP).
The plan is to merge this back into LayerNormMLP behind a flag (e.g. checkpoint_policy={'none', 'selective'}), add tests for both policies, and then retire the standalone module.
If you’d rather skip the interim step, I can integrate it into LayerNormMLP in this PR and keep it gated. What do you think? If you have a different approach that you would prefer, I would be happy to follow that.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The most recent update adds a new SelectiveLayerNormMLP module that implements selective activation checkpointing, trading ~20% additional backward compute time for dramatic memory savings (up to 6x reduction in forward peak memory). The implementation introduces a checkpoint parameter that gates three key behaviors: (1) which tensors are saved during forward (only inputs/weights vs. full activations), (2) whether FC2 is recomputed during backward (it's skipped since FC2's output isn't needed for weight gradients), and (3) CPU offloading availability (disabled when checkpointing is active). The module closely mirrors LayerNormMLP's structure, reusing the same FP8/tensor-parallel/sequence-parallel infrastructure while adding conditional logic through save_for_checkpoint and is_recomputation flags. This new module integrates with the existing PyTorch module hierarchy under transformer_engine/pytorch/module/, alongside sibling modules like LayerNormMLP and LayerNormLinear.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 3/5 Adds new SelectiveLayerNormMLP module with selective activation checkpointing that skips FC2 recomputation during backward, reducing memory at the cost of additional compute

Confidence score: 3/5

  • This PR requires careful review due to incomplete testing and a known CUDA graphs compatibility issue that causes numerical errors (4e-4 typical, 0.1 absolute in some cases)
  • Score reflects the unresolved CUDA graphs test failures and the significant complexity of the selective checkpointing logic that conditionally saves/recomputes tensors across forward/backward passes—while the core implementation appears sound, the interaction between checkpoint flags, FP8 metadata tracking, and CUDA graph capture needs verification
  • Pay close attention to the CUDA graphs test failures in test_cuda_graphs.py (lines documenting 4e-4 typical errors with occasional 0.1 absolute errors), the conditional tensor saving logic in lines 621-691 (which determines what gets saved vs. recomputed based on checkpoint/is_grad_enabled/save_for_checkpoint flags), and the _recompute method's dual-path behavior (lines 920-925) that either reruns forward with recompute_for_bwd=True or loads saved tensors

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer has addressed the review feedback by correcting module references, removing dead code, fixing initialization logic for RMSNorm/LayerNorm gamma parameters, refactoring the comparison script to test SelectiveLayerNormMLP against itself (with/without checkpointing), and adding missing imports. The core implementation of selective activation checkpointing in SelectiveLayerNormMLP._forward remains unchanged. Test infrastructure has been updated to parametrize the new checkpoint flag across all test suites, though CUDA graph tests remain skipped when checkpoint=True.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/module/selective_layernorm_mlp.py 5/5 Core implementation of selective activation checkpointing for LayerNormMLP; no changes since last review
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Refactored to compare SelectiveLayerNormMLP(checkpoint=False) vs SelectiveLayerNormMLP(checkpoint=True) instead of comparing against LayerNormMLP
tests/pytorch/selective_layernorm_mlp/test_numerics.py 5/5 Added checkpoint parameter to test matrix; fixed gamma initialization logic for RMSNorm/LayerNorm
tests/pytorch/selective_layernorm_mlp/distributed/run_numerics.py 5/5 Added test cases for checkpoint=True and checkpoint=False; added missing warnings import
tests/pytorch/selective_layernorm_mlp/test_sanity.py 5/5 Added checkpoint parameter to sanity test matrix
tests/pytorch/selective_layernorm_mlp/test_recipe.py 5/5 Added checkpoint parameter to quantizer update test
tests/pytorch/selective_layernorm_mlp/test_deferred_init.py 5/5 Added checkpoint parameter to deferred initialization tests
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 4/5 Added checkpoint parameter but explicitly skips tests when checkpoint=True due to known failures

Confidence score: 5/5

  • This PR is safe to merge with minimal risk; the selective checkpointing implementation is sound and well-tested
  • Score reflects that all previous review issues were addressed, tests pass (except known CUDA graph limitations), and the implementation follows TE patterns for custom autograd functions with memory management
  • Pay close attention to test_cuda_graphs.py—the developer explicitly asks for help with CUDA graph failures when checkpoint=True, which remain unresolved and are currently skipped

8 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

jaimec00 and others added 11 commits October 29, 2025 01:22
Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
…ributed/run_numerics.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Jaime <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest changes address previous review feedback by:

  1. Fixing the skip message in test_cuda_graphs.py: Updated the skip message from "SelectiveLayerNorm" to "LayerNormMLP" to correctly reference the module being tested (lines 297-298). This aligns with the PR's approach of adding a checkpoint parameter directly to LayerNormMLP rather than creating a separate module.

  2. Minor formatting in compare.py: Collapsed lines 35-36 from two lines to one for improved readability. No functional changes were made to the comparison script, which benchmarks LayerNormMLP with and without the checkpoint flag to validate the performance claims (6X memory reduction, ~20% slower backward).

  3. Core implementation in layernorm_mlp.py: The main logic implements selective activation checkpointing by (a) adding a recompute_for_bwd parameter to _forward that skips FC2 computation during recomputation, (b) saving minimal state when checkpointing is enabled, and (c) introducing a _recompute dispatcher that either re-executes the forward pass or loads saved tensors. The change bypasses CPU offloading and FSDP scattering when checkpoint=True since there are no activations to offload.

The implementation integrates with TransformerEngine's existing FP8 quantization, distributed training (sequence/tensor parallelism), and autocast infrastructure. The primary trade-off is ~20% slower backward pass for ~83% memory reduction, enabling training of larger MLP models.

Important Files Changed

Filename Score Overview
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 5/5 Fixed typo in skip message from "SelectiveLayerNorm" to "LayerNormMLP"
tests/pytorch/selective_layernorm_mlp/compare.py 5/5 Minor formatting change collapsing module instantiation from 2 lines to 1
transformer_engine/pytorch/module/layernorm_mlp.py 3/5 Core implementation of selective checkpointing with checkpoint flag, recompute_for_bwd parameter, and recomputation logic; potential issues with return type consistency and stale workspace comment

Confidence score: 3/5

  • This PR requires careful review before merging due to unresolved issues in the core implementation logic and incomplete CUDA graphs support
  • Score reflects three main concerns: (1) return type inconsistency in _recompute where the non-checkpoint branch returns a list but caller expects tuple unpacking, (2) workspace update comment at line 442 that references commented-out code and may be stale, and (3) CUDA graphs tests are explicitly skipped when checkpoint=True which means a known compatibility gap exists
  • Pay close attention to transformer_engine/pytorch/module/layernorm_mlp.py especially the _recompute dispatcher (line 950), the FC2-skip condition (line 574), and the workspace update logic (line 442)

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@jaimec00
Copy link
Author

Hi @ptrendx, just wanted to give an update on the PR. I implemented selective activation checkpointing directly in LayerNormMLP, which can be controlled by initializing the module with a 'checkpoint' flag. When checkpoint=False, behavior is unchanged, and it is simply LayerNormMLP (all tests pass, INCLUDING test_cuda_graphs.py). When checkpoint=True, selective activation checkpointing is implemented, getting >6X memory reduction in the forward pass at the cost of 20% slower backward pass. All tests pass EXCEPT for test_cuda_graphs.py. I have started debugging this and I am confident that this is due to the quantizer state changing in between the initial forward pass and the recomputation forward pass. I am currently working on a fix for this. In the meantime, I was wondering if this feature is something that you and the people at NVIDIA would be interested in merging into the main repo. If so, are there any notes you have for me? Thanks!

@ptrendx
Copy link
Member

ptrendx commented Oct 29, 2025

@jaimec00 yes, we are definitely interested in the functionality :-). It is still a little sad to me that pyTorch does not really give us a way to do that in a more generic way. I will review this PR either this or beginning of next week.

@jaimec00
Copy link
Author

@ptrendx perfect, thank you!

jaimec00 and others added 2 commits October 30, 2025 01:58
…ing fwd. Float8CurrentScaling with bf16 and fp16 still fail with checkpointing

Signed-off-by: Jaime Cardenas <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint flag. When enabled, activations are recomputed during backward pass instead of being saved, reducing memory usage by ~83% with only ~20% slower backward pass. The implementation recomputes LayerNorm and FC1 operations but skips FC2 (not needed for gradients).

Key changes:

  • Refactored _LayerNormMLP.forward() into _forward() with recompute_for_bwd parameter
  • Added _recompute() method that either recomputes from saved inputs (checkpoint=True) or loads saved activations (checkpoint=False)
  • Clones all quantizers during forward pass and saves them to ctx.other_args for recomputation
  • Added comprehensive test coverage including CUDA graph compatibility tests

Critical issues:

  • Quantizer state cloning (lines 191-204, 1034-1037): The _copy_quantizer() function creates independent copies of scale/amax tensors for Float8CurrentScalingQuantizer. These cloned tensors capture a snapshot of the quantizer state during forward pass but don't reflect updates that occur during training. During backward recomputation, stale quantizer values are used, causing numerical mismatches. This is the root cause of CUDA graph test failures with Float8CurrentScaling + checkpoint=True + non-fp32 dtypes.
  • The workspace update logic (line 520-528) is not skipped during recomputation, but quantized weights may have changed between forward and backward passes when using dynamic quantization.

Confidence Score: 3/5

  • This PR has a critical quantizer cloning issue that breaks Float8CurrentScaling with CUDA graphs, but is safe for other quantization recipes and non-graph usage
  • Score of 3 reflects that the core checkpointing mechanism is well-implemented and tests pass for most scenarios, but the quantizer state management has a fundamental design flaw. The cloning approach creates stale copies that don't track live quantizer updates, breaking Float8CurrentScaling compatibility. The PR author acknowledges this limitation and skips affected test cases. The feature provides significant memory benefits (6x reduction) with acceptable performance tradeoff (~20% slower backward) for supported configurations.
  • transformer_engine/pytorch/module/layernorm_mlp.py requires attention for the quantizer cloning logic (especially _copy_quantizer function and its usage in recomputation)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 2/5 Implements selective activation checkpointing with checkpoint flag. Critical issue: quantizer cloning creates stale state copies that break Float8CurrentScaling with CUDA graphs. Additional concerns with skipped workspace updates during recomputation.
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 4/5 Test file for CUDA graph compatibility with checkpoint flag. Skip condition at line 299-307 correctly identifies known failure case (Float8CurrentScaling + checkpoint + non-fp32 dtypes).

Sequence Diagram

sequenceDiagram
    participant User
    participant LayerNormMLP
    participant Forward as _forward()
    participant Backward as backward()
    participant Recompute as _recompute()

    User->>LayerNormMLP: forward(checkpoint=True)
    LayerNormMLP->>Forward: Call with checkpoint=True, recompute_for_bwd=False
    
    Note over Forward: save_for_checkpoint=True<br/>is_recomputation=False
    
    Forward->>Forward: Save input tensors (inp, weights, bias)
    Forward->>Forward: Clone all quantizers to ctx.other_args
    Forward->>Forward: Compute LayerNorm
    Forward->>Forward: Compute FC1 GEMM
    Forward->>Forward: Apply activation (GELU)
    Forward->>Forward: Compute FC2 GEMM
    Forward->>Forward: Clear activation tensors (fc1_out, act_out)
    Forward-->>LayerNormMLP: Return fc2_out
    
    User->>LayerNormMLP: backward(grad_output)
    LayerNormMLP->>Backward: Process gradients
    Backward->>Recompute: Call _recompute(ctx)
    
    alt checkpoint=True
        Recompute->>Forward: Call _forward() with recompute_for_bwd=True
        Note over Forward: is_recomputation=True<br/>Recompute from saved inputs
        Forward->>Forward: Restore saved tensors from ctx
        Forward->>Forward: Use cloned quantizers from ctx.other_args
        Forward->>Forward: Re-execute LayerNorm
        Forward->>Forward: Re-execute FC1 GEMM
        Forward->>Forward: Re-execute activation
        Forward->>Forward: Skip FC2 GEMM (not needed for backward)
        Forward->>Forward: Save recomputed activations to ctx
        Forward-->>Recompute: Return [ctx, activations...]
    else checkpoint=False
        Recompute->>Recompute: Load saved activations from ctx
        Recompute-->>Backward: Return [ctx, activations...]
    end
    
    Backward->>Backward: Compute FC2 weight/input gradients
    Backward->>Backward: Compute activation gradients
    Backward->>Backward: Compute FC1 weight/input gradients
    Backward->>Backward: Compute LayerNorm gradients
    Backward-->>User: Return input/weight gradients
Loading

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, activations are not saved during forward pass; instead, LayerNorm and FC1 are recomputed during backward (FC2 is skipped as it's not needed for gradients).

Key changes:

  • Refactored forward() to call _forward() with recompute_for_bwd flag
  • Added _recompute() method that either loads saved activations (checkpoint=False) or recomputes from inputs (checkpoint=True)
  • Saves and restores autocast state and FP8 quantizer states (for DelayedScaling recipe) to ensure numerical consistency between forward and recompute
  • Conditionally skips FC2, CPU offloading, and FSDP scattering when checkpointing is enabled
  • Tests confirm CUDA graphs now work with checkpoint=True after fixing autocast/quantizer state management

Trade-offs:

  • Memory: ~83% reduction (6X lower in forward pass)
  • Compute: ~22% slower backward pass
  • Numerics: Identical outputs (0.0 max absolute error)

Issue found:

  • backwards_needs_fc1_input condition doesn't match its comment—should include is_recomputation case

Confidence Score: 4/5

  • This PR is safe to merge with one logical fix needed for the backwards_needs_fc1_input condition
  • The implementation is well-tested (all tests pass including CUDA graphs), has impressive benchmarks showing proper memory/compute tradeoffs, and correctly handles complex state management for autocast and quantizers. The one logic issue with backwards_needs_fc1_input should be fixed before merge—the condition doesn't match the comment and may cause incorrect quantizer usage during recomputation. Otherwise, the checkpointing logic is sound: inputs are saved during forward, FC2 is correctly skipped during recompute (not needed for gradients), and state restoration ensures numerical consistency.
  • transformer_engine/pytorch/module/layernorm_mlp.py needs fix for backwards_needs_fc1_input logic at line 330

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 4/5 Implements selective activation checkpointing by refactoring forward into _forward, adding recomputation logic with autocast/quantizer state management, and conditionally skipping FC2 during backward recompute. Minor logic issue with backwards_needs_fc1_input condition.
tests/pytorch/selective_layernorm_mlp/test_cuda_graphs.py 5/5 Test file for CUDA graphs with checkpoint parameter, correctly parametrized to test both checkpoint=True and checkpoint=False with various FP8 recipes and dtypes.

Sequence Diagram

sequenceDiagram
    participant User
    participant Forward as forward()
    participant FwdImpl as _forward()
    participant Recompute as _recompute()
    participant Backward as backward()
    
    Note over User,Backward: checkpoint=True Path
    
    User->>Forward: input, checkpoint=True
    Forward->>FwdImpl: recompute_for_bwd=False
    
    Note over FwdImpl: Save ctx.checkpoint=True
    Note over FwdImpl: Save autocast state
    Note over FwdImpl: Copy quantizer states (DelayedScaling only)
    
    Note over FwdImpl: save_for_checkpoint=True
    Note over FwdImpl: Save inputs: inp, ln_weight, ln_bias, fc1_weight, etc.
    
    Note over FwdImpl: Compute LayerNorm
    Note over FwdImpl: Compute FC1 + Activation
    Note over FwdImpl: Skip FC2 caching (cleared immediately)
    Note over FwdImpl: Compute FC2
    
    FwdImpl-->>Forward: fc2_out
    Forward-->>User: output
    
    User->>Backward: grad_output
    Backward->>Recompute: ctx.checkpoint=True
    
    Note over Recompute: Restore saved inputs
    Note over Recompute: Set autocast state to saved
    Note over Recompute: Restore quantizer states (DelayedScaling)
    
    Recompute->>FwdImpl: recompute_for_bwd=True
    Note over FwdImpl: is_recomputation=True
    Note over FwdImpl: Compute LayerNorm
    Note over FwdImpl: Compute FC1 + Activation
    Note over FwdImpl: Skip FC2 (not needed for backward)
    
    FwdImpl-->>Recompute: [ctx, recomputed tensors]
    
    Note over Recompute: Restore autocast state
    Note over Recompute: Restore quantizer states
    
    Recompute-->>Backward: tensors for backward
    
    Note over Backward: Compute FC2 gradients
    Note over Backward: Compute activation gradients
    Note over Backward: Compute FC1 gradients
    Note over Backward: Compute LayerNorm gradients
    
    Backward-->>User: input_grad
Loading

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, it skips saving large intermediate activations (fc1_out, act_out) in the forward pass and instead recomputes them during backward, while cleverly skipping the FC2 computation during recomputation since it's not needed for gradient calculation.

Key Implementation Details:

  • Forward pass with checkpoint=True saves only 7 input tensors instead of 14 activation tensors, achieving ~83% memory reduction
  • Recomputation phase (recompute_for_bwd=True) executes LayerNorm + FC1 + Activation, but skips FC2
  • Properly handles autocast state and FP8 quantizer states for DelayedScaling recipe
  • CPU offloading and FSDP scattering are bypassed when checkpoint=True (no activations to offload)

Performance Tradeoffs:

  • Forward: negligible overhead
  • Backward: ~22% slower (recomputation cost)
  • Memory: 6-8x reduction in forward pass, overall significant savings

Critical Issue Found:

  • Line 829-835: fc2_out is undefined when is_recomputation=True and execution reaches the return statement, causing UnboundLocalError in edge cases where is_grad_enabled=False during recomputation

Confidence Score: 3/5

  • This PR has a critical logic bug that could cause runtime errors in specific edge cases, but the core implementation is sound
  • Score reflects well-designed selective checkpointing mechanism with proper FP8 state management and comprehensive test coverage. However, critical issue at line 829-835 where fc2_out is undefined during recomputation could cause UnboundLocalError. The bug is localized and fixable, but must be resolved before merge.
  • transformer_engine/pytorch/module/layernorm_mlp.py requires immediate attention to fix the undefined fc2_out variable issue in the return logic (lines 829-835)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 3/5 Implements selective activation checkpointing with checkpoint flag. Adds recomputation logic to skip FC2 during backward pass. Found critical issue where fc2_out is undefined when is_recomputation=True with is_grad_enabled=False, causing UnboundLocalError at return statement.

Sequence Diagram

sequenceDiagram
    participant User
    participant LayerNormMLP
    participant Forward as _forward
    participant Recompute as _recompute
    participant Backward

    User->>LayerNormMLP: forward(inp, weights, checkpoint=True)
    LayerNormMLP->>Forward: _forward(checkpoint=True, recompute_for_bwd=False)
    
    Note over Forward: save_for_checkpoint=True
    Note over Forward: is_recomputation=False
    
    Forward->>Forward: Save autocast state & quantizer states
    Forward->>Forward: Save input tensors to ctx (7 tensors)
    Forward->>Forward: Save other_args dict to ctx
    Forward->>Forward: Compute LayerNorm
    Forward->>Forward: Compute FC1 + Activation
    Forward->>Forward: Compute FC2 (full forward)
    Forward->>Forward: Skip activation saving (checkpoint=True)
    Forward-->>LayerNormMLP: fc2_out
    LayerNormMLP-->>User: output
    
    User->>LayerNormMLP: backward(grad_output)
    LayerNormMLP->>Recompute: _recompute(ctx)
    
    alt checkpoint=True
        Note over Recompute: Restore saved tensors (7 tensors)
        Recompute->>Recompute: Set autocast state from ctx
        Recompute->>Recompute: Restore quantizer states (DelayedScaling)
        Recompute->>Forward: _forward(saved_tensors, checkpoint=True, recompute_for_bwd=True)
        
        Note over Forward: is_recomputation=True
        Forward->>Forward: Compute LayerNorm
        Forward->>Forward: Compute FC1 + Activation
        Forward->>Forward: SKIP FC2 (is_recomputation=True)
        Forward->>Forward: Save recomputed activations to ctx
        Forward-->>Recompute: [ctx, activations...]
        
        Recompute->>Recompute: Restore autocast & quantizer states
        Recompute-->>Backward: [ctx, activations...]
    else checkpoint=False
        Note over Recompute: Restore saved tensors (14 tensors)
        Recompute-->>Backward: [ctx, activations...]
    end
    
    Backward->>Backward: Unpack activations
    Backward->>Backward: Compute gradients using recomputed activations
    Backward-->>User: gradients
Loading

Additional Comments (1)

  1. transformer_engine/pytorch/module/layernorm_mlp.py, line 829-835 (link)

    logic: fc2_out is undefined when is_recomputation=True and checkpoint=True (FC2 is skipped at line 592), causing UnboundLocalError. The return at line 826 exits early only when is_grad_enabled=True, but if is_grad_enabled=False during recomputation, execution continues to line 835 where fc2_out doesn't exist.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@jaimec00 jaimec00 changed the title [RFC][PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag [PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag Nov 2, 2025
@jaimec00
Copy link
Author

jaimec00 commented Nov 3, 2025

Hi @ptrendx, I have another quick update: all tests pass now, including the CUDA graph tests. Solution is described in the edited description. Thanks!

@ptrendx ptrendx self-requested a review November 4, 2025 19:18
@@ -0,0 +1,269 @@
import time
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment about this file - it is really nice, but it is not a test - it doesn't actually test anything, it just measures. We could introduce some test functionality here by e.g. ensuring that the error between the checkpointed LayerNormMLP is zero (since this shouldn't affect numerics) or that the memory used is lower (ideally we would quantify the expected memory usage and test against that, but for now even just making sure that the memory usage goes down would be good.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I converted it into a test for checking that memory goes down at least 6X in the forward pass. I also asserted that checkpointing is slower than not checkpointing in the backward pass (not sure if this is helpful, but let me know), and that the differences are 0. I put this test in tests/pytorch/layernorm_mlp/test_selective_activation_checkpointing.py because I wasn't sure where it fit in the rest of the testing scripts, but let me know if this test would be better elsewhere!

Comment on lines +223 to +226
# save the state of autocast and quantizers for recomputation
ctx.autocast_state = (
FP8GlobalStateManager.get_autocast_state()
) # to restore autocast state during recomputation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to also save the RNG state here, e.g. to get the same results when using stochastic rounding with FP4. You can see in transformer_engine/pytorch/distributed.py how the checkpoint function handles that when there is no custom RNG tracker present. @ksivaman do you think we should add here logic to handle custom RNG trackers for Megatron?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just added this for cpu and cuda rng states, let me know if I should try to add functionality for custom rng tracker

@ptrendx
Copy link
Member

ptrendx commented Nov 4, 2025

@jaimec00 I left some comments. Generally it looks good, other that the duplication in the test infrastructure and not handling the RNG state. I did resolve some of the Greptile comments that you already addressed, but there are still a few that apply, so I left those - please address those as well.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, the forward pass saves only inputs instead of intermediate activations, achieving >6X memory reduction. During the backward pass, FC1 and activation are recomputed while FC2 computation is skipped (since it's not needed for gradients).

Key Implementation Details:

  • Adds checkpoint boolean parameter to LayerNormMLP constructor (default False)
  • Refactors forward() into _forward() with recompute_for_bwd flag to control behavior
  • Introduces _recompute() method that restores FP8 quantizer states, autocast context, and RNG states before recomputing
  • Skips FC2 GEMM during recomputation via if not (is_recomputation and checkpoint) guard
  • Properly handles FP8 DelayedScaling recipe by saving/restoring quantizer states using FP8GlobalStateManager

Test Coverage:

  • New dedicated test file validates numerical correctness (zero gradient/output differences)
  • Existing test suites extended with checkpoint parameter across numerics, sanity, recipe, CUDA graphs, and distributed tests
  • Performance benchmarks confirm expected tradeoffs: ~83% memory reduction with ~20% backward slowdown

Minor Issues:

  • Several typo fixes included (UnsupporedUnsupported, corectnesscorrectness)
  • CPU offloading and FSDP scattering bypassed when checkpoint=True (expected behavior per PR description)

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk, pending verification of edge cases around return values during recomputation
  • The implementation is well-structured with comprehensive test coverage across multiple scenarios (FP8 recipes, distributed modes, CUDA graphs). The core logic for skipping FC2 during recomputation is sound. FP8 quantizer state and RNG state management appears correct. However, there are some edge cases that should be verified: (1) behavior when return_layernorm_output=True during recomputation may not match expectations, and (2) the interaction between checkpoint mode and features like CPU offloading/FSDP is intentionally disabled but should be clearly documented. Previous review comments about quantizer cloning and state management have been addressed.
  • transformer_engine/pytorch/module/layernorm_mlp.py - verify return value handling during recomputation with return_layernorm_output=True

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 4/5 Implements selective activation checkpointing by adding checkpoint parameter. Recomputes FC1 and activation in backward pass while skipping FC2. Properly handles FP8 quantizer state and RNG state restoration. Some edge cases around return values during recomputation need verification.
tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py 5/5 New test file validating checkpoint feature. Tests numerical correctness (outputs and gradients match), memory reduction (>6X), and expected backward slowdown. Covers multiple model sizes and sequence lengths.
tests/pytorch/test_numerics.py 5/5 Adds checkpoint parameter to existing LayerNormMLP test cases. Includes typo fixes ('Unsuppored' → 'Unsupported', 'Shoule' → 'Should'). Ensures checkpoint functionality is tested across all accuracy and wgrad tests.
tests/pytorch/test_cuda_graphs.py 5/5 Splits layernorm_mlp test case into separate layernorm_mlp_nocheckpoint and layernorm_mlp_checkpoint variants to test CUDA graph compatibility with both checkpoint modes. Tests now pass for both configurations.

Sequence Diagram

sequenceDiagram
    participant User
    participant LayerNormMLP
    participant Forward as _forward()
    participant Ctx as Context
    participant FC1 as FC1 GEMM
    participant Act as Activation
    participant FC2 as FC2 GEMM
    participant Backward as backward()
    participant Recompute as _recompute()

    Note over User,Recompute: Forward Pass (checkpoint=True)
    User->>LayerNormMLP: forward(input)
    LayerNormMLP->>Forward: _forward(checkpoint=True, recompute_for_bwd=False)
    Forward->>Ctx: Save checkpoint flag, autocast state, RNG state
    Forward->>Ctx: Save input tensors for recompute
    Forward->>Forward: LayerNorm computation
    Forward->>FC1: Compute FC1 GEMM
    FC1-->>Forward: fc1_out
    Forward->>Act: Apply activation (GELU)
    Act-->>Forward: act_out
    Forward->>FC2: Compute FC2 GEMM
    FC2-->>Forward: fc2_out
    Note over Forward: Deallocate activations (saved memory!)
    Forward-->>User: fc2_out

    Note over User,Recompute: Backward Pass (checkpoint=True)
    User->>Backward: backward(grad_output)
    Backward->>Recompute: _recompute(ctx)
    Recompute->>Ctx: Restore autocast state & RNG state
    Recompute->>Forward: _forward(checkpoint=True, recompute_for_bwd=True)
    Note over Forward: Recompute from saved inputs
    Forward->>Forward: LayerNorm computation
    Forward->>FC1: Compute FC1 GEMM (again)
    FC1-->>Forward: fc1_out
    Forward->>Act: Apply activation (again)
    Act-->>Forward: act_out
    Note over Forward,FC2: Skip FC2 computation (not needed for gradients!)
    Forward->>Ctx: Save recomputed activations to ctx
    Forward-->>Recompute: [ctx, activations]
    Recompute->>Ctx: Restore final autocast & RNG states
    Recompute-->>Backward: ctx + recomputed tensors
    Backward->>Backward: Compute gradients using recomputed activations
    Backward-->>User: input gradients
Loading

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jaime Cardenas <[email protected]>
Signed-off-by: Jaime Cardenas <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements selective activation checkpointing for LayerNormMLP by adding a checkpoint parameter. When enabled, it trades ~20% slower backward pass for >6X memory reduction by recomputing activations during backward instead of saving them.

Key implementation details:

  • Refactored forward into _forward (main logic) and added _recompute method for backward recomputation
  • When checkpoint=True, only inputs (weights, biases, input tensor) are saved during forward pass instead of all intermediate activations
  • During backward, activations are recomputed by calling _forward again with recompute_for_bwd=True
  • FC2 computation is intelligently skipped during recomputation since it's not needed for gradient calculation
  • RNG state and FP8 quantizer states are properly saved/restored to ensure numerical correctness during recomputation
  • CPU offloading and FSDP scattering only apply when checkpoint=False (when activations exist to offload/scatter)

Test coverage:

  • New test validates >6X memory reduction with identical numerical outputs and gradients
  • Existing test suites extended with checkpoint parameter across accuracy, sanity, CUDA graphs, and distributed tests
  • All tests reportedly pass including CUDA graphs (fixed by proper autocast/quantizer state management)

Confidence Score: 4/5

  • This PR is safe to merge with thorough testing and review of the checkpointing logic
  • The implementation is well-structured with proper state management for RNG and quantizers, comprehensive test coverage validates correctness and performance claims, and the feature is opt-in (default checkpoint=False maintains existing behavior). Score is 4 rather than 5 due to the complexity of the checkpointing logic and potential edge cases in distributed/FP8 scenarios that may need additional validation in production workloads.
  • Primary attention needed for transformer_engine/pytorch/module/layernorm_mlp.py due to complex state management logic in _forward and _recompute methods, particularly around FP8 quantizer state handling and conditional tensor saving/clearing paths.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 4/5 Implements selective activation checkpointing by adding checkpoint parameter, refactoring forward into _forward and _recompute, with proper RNG/quantizer state management. Logic correctly skips FC2 during recomputation and handles activation saving conditionally.
tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py 5/5 New test validates checkpointing achieves >6X memory reduction in forward pass with identical numerical results. Tests appropriately verify performance tradeoffs and gradient correctness.
tests/pytorch/test_numerics.py 5/5 Adds checkpoint parameter to existing LayerNormMLP accuracy tests, plus typo fixes. Changes are minimal and properly extend test coverage.
tests/pytorch/test_cuda_graphs.py 5/5 Splits LayerNormMLP CUDA graph tests into separate checkpoint/no-checkpoint cases to test both modes independently.

Sequence Diagram

sequenceDiagram
    participant User
    participant LayerNormMLP
    participant _LayerNormMLP.forward
    participant _LayerNormMLP._forward
    participant _LayerNormMLP.backward
    participant _LayerNormMLP._recompute

    Note over User,_LayerNormMLP._recompute: Forward Pass (checkpoint=True)
    User->>LayerNormMLP: forward(input)
    LayerNormMLP->>_LayerNormMLP.forward: apply(input, checkpoint=True)
    _LayerNormMLP.forward->>_LayerNormMLP._forward: _forward(..., recompute_for_bwd=False)
    Note over _LayerNormMLP._forward: save_for_checkpoint=True
    Note over _LayerNormMLP._forward: Save inputs only (inp, weights, biases)
    Note over _LayerNormMLP._forward: Save RNG state & FP8 quantizer state
    Note over _LayerNormMLP._forward: Compute LayerNorm → FC1 → Activation → FC2
    Note over _LayerNormMLP._forward: Clear intermediate activations
    _LayerNormMLP._forward-->>_LayerNormMLP.forward: fc2_out
    _LayerNormMLP.forward-->>User: output

    Note over User,_LayerNormMLP._recompute: Backward Pass (checkpoint=True)
    User->>_LayerNormMLP.backward: backward(grad_output)
    _LayerNormMLP.backward->>_LayerNormMLP._recompute: _recompute(ctx)
    Note over _LayerNormMLP._recompute: Restore RNG state & quantizer state
    _LayerNormMLP._recompute->>_LayerNormMLP._forward: _forward(..., recompute_for_bwd=True)
    Note over _LayerNormMLP._forward: is_recomputation=True
    Note over _LayerNormMLP._forward: Recompute LayerNorm → FC1 → Activation
    Note over _LayerNormMLP._forward: Skip FC2 (not needed for gradients)
    Note over _LayerNormMLP._forward: Save recomputed activations for backward
    _LayerNormMLP._forward-->>_LayerNormMLP._recompute: (ctx, inputmat, ln_out, fc1_out, act_out, ...)
    Note over _LayerNormMLP._recompute: Restore current RNG & quantizer state
    _LayerNormMLP._recompute-->>_LayerNormMLP.backward: saved tensors
    Note over _LayerNormMLP.backward: Compute gradients using recomputed activations
    _LayerNormMLP.backward-->>User: input_grad, weight_grads
Loading

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jaimec00
Copy link
Author

jaimec00 commented Nov 5, 2025

@ptrendx Thank you for reviewing! I commited a few changes which I believe cover the points you made

  1. integrate duplicated tests into main testing framework
  2. handle CPU and CUDA RNG state tracking for reproducible FP4 quantization
  3. address greptile comments
  4. tidying up stray comments

A few points:

  • I changed the test/pytorch/selective_layernorm_mlp/compare.py file to tests/pytorch/layernorm_mlp/test_selective_activation_checkpoint.py. It is now a test that checks for memory reduction and numerical accuracy. I was not sure where this fit into the current testing framework, so it is a standalone test for now, let me know if I should change this
  • Also please let me know if I should try to address the point you made about custom RNG state trackers

Thanks!

@jaimec00 jaimec00 requested a review from ptrendx November 7, 2025 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Selective Activation Checkpointing with LayerNormMLP

2 participants