diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index bc418f0b1fa..b506e9d610b 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -711,6 +711,21 @@ global_profiler: context: all stacks: all kw_args: {} + precision_debugger: + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + enable: false + config_path: null + data_dir: outputs/precision_debug + steps: null + stages: + - rollout_generate + - update_actor + - actor_compute_log_prob + - ref_compute_log_prob + - compute_values + - critic_update + - compute_rm_score + - train transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index a3baaf52af3..20263fd3efd 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -645,6 +645,21 @@ global_profiler: context: all stacks: all kw_args: {} + precision_debugger: + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + enable: false + config_path: null + data_dir: outputs/precision_debug + steps: null + stages: + - rollout_generate + - update_actor + - actor_compute_log_prob + - ref_compute_log_prob + - compute_values + - critic_update + - compute_rm_score + - train transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 76ba4c57575..296a94b041e 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -236,6 +236,13 @@ global_profiler: stacks: "all" # devices, record_context etc. kw_args: {} + precision_debugger: + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + enable: False + config_path: null + data_dir: "outputs/precision_debug" + steps: null + stages: ["rollout_generate", "update_actor", "actor_compute_log_prob", "ref_compute_log_prob", "compute_values", "critic_update", "compute_rm_score", "train"] # configs for TransferQueue transfer_queue: diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7489b522fa2..f88a8ba5123 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -300,6 +300,26 @@ global_profiler: # devices, record_context etc. kw_args: {} + # precision debugger config + precision_debugger: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + + # Whether to enable precision debugger + enable: False + + # Path to msprobe config.json + config_path: null + + # Dump root directory + data_dir: "outputs/precision_debug" + + # Steps to collect, null means all + steps: null + + # Stages to collect + stages: ["rollout_generate", "update_actor", "actor_compute_log_prob", "ref_compute_log_prob", "compute_values", "critic_update", "compute_rm_score", "train"] # configs for TransferQueue transfer_queue: diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 46271507934..4d01d4e9eca 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -271,6 +271,7 @@ def __init__( self.config = config self.reward_fn = reward_fn self.val_reward_fn = val_reward_fn + self._propagate_precision_debugger_config() self.hybrid_engine = config.actor_rollout_ref.hybrid_engine assert self.hybrid_engine, "Currently, only support hybrid engine" @@ -1073,6 +1074,18 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm and not self.use_reward_loop: self.rm_wg.stop_profile() + def _propagate_precision_debugger_config(self) -> None: + precision_cfg = OmegaConf.select(self.config, "global_profiler.global_tool_config.precision_debugger") + if precision_cfg is None: + return + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref") is not None: + self.config.actor_rollout_ref.precision_debugger = precision_cfg + if OmegaConf.select(self.config, "critic") is not None: + self.config.critic.precision_debugger = precision_cfg + if OmegaConf.select(self.config, "reward_model") is not None: + self.config.reward_model.precision_debugger = precision_cfg + def _get_dp_size(self, worker_group, role: str) -> int: """Get data parallel size from worker group dispatch info. @@ -1369,6 +1382,7 @@ def fit(self): else curr_step_profile ) batch: DataProto = DataProto.from_single_dict(batch_dict) + batch.meta_info["global_steps"] = self.global_steps batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature # add uid to batch diff --git a/verl/utils/import_utils.py b/verl/utils/import_utils.py index ee78b580675..4e97ba070e8 100644 --- a/verl/utils/import_utils.py +++ b/verl/utils/import_utils.py @@ -69,6 +69,15 @@ def is_trl_available(): return trl_spec is not None +@cache +def is_msprobe_available(): + try: + msprobe_spec = importlib.util.find_spec("msprobe") + except ModuleNotFoundError: + msprobe_spec = None + return msprobe_spec is not None + + def import_external_libs(external_libs=None): if external_libs is None: return diff --git a/verl/utils/profiler/__init__.py b/verl/utils/profiler/__init__.py index 73edb01a02c..84fc6a63d93 100644 --- a/verl/utils/profiler/__init__.py +++ b/verl/utils/profiler/__init__.py @@ -15,6 +15,7 @@ from ..device import is_npu_available from ..import_utils import is_nvtx_available from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer +from .precision_hook import PrecisionDebuggerLogger from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig # Select marker implementations by availability, but keep DistProfiler as our dispatcher @@ -37,4 +38,5 @@ "ProfilerConfig", "simple_timer", "marked_timer", + "PrecisionDebuggerLogger", ] diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py index 4430d758698..442982d9f12 100644 --- a/verl/utils/profiler/config.py +++ b/verl/utils/profiler/config.py @@ -78,6 +78,29 @@ def __post_init__(self) -> None: assert self.stack_depth > 0, f"stack_depth must be positive, got {self.stack_depth}" +@dataclass +class PrecisionDebuggerToolConfig(BaseConfig): + """Precision debugger tool config (msprobe).""" + + enable: bool = False + config_path: Optional[str] = None + data_dir: str = "outputs/precision_debug" + steps: Optional[list[int]] = None + stages: Optional[list[str]] = None + strict: bool = False + + def __post_init__(self) -> None: + assert isinstance(self.enable, bool), f"enable must be bool, got {type(self.enable)}" + if self.config_path is not None: + assert isinstance(self.config_path, str), f"config_path must be str, got {type(self.config_path)}" + assert isinstance(self.data_dir, str), f"data_dir must be str, got {type(self.data_dir)}" + if self.steps is not None: + assert isinstance(self.steps, list), f"steps must be list[int], got {type(self.steps)}" + if self.stages is not None: + assert isinstance(self.stages, list), f"stages must be list[str], got {type(self.stages)}" + assert isinstance(self.strict, bool), f"strict must be bool, got {type(self.strict)}" + + @dataclass class NPUToolConfig(NsightToolConfig): """NPU profiler too; config.""" diff --git a/verl/utils/profiler/precision_debugger_profile.py b/verl/utils/profiler/precision_debugger_profile.py new file mode 100644 index 00000000000..be49a9ac8db --- /dev/null +++ b/verl/utils/profiler/precision_debugger_profile.py @@ -0,0 +1,137 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from dataclasses import asdict +from typing import Optional + +from verl.utils.import_utils import is_msprobe_available +from verl.utils.profiler.config import PrecisionDebuggerToolConfig + + +_GLOBAL_LOCK = threading.Lock() + + +class PrecisionDebuggerProfiler: + """Precision debugger wrapper for msprobe. + + This class implements a minimal start/stop contract and is intentionally + not a DistProfiler subclass to keep the dependency one-way. + """ + + def __init__(self, precision_cfg, rank: Optional[int] = None): + self.rank = rank + self.precision_cfg = self._normalize_config(precision_cfg) + self._active_lock: Optional[threading.Lock] = None + self._enabled = self._is_enabled(self.precision_cfg) + self._available = is_msprobe_available() + self._debugger = None + + @staticmethod + def _normalize_config(precision_cfg) -> PrecisionDebuggerToolConfig: + if precision_cfg is None: + return PrecisionDebuggerToolConfig() + if isinstance(precision_cfg, PrecisionDebuggerToolConfig): + return precision_cfg + if hasattr(precision_cfg, "to_container"): + precision_cfg = precision_cfg.to_container(resolve=True) + if isinstance(precision_cfg, dict): + return PrecisionDebuggerToolConfig(**precision_cfg) + return PrecisionDebuggerToolConfig(**asdict(precision_cfg)) + + @staticmethod + def _is_enabled(precision_cfg: PrecisionDebuggerToolConfig) -> bool: + return bool(precision_cfg.enable) + + def _should_collect(self, stage: str, global_step: Optional[int]) -> bool: + if not self._enabled: + return False + if self.precision_cfg.stages is not None and stage not in set(self.precision_cfg.stages): + return False + if self.precision_cfg.steps is not None and global_step is not None: + if int(global_step) not in set(self.precision_cfg.steps): + return False + return True + + def _get_lock(self) -> threading.Lock: + return _GLOBAL_LOCK + + def start(self, stage: str, global_step: Optional[int] = None, model=None) -> bool: + if not self._should_collect(stage=stage, global_step=global_step): + return False + if not self._available: + if self.precision_cfg.strict: + raise ImportError("msprobe is not available but precision_debugger.strict is True") + return False + + config_path = self.precision_cfg.config_path + data_dir = self.precision_cfg.data_dir + if not config_path or not data_dir: + return False + + step_tag = f"step_{global_step}" if global_step is not None else "step_unknown" + rank_tag = f"rank_{self.rank}" if self.rank is not None else "rank_unknown" + dump_path = os.path.join(data_dir, step_tag, stage, rank_tag) + os.makedirs(dump_path, exist_ok=True) + + lock = self._get_lock() + lock.acquire() + self._active_lock = lock + try: + from msprobe.pytorch import PrecisionDebugger + + debugger = None + if self._debugger is None: + debugger = PrecisionDebugger(config_path=config_path, dump_path=dump_path) + if debugger is None: + if self.precision_cfg.strict: + raise RuntimeError("Failed to create PrecisionDebugger instance") + return False + self._debugger = debugger + else: + debugger = self._debugger + if hasattr(debugger, "service") and hasattr(debugger.service, "config"): + debugger.service.config.dump_path = dump_path + debugger.start(model) + return True + except Exception: + self._release_lock() + if self.precision_cfg.strict: + raise + return False + + def stop(self, started: bool = False, step: bool = False) -> None: + if not started: + self._release_lock() + return + if not self._available: + self._release_lock() + return + try: + debugger = self._debugger + if debugger is None: + return + debugger.stop() + if step: + if hasattr(debugger, "step"): + debugger.step() + finally: + self._release_lock() + + def _release_lock(self) -> None: + lock = self._active_lock + self._active_lock = None + if lock is not None: + lock.release() diff --git a/verl/utils/profiler/precision_hook.py b/verl/utils/profiler/precision_hook.py new file mode 100644 index 00000000000..b38f3fc510b --- /dev/null +++ b/verl/utils/profiler/precision_hook.py @@ -0,0 +1,131 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from .precision_debugger_profile import PrecisionDebuggerProfiler + + +def _resolve_attr(obj, attr): + if not isinstance(attr, str): + return None + if "." in attr: + current = obj + for part in attr.split("."): + current = getattr(current, part, None) + if current is None: + return None + return current + return getattr(obj, attr, None) + + +def _get_model(self_instance, precision_model_attr): + if precision_model_attr is None: + return None + if isinstance(precision_model_attr, (list, tuple)): + for attr in precision_model_attr: + val = _resolve_attr(self_instance, attr) + if val is not None: + return val + return None + return _resolve_attr(self_instance, precision_model_attr) + + +def _get_global_step(self_instance, args, kwargs, precision_global_step_attr: Optional[str]): + for val in list(args) + list(kwargs.values()): + if hasattr(val, "meta_info"): + meta = getattr(val, "meta_info") + if isinstance(meta, dict) and "global_steps" in meta: + return meta.get("global_steps") + if isinstance(val, dict) and "global_steps" in val: + return val.get("global_steps") + if precision_global_step_attr and hasattr(self_instance, precision_global_step_attr): + return getattr(self_instance, precision_global_step_attr) + if hasattr(self_instance, "precision_global_step"): + return getattr(self_instance, "precision_global_step") + return None + + +def build_precision_impl(self_instance, precision_stage: Optional[str]): + precision_cfg = getattr(self_instance, "precision_debugger_cfg", None) + if not precision_cfg or not precision_stage: + return None + rank = getattr(getattr(self_instance, "profiler", None), "rank", None) + return PrecisionDebuggerProfiler(precision_cfg, rank=rank) + + +def start_precision( + precision_impl: Optional[PrecisionDebuggerProfiler], + self_instance, + args, + kwargs, + precision_stage: Optional[str], + precision_model_attr, + precision_global_step_attr: Optional[str], +) -> bool: + if precision_impl is None: + return False + global_step = _get_global_step(self_instance, args, kwargs, precision_global_step_attr) + model = _get_model(self_instance, precision_model_attr) + return precision_impl.start(stage=precision_stage, global_step=global_step, model=model) + + +def stop_precision( + precision_impl: Optional[PrecisionDebuggerProfiler], + started: bool, + precision_step: bool, +) -> None: + if precision_impl is None: + return + precision_impl.stop(started=started, step=precision_step) + + +class PrecisionDebuggerLogger: + """Decorator to run PrecisionDebugger around a method call. + + Example: + >>> @PrecisionDebuggerLogger(stage="train", model_attr="actor_module") + >>> def update_policy(self, batch): ... + """ + + def __init__( + self, + stage: str, + model_attr: Optional[object] = None, + global_step_attr: Optional[str] = None, + step: bool = False, + ): + self.stage = stage + self.model_attr = model_attr + self.global_step_attr = global_step_attr + self.step = step + + def __call__(self, decorated_function: callable): + def wrapper(self_instance, *args, **kwargs): + precision_impl = build_precision_impl(self_instance, self.stage) + started = start_precision( + precision_impl, + self_instance, + args, + kwargs, + self.stage, + self.model_attr, + self.global_step_attr, + ) + try: + return decorated_function(self_instance, *args, **kwargs) + finally: + stop_precision(precision_impl, started, self.step) + + return wrapper diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py index 8e3145a66bb..7984181ce9f 100644 --- a/verl/utils/profiler/profile.py +++ b/verl/utils/profiler/profile.py @@ -77,6 +77,7 @@ class DistProfiler: - npu: NPUProfiler (Ascend) - torch: PyTorch torch.profiler wrapper - torch_memory: Torch CUDA memory snapshot dump + - precision_debugger: msprobe precision debugger """ def __init__( @@ -96,6 +97,7 @@ def __init__( self._tool = getattr(config, "tool", None) self._enable = config.enable self._this_step = False + self.rank = rank # Normalize rank selection self._this_rank = False @@ -125,6 +127,10 @@ def __init__( self._impl = _Torch(rank=rank, config=config, tool_config=tool_config) elif self._tool == "torch_memory": self._impl = TorchMemoryProfiler(rank=rank, config=config, tool_config=tool_config) + elif self._tool == "precision_debugger": + from .precision_debugger_profile import PrecisionDebuggerProfiler as _Precision + + self._impl = _Precision(precision_cfg=tool_config, rank=rank) else: # Fallback to a no-op impl self._impl = _NoOpProfiler() @@ -151,6 +157,7 @@ def stop(self): self._this_step = False return getattr(self._impl, "stop", lambda: None)() + @classmethod def annotate( cls, @@ -160,29 +167,34 @@ def annotate( category: Optional[str] = None, **kwargs_outer, ) -> Callable: - def decorator(func): - @functools.wraps(func) + def _decorate_with_profiler(impl, func_inner): + if hasattr(impl, "annotate"): + return impl.annotate(message=message, color=color, domain=domain, category=category, **kwargs_outer)( + func_inner + ) + return func_inner + + def _should_profile(self_instance) -> bool: + profiler = getattr(self_instance, "profiler", None) + return ( + profiler + and profiler.check_enable() + and profiler.check_this_step() + and profiler.check_this_rank() + ) + + def decorator(func_inner): + @functools.wraps(func_inner) def wrapper(self_instance, *args, **kwargs_inner): - profiler = getattr(self_instance, "profiler", None) - if ( - not profiler - or not profiler.check_enable() - or not profiler.check_this_step() - or not profiler.check_this_rank() - ): - return func(self_instance, *args, **kwargs_inner) - - impl = profiler._impl - if hasattr(impl, "annotate"): - try: - actual_decorator = impl.annotate( - message=message, color=color, domain=domain, category=category, **kwargs_outer - ) - - return actual_decorator(func)(self_instance, *args, **kwargs_inner) - except Exception: - return func(self_instance, *args, **kwargs_inner) - return func(self_instance, *args, **kwargs_inner) + try: + if _should_profile(self_instance): + impl = self_instance.profiler._impl + wrapped = _decorate_with_profiler(impl, func_inner) + try: + return wrapped(self_instance, *args, **kwargs_inner) + except Exception: + return func_inner(self_instance, *args, **kwargs_inner) + return func_inner(self_instance, *args, **kwargs_inner) return wrapper @@ -197,6 +209,8 @@ def stop(self): return + + class TorchMemoryProfiler: """Profiler that dumps CUDA memory snapshots at step boundaries. diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index d524f0e2ba1..df98ac58b14 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -31,12 +31,13 @@ from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input from verl.utils.device import get_device_id, get_device_name from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ -from verl.utils.profiler import GPUMemoryLogger +from verl.utils.profiler import DistProfiler, GPUMemoryLogger, PrecisionDebuggerLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.utils.torch_dtypes import PrecisionType from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs + from verl.workers.actor import BasePPOActor from verl.workers.config import ActorConfig @@ -388,6 +389,8 @@ def _forward_micro_batch( outputs["sum_pi_squared"] = sum_pi_squared return outputs + @PrecisionDebuggerLogger(stage="update_actor", model_attr="actor_module", step=True) + @DistProfiler.annotate() def _optimizer_step(self): assert self.config.grad_clip is not None if self.scaler is not None: @@ -499,6 +502,8 @@ def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> return outputs @GPUMemoryLogger(role="dp actor", logger=logger) + @PrecisionDebuggerLogger(stage="train", model_attr="actor_module") + @DistProfiler.annotate() def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() @@ -581,6 +586,7 @@ def update_policy(self, data: DataProto): outputs = self._forward_micro_batch( model_inputs, temperature=temperature, calculate_entropy=calculate_entropy ) + # keep handle active across backward log_prob = outputs["log_probs"] entropy = outputs["entropys"] if calculate_entropy else None diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 7fdaa6e9811..f32a097f6e0 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -39,6 +39,7 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty from verl.utils.device import get_device_id, get_torch_device +from verl.utils.profiler import DistProfiler, PrecisionDebuggerLogger from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction from verl.utils.megatron.router_replay_utils import ( @@ -54,6 +55,7 @@ from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor + from verl.workers.actor import BasePPOActor from verl.workers.config import MtpConfig @@ -753,6 +755,8 @@ def logits_processor(logits, label, label_mask): return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) + @PrecisionDebuggerLogger(stage="train", model_attr="actor_module") + @DistProfiler.annotate() def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = False) -> dict: """Update the policy with an iterator of DataProto @@ -805,7 +809,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. - update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() + update_successful, grad_norm, num_zeros_in_grad = self._optimizer_step_with_precision() data = {"actor/grad_norm": grad_norm} append_to_dict(metrics, data) @@ -822,3 +826,8 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals self.actor_optimizer.zero_grad() get_torch_device().empty_cache() return metrics + + @PrecisionDebuggerLogger(stage="update_actor", model_attr="actor_module", step=True) + @DistProfiler.annotate() + def _optimizer_step_with_precision(self): + return self.actor_optimizer.step() diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index a4f0d9f4c77..bd95c65643e 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -37,7 +37,13 @@ from verl.utils.flops_counter import FlopsCounter from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.metric.utils import Metric -from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + PrecisionDebuggerLogger, + ProfilerConfig, + log_gpu_memory_usage, +) from verl.utils.py_functional import append_to_dict from verl.utils.tensordict_utils import maybe_fix_3d_position_ids from verl.utils.torch_functional import allgather_dict_into_dict @@ -69,6 +75,7 @@ def __init__(self, config: TrainingWorkerConfig): initialize_global_process_group_ray(timeout_second=None) self.config = config + self.precision_debugger_cfg = config.get("precision_debugger", None) self.model_config = self.config.model_config self.engine_config = self.config.engine_config self.optimizer_config = self.config.optimizer_config @@ -551,12 +558,14 @@ def init_model(self): ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) + @PrecisionDebuggerLogger(stage="ref_compute_log_prob", model_attr="ref") @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: output = self.ref.infer_batch(data=data) return output.cpu() if output is not None else None @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @PrecisionDebuggerLogger(stage="actor_compute_log_prob", model_attr="actor") @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: TensorDict) -> TensorDict: output = self.actor.infer_batch(data) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index a5e72f84f92..d76e7f8ce29 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -80,7 +80,14 @@ from verl.utils.import_utils import import_external_libs from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import compute_position_id_with_mask, convert_weight_keys -from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + PrecisionDebuggerLogger, + ProfilerConfig, + log_gpu_memory_usage, + simple_timer, +) from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max from verl.utils.py_functional import convert_to_regular_types from verl.utils.ray_utils import get_event_loop @@ -147,6 +154,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs): Worker.__init__(self) self.config = config + self.precision_debugger_cfg = config.get("precision_debugger", None) import torch.distributed if not torch.distributed.is_initialized(): @@ -821,6 +829,8 @@ def init_model(self): self.actor = DataParallelPPOActor( config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer ) + self.actor.precision_debugger_cfg = self.precision_debugger_cfg + self.actor.profiler = self.profiler if self._is_rollout: self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) @@ -935,7 +945,12 @@ def update_actor(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) - @DistProfiler.annotate(color="red", role="rollout_generate") + @DistProfiler.annotate( + color="red", + role="rollout_generate", + precision_stage="rollout_generate", + precision_model_attr=("actor_module_fsdp", "actor_module"), + ) def generate_sequences(self, prompts: DataProto): # Support all hardwares assert self._is_rollout @@ -985,7 +1000,12 @@ def generate_sequences(self, prompts: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + @DistProfiler.annotate( + color="blue", + role="actor_compute_log_prob", + precision_stage="actor_compute_log_prob", + precision_model_attr=("actor_module_fsdp", "actor_module"), + ) def compute_log_prob(self, data: DataProto): # when is_lora is True, we use the actor without lora applied to calculate the log_prob # which is mostly used for ref log_prob calculation @@ -1037,7 +1057,12 @@ def compute_log_prob(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + @DistProfiler.annotate( + color="olive", + role="ref_compute_log_prob", + precision_stage="ref_compute_log_prob", + precision_model_attr=("ref_module_fsdp", "actor_module_fsdp"), + ) def compute_ref_log_prob(self, data: DataProto): if self._is_lora: # if _is_lora, actor without lora applied is the ref @@ -1536,7 +1561,12 @@ def init_model(self): ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="cyan", role="compute_values") + @DistProfiler.annotate( + color="cyan", + role="compute_values", + precision_stage="compute_values", + precision_model_attr="critic_module_fsdp", + ) def compute_values(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -1556,7 +1586,12 @@ def compute_values(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="pink", role="critic_update") + @DistProfiler.annotate( + color="pink", + role="critic_update", + precision_stage="critic_update", + precision_model_attr="critic_module_fsdp", + ) def update_critic(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -1924,7 +1959,12 @@ def _switch_chat_template(self, data: DataProto): return DataProto.from_dict(rm_inputs) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) - @DistProfiler.annotate(color="brown", role="compute_rm_score") + @DistProfiler.annotate( + color="brown", + role="compute_rm_score", + precision_stage="compute_rm_score", + precision_model_attr="reward_model_module_fsdp", + ) def compute_rm_score(self, data: DataProto): import itertools diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 14aa17949f9..846a28da83a 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -253,6 +253,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config: DictConfig, role: str, **kwargs): Worker.__init__(self) self.config = config + self.precision_debugger_cfg = config.get("precision_debugger", None) if repatch is not None: # NPU MindSpeed patch, will be refactored with MindSpeedEngine. repatch(self.config.actor.megatron.get("override_transformer_config", {})) @@ -612,6 +613,8 @@ def init_model(self): actor_optimizer=self.actor_optimizer, mtp_config=self.config.model.mtp if self.config.model.mtp.enable else None, ) + self.actor.precision_debugger_cfg = self.precision_debugger_cfg + self.actor.profiler = self.profiler print(f"routing replay layers: {len(RouterReplay.router_instances)}") log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) @@ -753,6 +756,7 @@ def update_actor(self, data: DataProto): micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size dataloader = self.actor.make_minibatch_iterator(data=data) + self.actor.precision_global_step = data.meta_info.get("global_steps", None) with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(dataloader=dataloader) delta_time = timer.last @@ -786,7 +790,12 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @GPUMemoryLogger(role="generate_sequences", logger=logger) - @DistProfiler.annotate(color="red", role="rollout_generate") + @DistProfiler.annotate( + color="red", + role="rollout_generate", + precision_stage="rollout_generate", + precision_model_attr="actor_module", + ) def generate_sequences(self, prompts: DataProto): assert self._is_rollout prompts = prompts.to(get_device_name()) @@ -836,7 +845,12 @@ def generate_sequences(self, prompts: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) - @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + @DistProfiler.annotate( + color="olive", + role="ref_compute_log_prob", + precision_stage="ref_compute_log_prob", + precision_model_attr=("ref_module", "actor_module"), + ) def compute_ref_log_prob(self, data: DataProto): if self.peft_cls is not None: # if is lora, actor without lora applied is the ref @@ -862,7 +876,12 @@ def compute_ref_log_prob(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_log_prob", logger=logger) - @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + @DistProfiler.annotate( + color="blue", + role="actor_compute_log_prob", + precision_stage="actor_compute_log_prob", + precision_model_attr="actor_module", + ) def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: @@ -1212,7 +1231,12 @@ def init_model(self): ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="cyan", role="compute_values") + @DistProfiler.annotate( + color="cyan", + role="compute_values", + precision_stage="compute_values", + precision_model_attr="critic_module", + ) def compute_values(self, data: DataProto): micro_batch_size = self.config.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -1229,7 +1253,12 @@ def compute_values(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="pink", role="critic_update") + @DistProfiler.annotate( + color="pink", + role="critic_update", + precision_stage="critic_update", + precision_model_attr="critic_module", + ) def update_critic(self, data: DataProto): data = data.to(get_device_id()) @@ -1453,7 +1482,12 @@ def init_model(self): # TODO: reward model use itself tokenizer instead of sft tokenizer # the input_ids, responses, attention_mask and position_ids may be different! @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) - @DistProfiler.annotate(color="brown", role="compute_rm_score") + @DistProfiler.annotate( + color="brown", + role="compute_rm_score", + precision_stage="compute_rm_score", + precision_model_attr="reward_model_module", + ) def compute_rm_score(self, data: DataProto): data.meta_info["micro_batch_size"] = self.config.micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 21a620dbc35..f1224160d0b 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -168,6 +168,8 @@ def __init__( self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) + self.precision_debugger_cfg = getattr(self.config, "precision_debugger", None) + self.precision_global_step = None max_position_embeddings = get_max_position_embeddings(self.model_config.hf_config) if self.config.max_model_len is None: self.config.max_model_len = max_position_embeddings @@ -220,6 +222,9 @@ def get_master_address(self): """Get master address and port for init NCCL process group.""" return self._master_address, self._master_port + def set_precision_global_step(self, global_step: int) -> None: + self.precision_global_step = global_step + def get_server_address(self): """Get http server address and port.""" assert self._server_port is not None, "http server is not launched, port is None" diff --git a/verl/workers/rollout/sglang_rollout/http_server_engine.py b/verl/workers/rollout/sglang_rollout/http_server_engine.py index 6822a9e52da..f3a01b842b9 100644 --- a/verl/workers/rollout/sglang_rollout/http_server_engine.py +++ b/verl/workers/rollout/sglang_rollout/http_server_engine.py @@ -103,6 +103,10 @@ async def _read_async_response(resp: aiohttp.ClientResponse) -> dict[str, Any]: } +def _launch_server_with_precision(server_args: ServerArgs) -> None: + launch_server(server_args) + + def launch_server_process( server_args: ServerArgs, timeout: float = DEFAULT_TIMEOUT, @@ -134,7 +138,7 @@ def launch_server_process( This is for consistency; except for the process obtained by node_rank = 0, other processes have no actual effect. """ - p = multiprocessing.Process(target=launch_server, args=(server_args,)) + p = multiprocessing.Process(target=_launch_server_with_precision, args=(server_args,)) if server_args.node_rank != 0 or not first_rank_in_node: logger.info(f"Server process started with PID {p.pid} for node rank {server_args.node_rank}", flush=True) return p diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 2be15fc5b05..c160efd4cb7 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -190,6 +190,11 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 """ await self._init_server_adapter() + if not hasattr(self, "_precision_global_step"): + self._precision_global_step = -1 + self._precision_global_step += 1 + if hasattr(self, "server_actor") and self.server_actor is not None: + await self.server_actor.set_precision_global_step.remote(self._precision_global_step) update_weights_bucket_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) << 20 if self.config.get("quantization", None) == "fp8": diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index a0e738a25d4..fe0b32cbf6a 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -172,11 +172,54 @@ def monkey_patch_model(self, vocab_size: int): monkey_patch_compute_logits(self.model_runner.model, vocab_size) # patch weight loader to support MoE model patch_vllm_moe_model_weight_loader(self.model_runner.model) + self._attach_precision_debugger() + + def _attach_precision_debugger(self) -> None: + cfg_json = os.getenv("VERL_PRECISION_DEBUGGER_CONFIG_JSON", None) + if not cfg_json: + return + try: + precision_cfg = json.loads(cfg_json) + except Exception: + return + if not precision_cfg or not precision_cfg.get("enable", False): + return + + model = self.model_runner.model + if not hasattr(model, "forward"): + return + + original_forward = model.forward + extension_self = self + + def precision_forward(self, *args, **kwargs): + from verl.utils.profiler.precision_debugger_profile import PrecisionDebuggerProfiler + + if not hasattr(extension_self, "_precision_global_step"): + extension_self._precision_global_step = None + profiler = PrecisionDebuggerProfiler( + precision_cfg, rank=getattr(extension_self, "local_rank", None) + ) + started = profiler.start( + stage="rollout_generate", + global_step=getattr(extension_self, "_precision_global_step", None), + model=model, + ) + try: + return original_forward(*args, **kwargs) + finally: + profiler.stop(started=started) + + model.forward = MethodType(precision_forward, model) def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False, use_shm: bool = False): """Update the weights of the rollout model.""" from vllm.platforms import current_platform + if not hasattr(self, "_precision_global_step"): + self._precision_global_step = -1 + self._precision_global_step += 1 + if current_platform.device_type == "npu" and self.device is None: self.device = torch.device(f"npu:{self.local_rank}") diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index f4e26f13fde..4f327699547 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -112,6 +112,7 @@ def __init__( self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) + self.precision_debugger_cfg = getattr(self.config, "precision_debugger", None) max_position_embeddings = get_max_position_embeddings(self.model_config.hf_config) if self.config.max_model_len is None: self.config.max_model_len = max_position_embeddings @@ -171,6 +172,19 @@ def __init__( f"data_parallel_rpc_port: {self._dp_rpc_port}, data_parallel_master_port: {self._dp_master_port}" ) + def _export_precision_debugger_env(self) -> None: + precision_cfg = self.precision_debugger_cfg + if not precision_cfg or not getattr(precision_cfg, "enable", False): + return + try: + if hasattr(precision_cfg, "to_container"): + precision_cfg = precision_cfg.to_container(resolve=True) + if isinstance(precision_cfg, dict): + os.environ["VERL_PRECISION_DEBUGGER_CONFIG_JSON"] = json.dumps(precision_cfg) + except Exception: + # Best-effort only; precision debugger should not block server launch + return + def get_master_address(self): """Get master address and port for data parallel. Returns: @@ -207,6 +221,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non self._dp_rpc_port = dp_rpc_port # 1. setup vllm serve cli args + self._export_precision_debugger_env() engine_kwargs = self.config.get("engine_kwargs", {}).get("vllm", {}) or {} engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} if self.config.get("limit_images", None): # support for multi-image data