Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,24 @@
from megatron.core.transformer.enums import AttnBackend
from transformers import GenerationConfig

# For load optimizer dist_ckpt
import transformer_engine
torch.serialization.add_safe_globals([torch.optim.AdamW])
torch.serialization.add_safe_globals([transformer_engine.pytorch.optimizers.fused_adam.FusedAdam])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying global state at module import time is generally discouraged as it can lead to hard-to-debug side effects and makes the system's behavior dependent on import order. Here, torch.serialization._safe_globals is modified when this module is imported.

This logic should be moved into an explicit initialization function that is called once during your application's setup, for example at the beginning of the load_checkpoint method. This makes the dependency explicit and avoids unintended consequences in other parts of the codebase that might also use torch.serialization.


from verl.models.weight_loader_registry import get_weight_saver
from verl.utils.device import get_device_name, get_torch_device

Check failure on line 36 in verl/utils/checkpoint/megatron_checkpoint_manager.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E402)

verl/utils/checkpoint/megatron_checkpoint_manager.py:36:1: E402 Module level import not at top of file
from verl.utils.fs import is_non_local, local_mkdir_safe

Check failure on line 37 in verl/utils/checkpoint/megatron_checkpoint_manager.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E402)

verl/utils/checkpoint/megatron_checkpoint_manager.py:37:1: E402 Module level import not at top of file
from verl.utils.logger import log_with_rank

Check failure on line 38 in verl/utils/checkpoint/megatron_checkpoint_manager.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E402)

verl/utils/checkpoint/megatron_checkpoint_manager.py:38:1: E402 Module level import not at top of file
from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing

Check failure on line 39 in verl/utils/checkpoint/megatron_checkpoint_manager.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E402)

verl/utils/checkpoint/megatron_checkpoint_manager.py:39:1: E402 Module level import not at top of file
from verl.utils.megatron_utils import (

Check failure on line 40 in verl/utils/checkpoint/megatron_checkpoint_manager.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E402)

verl/utils/checkpoint/megatron_checkpoint_manager.py:40:1: E402 Module level import not at top of file
get_dist_checkpoint_path,
get_hf_model_checkpoint_path,
get_transformer_config_checkpoint_path,
)

Check failure on line 45 in verl/utils/checkpoint/megatron_checkpoint_manager.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E402)

verl/utils/checkpoint/megatron_checkpoint_manager.py:41:1: E402 Module level import not at top of file
from .checkpoint_manager import BaseCheckpointManager

Check failure on line 47 in verl/utils/checkpoint/megatron_checkpoint_manager.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E402)

verl/utils/checkpoint/megatron_checkpoint_manager.py:47:1: E402 Module level import not at top of file
# Setup logging
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))
Expand Down Expand Up @@ -231,7 +236,11 @@
return os.path.join(common_path, basename)

def generate_state_dict(
self, generate_model: bool = True, generate_optimizer: bool = True, generate_extra: bool = True
self,
generate_model: bool = True,
generate_optimizer: bool = True,
generate_extra: bool = True,
is_loading: bool = False
):
# For save dist checkpointing
state_dict = {}
Expand All @@ -252,7 +261,7 @@
# Optimizer State Dict
if generate_optimizer:
torch.distributed.barrier()
optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict)
optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict, is_loading=is_loading)
state_dict["optimizer"] = optimizer_sharded_states

if self.lr_scheduler is not None:
Expand Down Expand Up @@ -296,7 +305,10 @@

# Get State Dict for loading
sharded_state_dict = self.generate_state_dict(
self.should_load_model and self.use_dist_checkpointing, self.should_load_optimizer, self.should_load_extra
self.should_load_model and self.use_dist_checkpointing,
self.should_load_optimizer,
self.should_load_extra,
is_loading=True
)
log_with_rank(f"Generated state dict for loading: {sharded_state_dict.keys()}", rank=self.rank, logger=logger)

Expand Down
Loading