[recipe][megatron] refactor: isolate megatron recipe and extend distillation losses#4797
[recipe][megatron] refactor: isolate megatron recipe and extend distillation losses#4797process-cxr wants to merge 1 commit intoverl-project:mainfrom
Conversation
…tion losses Add megatron_distill_losses with support for RKL, JSD and KL+RL losses Expose distillation losses as callable operators in Megatron workers Move all Megatron-related training code under recipe/gkd/megatron Prepare directory structure for future FSDP training framework
There was a problem hiding this comment.
Code Review
This pull request refactors the Megatron-based GKD training pipeline by isolating Megatron-specific code and extending the distillation loss implementations. The new megatron_distill_losses.py file introduces several distillation loss functions (KL, RKL, JSD, etc.) as custom PyTorch autograd Functions.
My review focuses on the correctness and robustness of these new loss implementations. I've identified a critical issue in all custom autograd.Function implementations where an input tensor is modified in-place without being marked as dirty using ctx.mark_dirty(). This can lead to incorrect gradients and must be fixed. Additionally, I've found a high-severity issue in the configuration factory function where a broad except Exception can hide configuration errors, leading to silent failures.
The rest of the changes, which mainly involve refactoring file structures and updating call sites to use the new loss factory, look good and align with the goal of improving modularity.
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): | ||
| eps = 1e-20 |
There was a problem hiding this comment.
The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 61-62, 71), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): | |
| eps = 1e-20 | |
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): | |
| ctx.mark_dirty(vocab_parallel_logits) | |
| eps = 1e-20 |
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): | ||
| eps = 1e-20 |
There was a problem hiding this comment.
The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 153-154, 163), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): | |
| eps = 1e-20 | |
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): | |
| ctx.mark_dirty(vocab_parallel_logits) | |
| eps = 1e-20 |
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1): | ||
| eps = 1e-20 |
There was a problem hiding this comment.
The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 293-294, 303), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1): | |
| eps = 1e-20 | |
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1): | |
| ctx.mark_dirty(vocab_parallel_logits) | |
| eps = 1e-20 |
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float): | ||
| beta = min(max(float(beta), 1e-6), 1.0 - 1e-6) |
There was a problem hiding this comment.
The forward method modifies the vocab_parallel_logits tensor in-place (e.g., lines 449-450, 459), but it doesn't mark it as dirty. This can lead to incorrect gradient calculations. According to PyTorch documentation, you must use ctx.mark_dirty(vocab_parallel_logits) when modifying an input tensor in-place.
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float): | |
| beta = min(max(float(beta), 1e-6), 1.0 - 1e-6) | |
| def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float): | |
| ctx.mark_dirty(vocab_parallel_logits) | |
| beta = min(max(float(beta), 1e-6), 1.0 - 1e-6) |
| cfg = dict(loss_cfg) | ||
| else: | ||
| cfg = {} | ||
| except Exception: |
There was a problem hiding this comment.
Using a broad except Exception: can mask important errors during configuration processing. For instance, if OmegaConf.to_container fails for a reason other than a missing import, this will be silently ignored, and the default loss configuration will be used. This can lead to misconfigured training runs that are hard to debug. It's better to catch specific exceptions like ImportError.
| except Exception: | |
| except ImportError: |
|
recipe has been move to verl-project/verl-recipe as a submodule, #4795. Please submit a PR to ver-recipe. |
This PR refactors the Megatron-based GKD training pipeline and extends the distillation loss implementation to improve modularity and future extensibility.
Specifically, it:
megatron_distill_losseswith support for RKL, JSD, and KL+RL losses, implemented as reusable loss operatorsrecipe/gkd/megatronto clearly isolate it from other backendsWhat does this PR do?
This PR reorganizes the Megatron training recipe in GKD by isolating Megatron-specific code into a dedicated directory and extending the distillation loss module.
The refactor improves code clarity and separation of concerns, while the new loss operators make it easier to experiment with alternative distillation objectives. The new directory layout also lays the groundwork for adding an FSDP-based training backend.
Checklist Before Starting
[{modules}] {type}: {description}Test
This PR mainly introduces refactoring and loss function extensions that are not fully covered by existing CI tests.
The changes have been validated by running Megatron-based GKD training and verifying correct loss computation and training stability when switching between KL, RKL, JSD, and KL+RL losses.
API and Usage Example
No public API is changed in this PR. The new distillation losses are internal to the Megatron training pipeline.
Design & Code Changes
recipe/gkd/megatronChecklist Before Submitting