Skip to content

Commit 2595de2

Browse files
authored
[megatron] fix: mbridge load optimizer dist_ckpt (verl-project#3850)
### What does this PR do? fix training resume from mbridge load optimizer dist_ckpt. This PR fix verl-project#3517 > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Gao Shiyuan <gaoshiyuan@baidu.com>
1 parent bb80a26 commit 2595de2

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,11 @@ def get_checkpoint_name(
231231
return os.path.join(common_path, basename)
232232

233233
def generate_state_dict(
234-
self, generate_model: bool = True, generate_optimizer: bool = True, generate_extra: bool = True
234+
self,
235+
generate_model: bool = True,
236+
generate_optimizer: bool = True,
237+
generate_extra: bool = True,
238+
is_loading: bool = False,
235239
):
236240
# For save dist checkpointing
237241
state_dict = {}
@@ -252,7 +256,7 @@ def generate_state_dict(
252256
# Optimizer State Dict
253257
if generate_optimizer:
254258
torch.distributed.barrier()
255-
optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict)
259+
optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict, is_loading=is_loading)
256260
state_dict["optimizer"] = optimizer_sharded_states
257261

258262
if self.lr_scheduler is not None:
@@ -292,11 +296,20 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
292296
if local_path is not None:
293297
assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist."
294298

299+
# For load optimizer dist_ckpt
300+
import transformer_engine
301+
302+
torch.serialization.add_safe_globals([torch.optim.AdamW])
303+
torch.serialization.add_safe_globals([transformer_engine.pytorch.optimizers.fused_adam.FusedAdam])
304+
295305
dist_checkpoint_path = get_dist_checkpoint_path(local_path)
296306

297307
# Get State Dict for loading
298308
sharded_state_dict = self.generate_state_dict(
299-
self.should_load_model and self.use_dist_checkpointing, self.should_load_optimizer, self.should_load_extra
309+
self.should_load_model and self.use_dist_checkpointing,
310+
self.should_load_optimizer,
311+
self.should_load_extra,
312+
is_loading=True,
300313
)
301314
log_with_rank(f"Generated state dict for loading: {sharded_state_dict.keys()}", rank=self.rank, logger=logger)
302315

0 commit comments

Comments
 (0)