-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[megatron] feat: use mbridge as megatron adaptor #2064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
99b41fe
961bdf1
9e5ef70
1f2871a
94b9cd7
f464a23
8527ee1
a79d98b
563aac0
b2fc082
b7f9731
9947363
da586f2
1773864
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| try: | ||
| from mbridge import AutoBridge | ||
| from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model | ||
| except ImportError: | ||
| import subprocess | ||
| import sys | ||
|
|
||
| print("mbridge package not found. This package is required for model bridging functionality.") | ||
| print("Install mbridge with `pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps`") | ||
|
|
||
| def install_mbridge(): | ||
| try: | ||
| subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/ISEEKYAN/mbridge.git", "--no-deps"]) | ||
| except subprocess.CalledProcessError: | ||
| print("Failed to install mbridge") | ||
| raise | ||
|
|
||
| install_mbridge() | ||
| from mbridge import * | ||
|
|
||
| __all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,6 +94,7 @@ actor_rollout_ref: | |
| dist_checkpointing_path: null | ||
| seed: 42 | ||
| override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage | ||
| use_mbridge: False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually use dist_checkpointing and mbridge should be an either-or relation? Maybe we shall use some naming like Also, we may need to consider how this combined with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ccclyu @dataproblems , could you give some advice on the API design? How checkpoint:
pre_load: # first time load
format: [hf, dist_ckpt]. # hf default use_mbridge
load:
format: [hf, dist_ckpt]
save:
format: [hf, dist_ckpt]But maybe this will break some APIs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the current way is ok in the config, since it's possible to have some relationship between Implementation wise, I would add an abstraction that captures the checkpoint saving logic away from the checkpoint manager and the workers, that way the code base for checkpoint manager and workers relies on a stable interface and allows you to provide more options while modifying less code. Is that something that you were looking for, or am I missing the point here?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, Your latter part makes sense to me, it's a refactor point, here I hope to focus on API design. So
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks good to me.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current config LGTM. Long-term wise, if we migrate to
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personally I prefer use HF format in all lifetime of training. We would deprecate |
||
| profile: # profile the actor model in `update_policy` | ||
| use_profile: False # open it when you want to profile the actor model | ||
| profile_ranks: null # list, you can specify the ranks to profile | ||
|
|
@@ -124,6 +125,7 @@ actor_rollout_ref: | |
| dist_checkpointing_path: null | ||
| seed: ${actor_rollout_ref.actor.megatron.seed} | ||
| override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} | ||
| use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} | ||
| profile: | ||
| use_profile: False | ||
| profile_ranks: null | ||
|
|
@@ -245,6 +247,7 @@ critic: | |
| dist_checkpointing_path: null | ||
| seed: ${actor_rollout_ref.actor.megatron.seed} | ||
| override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} | ||
| use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} | ||
| load_weight: True | ||
| ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} | ||
| ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu | ||
|
|
@@ -284,6 +287,7 @@ reward_model: | |
| dist_checkpointing_path: null | ||
| seed: ${actor_rollout_ref.actor.megatron.seed} | ||
| override_transformer_config: {} | ||
| use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} | ||
| model: | ||
| input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical | ||
| path: ~/models/FsfairX-LLaMA3-RM-v0.1 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.