Skip to content

[recipe][megatron] refactor: isolate megatron recipe and extend distillation losses#4797

Closed
process-cxr wants to merge 1 commit intoverl-project:mainfrom
process-cxr:megatron/gkd-refactor
Closed

[recipe][megatron] refactor: isolate megatron recipe and extend distillation losses#4797
process-cxr wants to merge 1 commit intoverl-project:mainfrom
process-cxr:megatron/gkd-refactor

Conversation

@process-cxr
Copy link

This PR refactors the Megatron-based GKD training pipeline and extends the distillation loss implementation to improve modularity and future extensibility.

Specifically, it:

  • Adds megatron_distill_losses with support for RKL, JSD, and KL+RL losses, implemented as reusable loss operators
  • Exposes distillation losses for flexible invocation inside Megatron workers
  • Moves all Megatron-related training code under recipe/gkd/megatron to clearly isolate it from other backends
  • Prepares the directory structure for future integration of an FSDP-based training framework

What does this PR do?

This PR reorganizes the Megatron training recipe in GKD by isolating Megatron-specific code into a dedicated directory and extending the distillation loss module.

The refactor improves code clarity and separation of concerns, while the new loss operators make it easier to experiment with alternative distillation objectives. The new directory layout also lays the groundwork for adding an FSDP-based training backend.

Checklist Before Starting

  • Format the PR title as [{modules}] {type}: {description}

Test

This PR mainly introduces refactoring and loss function extensions that are not fully covered by existing CI tests.

The changes have been validated by running Megatron-based GKD training and verifying correct loss computation and training stability when switching between KL, RKL, JSD, and KL+RL losses.

API and Usage Example

No public API is changed in this PR. The new distillation losses are internal to the Megatron training pipeline.

Design & Code Changes

  • Introduced a unified distillation loss module for Megatron-based GKD training
  • Refactored directory structure to move Megatron training logic into recipe/gkd/megatron
  • Reduced coupling between Megatron and other training backends to improve maintainability

Checklist Before Submitting

  • Read the Contribute Guide
  • Apply pre-commit checks (basic formatting and lint checks applied)
  • Add / Update documentation (not required for this refactor)
  • Add unit or end-to-end tests (not feasible due to training-level validation)

…tion losses

Add megatron_distill_losses with support for RKL, JSD and KL+RL losses

Expose distillation losses as callable operators in Megatron workers

Move all Megatron-related training code under recipe/gkd/megatron

Prepare directory structure for future FSDP training framework
@CLAassistant
Copy link

CLAassistant commented Jan 5, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Megatron-based GKD training pipeline by isolating Megatron-specific code and extending the distillation loss implementations. The new megatron_distill_losses.py file introduces several distillation loss functions (KL, RKL, JSD, etc.) as custom PyTorch autograd Functions.

My review focuses on the correctness and robustness of these new loss implementations. I've identified a critical issue in all custom autograd.Function implementations where an input tensor is modified in-place without being marked as dirty using ctx.mark_dirty(). This can lead to incorrect gradients and must be fixed. Additionally, I've found a high-severity issue in the configuration factory function where a broad except Exception can hide configuration errors, leading to silent failures.

The rest of the changes, which mainly involve refactoring file structures and updating call sites to use the new loss factory, look good and align with the goal of improving modularity.

Comment on lines +50 to +51
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices):
eps = 1e-20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 61-62, 71), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.

Suggested change
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices):
eps = 1e-20
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices):
ctx.mark_dirty(vocab_parallel_logits)
eps = 1e-20

Comment on lines +142 to +143
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices):
eps = 1e-20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 153-154, 163), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.

Suggested change
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices):
eps = 1e-20
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices):
ctx.mark_dirty(vocab_parallel_logits)
eps = 1e-20

Comment on lines +276 to +277
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1):
eps = 1e-20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 293-294, 303), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.

Suggested change
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1):
eps = 1e-20
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1):
ctx.mark_dirty(vocab_parallel_logits)
eps = 1e-20

Comment on lines +436 to +437
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float):
beta = min(max(float(beta), 1e-6), 1.0 - 1e-6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 449-450, 459), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.

Suggested change
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float):
beta = min(max(float(beta), 1e-6), 1.0 - 1e-6)
def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float):
ctx.mark_dirty(vocab_parallel_logits)
beta = min(max(float(beta), 1e-6), 1.0 - 1e-6)

cfg = dict(loss_cfg)
else:
cfg = {}
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a broad except Exception: can mask important errors during configuration processing. For instance, if OmegaConf.to_container fails for a reason other than a missing import, this will be silently ignored, and the default loss configuration will be used. This can lead to misconfigured training runs that are hard to debug. It's better to catch specific exceptions like ImportError.

Suggested change
except Exception:
except ImportError:

@wuxibin89
Copy link
Collaborator

recipe has been move to verl-project/verl-recipe as a submodule, #4795. Please submit a PR to ver-recipe.

@wuxibin89 wuxibin89 closed this Jan 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants