diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 5ea7c0e19d0..e5b03f6496d 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -71,6 +71,6 @@ FSDP Utilities Debug Utilities ------------------- -.. automodule:: verl.utils.debug +.. automodule:: verl.utils.profiler :members: log_gpu_memory_usage, GPUMemoryLogger diff --git a/recipe/dapo/config/dapo_trainer.yaml b/recipe/dapo/config/dapo_trainer.yaml index 0c518b7a93a..47ac00fd6a0 100644 --- a/recipe/dapo/config/dapo_trainer.yaml +++ b/recipe/dapo/config/dapo_trainer.yaml @@ -19,6 +19,7 @@ reward_model: algorithm: filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig enable: False # We try to avoid forgetting to set enable metric: null # acc / score / seq_reward / seq_final_reward / ... max_num_gen_batches: 0 # Non-positive values mean no upper limit diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index ac14026785d..11eedfdb8f9 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -40,7 +40,7 @@ compute_advantage, compute_response_mask, ) -from verl.utils.debug import marked_timer +from verl.utils.profiler import marked_timer class RayDAPOTrainer(RayPPOTrainer): diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py index 1ffb68cea32..545a1744ed7 100644 --- a/recipe/dapo/main_dapo.py +++ b/recipe/dapo/main_dapo.py @@ -15,20 +15,44 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import copy import os import socket import hydra import ray -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from verl.trainer.ppo.reward import get_custom_reward_fn +from verl.utils import omega_conf_to_dataclass from .dapo_ray_trainer import RayDAPOTrainer +def trainer_dict_to_dataclass(conf: DictConfig): + """Convert specific nested sections of a DictConfig object into dataclass instances. + + Args: + conf (DictConfig): An instance of DictConfig, typically from the omegaconf library, + representing a configuration dictionary. + + Returns: + DictConfig: A deep copy of the input `conf` with specific sections converted to dataclasses. + """ + # Create a deep copy of the input configuration to avoid modifying the original object + config = copy.deepcopy(conf) + config.algorithm = omega_conf_to_dataclass(config.algorithm) + config.critic.profiler = omega_conf_to_dataclass(config.critic.profiler) + config.reward_model.profiler = omega_conf_to_dataclass(config.reward_model.profiler) + config.actor_rollout_ref.actor.profiler = omega_conf_to_dataclass(config.actor_rollout_ref.actor.profiler) + config.actor_rollout_ref.ref.profiler = omega_conf_to_dataclass(config.actor_rollout_ref.ref.profiler) + config.actor_rollout_ref.rollout.profiler = omega_conf_to_dataclass(config.actor_rollout_ref.rollout.profiler) + return config + + @hydra.main(config_path="config", config_name="dapo_trainer", version_base=None) -def main(config): +def main(config_dict): + config = trainer_dict_to_dataclass(config_dict) run_ppo(config) diff --git a/recipe/entropy/entropy_ray_trainer.py b/recipe/entropy/entropy_ray_trainer.py index a2c06d72eb2..fc914ed2c2b 100644 --- a/recipe/entropy/entropy_ray_trainer.py +++ b/recipe/entropy/entropy_ray_trainer.py @@ -39,7 +39,7 @@ compute_advantage, compute_response_mask, ) -from verl.utils.debug import simple_timer +from verl.utils.profiler import simple_timer class RayEntropyTrainer(RayPPOTrainer): diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 68c30edf60b..e35340464c4 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -25,7 +25,6 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import log_gpu_memory_usage from verl.utils.device import get_device_id, get_device_name, get_nccl_backend from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_local_path_from_hdfs @@ -39,6 +38,7 @@ offload_fsdp_optimizer, ) from verl.utils.import_utils import import_external_libs +from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 3da0a0ca6d3..c66fe53d265 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -33,8 +33,8 @@ from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn -from verl.utils.debug.performance import simple_timer from verl.utils.metric import reduce_metrics +from verl.utils.profiler.performance import simple_timer from . import prime_core_algos diff --git a/recipe/spin/fsdp_workers.py b/recipe/spin/fsdp_workers.py index 17b0ed414da..e8a43e0d8d8 100644 --- a/recipe/spin/fsdp_workers.py +++ b/recipe/spin/fsdp_workers.py @@ -31,7 +31,6 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import log_gpu_memory_usage from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local @@ -46,6 +45,7 @@ ) from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask +from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import ActorRolloutRefWorker from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager diff --git a/recipe/sppo/dp_actor.py b/recipe/sppo/dp_actor.py index e87bea8ffb4..df14c0b4ed6 100644 --- a/recipe/sppo/dp_actor.py +++ b/recipe/sppo/dp_actor.py @@ -21,8 +21,8 @@ import verl.utils.torch_functional as verl_F from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, kl_penalty -from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import rearrange_micro_batches from verl.workers.actor.dp_actor import DataParallelPPOActor diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py index fc32c4c6d14..c6140158360 100644 --- a/recipe/sppo/sppo_ray_trainer.py +++ b/recipe/sppo/sppo_ray_trainer.py @@ -44,7 +44,7 @@ compute_response_mask, ) from verl.trainer.ppo.reward import compute_reward, compute_reward_async -from verl.utils.debug.performance import simple_timer +from verl.utils.profiler.performance import simple_timer from verl.utils.tracking import ValidationGenerationsLogger diff --git a/recipe/sppo/sppo_worker.py b/recipe/sppo/sppo_worker.py index 8f7fbbefa54..fbe3a6e48b4 100644 --- a/recipe/sppo/sppo_worker.py +++ b/recipe/sppo/sppo_worker.py @@ -20,10 +20,10 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import log_gpu_memory_usage from verl.utils.flops_counter import FlopsCounter from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer from verl.utils.import_utils import import_external_libs +from verl.utils.profiler import log_gpu_memory_usage from verl.workers.fsdp_workers import ActorRolloutRefWorker logger = logging.getLogger(__file__) diff --git a/tests/special_distributed/test_tensor_dict.py b/tests/special_distributed/test_tensor_dict.py index 27da6f5a2f2..0a7f8039d90 100644 --- a/tests/special_distributed/test_tensor_dict.py +++ b/tests/special_distributed/test_tensor_dict.py @@ -58,8 +58,8 @@ def test_all_gather_data_proto(): def test_vocab_parallel_entropy(): from megatron.core import parallel_state as mpu - from verl.utils.debug import log_gpu_memory_usage from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy + from verl.utils.profiler import log_gpu_memory_usage from verl.utils.torch_functional import entropy_from_logits mpu.initialize_model_parallel( diff --git a/tests/special_sanity/check_api_docs.py b/tests/special_sanity/check_api_docs.py index 994882d9ef9..fa31ec8c5dc 100644 --- a/tests/special_sanity/check_api_docs.py +++ b/tests/special_sanity/check_api_docs.py @@ -39,13 +39,13 @@ _ALLOW_LIST = [ "verl.third_party.vllm.LLM", "verl.third_party.vllm.parallel_state", - "verl.utils.debug.WorkerProfiler", - "verl.utils.debug.WorkerProfilerExtension", - "verl.utils.debug.log_gpu_memory_usage", - "verl.utils.debug.log_print", - "verl.utils.debug.mark_annotate", - "verl.utils.debug.mark_end_range", - "verl.utils.debug.mark_start_range", + "verl.utils.profiler.WorkerProfiler", + "verl.utils.profiler.WorkerProfilerExtension", + "verl.utils.profiler.log_gpu_memory_usage", + "verl.utils.profiler.log_print", + "verl.utils.profiler.mark_annotate", + "verl.utils.profiler.mark_end_range", + "verl.utils.profiler.mark_start_range", "verl.models.mcore.qwen2_5_vl.get_vision_model_config", "verl.models.mcore.qwen2_5_vl.get_vision_projection_config", ] diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index ee8a366e885..c8988db55a5 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -28,7 +28,7 @@ "recipe/prime/prime_ray_trainer.py", # appear in default device_name "recipe/spin/spin_trainer.py", # appear in default device_name "recipe/sppo/sppo_ray_trainer.py", # appear in default device_name - "verl/utils/debug/nvtx_profile.py", # appear in NsightSystemsProfiler + "verl/utils/profiler/nvtx_profile.py", # appear in NsightSystemsProfiler "verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx "verl/utils/rendezvous/ray_backend.py", # appear in cupy importance "verl/single_controller/ray/base.py", # appear in default device_name diff --git a/tests/special_sanity/validate_structure.py b/tests/special_sanity/validate_structure.py index 27929116f2c..a5390b15acd 100644 --- a/tests/special_sanity/validate_structure.py +++ b/tests/special_sanity/validate_structure.py @@ -86,7 +86,7 @@ def main() -> None: parser.add_argument( "--allow-files", nargs="*", - default=["tests/test_protocol_on_cpu.py"], + default=["tests/test_protocol_on_cpu.py", "tests/test_base_config_on_cpu.py"], help="Extra top-level test folders that are exempt from the rule", ) args = parser.parse_args() diff --git a/tests/test_base_config_on_cpu.py b/tests/test_base_config_on_cpu.py new file mode 100644 index 00000000000..9a50235c8ff --- /dev/null +++ b/tests/test_base_config_on_cpu.py @@ -0,0 +1,42 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from verl.base_config import BaseConfig + + +@pytest.fixture +def base_config_mock(): + """Fixture to create a mock BaseConfig instance with test attributes.""" + mock_config = BaseConfig() + mock_config.test_attr = "test_value" + return mock_config + + +def test_getitem_success(base_config_mock): + """Test __getitem__ with existing attribute (happy path).""" + assert base_config_mock["test_attr"] == "test_value" + + +def test_getitem_nonexistent_attribute(base_config_mock): + """Test __getitem__ with non-existent attribute (exception path 1).""" + with pytest.raises(AttributeError): + _ = base_config_mock["nonexistent_attr"] + + +def test_getitem_invalid_key_type(base_config_mock): + """Test __getitem__ with invalid key type (exception path 2).""" + with pytest.raises(TypeError): + _ = base_config_mock[123] # type: ignore diff --git a/tests/trainer/config/__init__.py b/tests/trainer/config/__init__.py new file mode 100644 index 00000000000..1ce90c5eb35 --- /dev/null +++ b/tests/trainer/config/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/trainer/config/test_algo_config_on_cpu.py b/tests/trainer/config/test_algo_config_on_cpu.py new file mode 100644 index 00000000000..ab3f646499b --- /dev/null +++ b/tests/trainer/config/test_algo_config_on_cpu.py @@ -0,0 +1,194 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from omegaconf import OmegaConf + +from verl.trainer.config import AlgoConfig, KLControlConfig, PFPPOConfig +from verl.trainer.ppo.core_algos import ( + compute_gae_advantage_return, + compute_grpo_outcome_advantage, + get_adv_estimator_fn, +) +from verl.utils.config import omega_conf_to_dataclass + + +class TestAlgoConfig(unittest.TestCase): + """Test the AlgoConfig dataclass and its integration with core algorithms.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a sample algorithm config as DictConfig (similar to what comes from YAML) + self.config_dict = { + "_target_": "verl.trainer.config.AlgoConfig", + "gamma": 0.99, + "lam": 0.95, + "adv_estimator": "gae", + "norm_adv_by_std_in_grpo": True, + "use_kl_in_reward": True, + "kl_penalty": "kl", + "kl_ctrl": { + "_target_": "verl.trainer.config.KLControlConfig", + "type": "adaptive", + "kl_coef": 0.002, + "horizon": 5000, + "target_kl": 0.05, + }, + "use_pf_ppo": True, + "pf_ppo": {"_target_": "verl.trainer.config.PFPPOConfig", "reweight_method": "max_min", "weight_pow": 3.0}, + } + self.omega_config = OmegaConf.create(self.config_dict) + self.algo_config = AlgoConfig( + gamma=0.99, + lam=0.95, + adv_estimator="gae", + norm_adv_by_std_in_grpo=True, + use_kl_in_reward=True, + kl_penalty="kl", + kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05), + use_pf_ppo=True, + pf_ppo=PFPPOConfig(reweight_method="max_min", weight_pow=3.0), + ) + + def test_dataclass_creation_from_dict(self): + """Test creating AlgoConfig from dictionary.""" + config = omega_conf_to_dataclass(self.config_dict) + + self.assertIsInstance(config, AlgoConfig) + self.assertEqual(config.gamma, 0.99) + self.assertEqual(config.lam, 0.95) + self.assertEqual(config.adv_estimator, "gae") + self.assertTrue(config.norm_adv_by_std_in_grpo) + self.assertTrue(config.use_kl_in_reward) + self.assertEqual(config.kl_penalty, "kl") + self.assertTrue(config.use_pf_ppo) + + def test_dataclass_creation_from_omega_config(self): + """Test creating AlgoConfig from OmegaConf DictConfig.""" + config = omega_conf_to_dataclass(self.omega_config) + + self.assertIsInstance(config, AlgoConfig) + self.assertEqual(config.gamma, 0.99) + self.assertEqual(config.lam, 0.95) + + def test_nested_configs(self): + """Test that nested configurations are properly converted.""" + config = omega_conf_to_dataclass(self.omega_config) + + # Test KL control config + self.assertIsInstance(config.kl_ctrl, KLControlConfig) + self.assertEqual(config.kl_ctrl.type, "adaptive") + self.assertEqual(config.kl_ctrl.kl_coef, 0.002) + self.assertEqual(config.kl_ctrl.horizon, 5000) + self.assertEqual(config.kl_ctrl.target_kl, 0.05) + + # Test PF PPO config + self.assertIsInstance(config.pf_ppo, PFPPOConfig) + self.assertEqual(config.pf_ppo.reweight_method, "max_min") + self.assertEqual(config.pf_ppo.weight_pow, 3.0) + + def test_default_values(self): + """Test that default values are properly set.""" + minimal_config = {"gamma": 0.8} + config = omega_conf_to_dataclass(minimal_config, AlgoConfig) + + self.assertEqual(config.gamma, 0.8) + self.assertEqual(config.lam, 1.0) # default value + self.assertEqual(config.adv_estimator, "gae") # default value + self.assertTrue(config.norm_adv_by_std_in_grpo) # default value + self.assertFalse(config.use_kl_in_reward) # default value + self.assertEqual(config.kl_penalty, "kl") # default value + self.assertFalse(config.use_pf_ppo) # default value + + def test_get_method_backward_compatibility(self): + """Test the get method for backward compatibility.""" + config = omega_conf_to_dataclass(self.omega_config) + + # Test existing attribute + self.assertEqual(config.get("gamma"), 0.99) + self.assertEqual(config.get("gamma", 1.0), 0.99) + + # Test non-existing attribute + self.assertIsNone(config.get("non_existing")) + self.assertEqual(config.get("non_existing", "default"), "default") + + def test_advantage_estimator_with_cfg(self): + """Test integration with advantage estimators from core_algos.""" + config = self.algo_config + + # Test GAE advantage estimator + adv_fn = get_adv_estimator_fn(config.adv_estimator) + self.assertIsNotNone(adv_fn) + + # Test with actual GAE computation + batch_size, seq_len = 2, 5 + token_level_rewards = torch.randn(batch_size, seq_len) + values = torch.randn(batch_size, seq_len) + response_mask = torch.ones(batch_size, seq_len) + + advantages, returns = compute_gae_advantage_return( + token_level_rewards=token_level_rewards, + values=values, + response_mask=response_mask, + gamma=config.gamma, + lam=config.lam, + ) + + self.assertEqual(advantages.shape, (batch_size, seq_len)) + self.assertEqual(returns.shape, (batch_size, seq_len)) + + def test_grpo_advantage_estimator_with_cfg(self): + """Test integration with GRPO advantage estimator.""" + grpo_config = AlgoConfig(adv_estimator="grpo", norm_adv_by_std_in_grpo=True) + + # Test GRPO advantage computation + batch_size, seq_len = 4, 3 + token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]]) + response_mask = torch.ones(batch_size, seq_len) + index = np.array([0, 0, 1, 1]) # Two groups + + advantages, returns = compute_grpo_outcome_advantage( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + norm_adv_by_std_in_grpo=grpo_config.norm_adv_by_std_in_grpo, + ) + + self.assertEqual(advantages.shape, (batch_size, seq_len)) + self.assertEqual(returns.shape, (batch_size, seq_len)) + + def test_post_init_nested_configs(self): + """Test that __post_init__ properly initializes nested configs when None.""" + # Create config without nested configs + minimal_config = AlgoConfig(gamma=0.9) + + # Check that nested configs are initialized + self.assertIsNotNone(minimal_config.kl_ctrl) + self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig) + self.assertIsNone(minimal_config.pf_ppo) + + def test_config_init_from_yaml(self): + cfg = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") + algo_config = omega_conf_to_dataclass(cfg.algorithm) + from verl.trainer.config import AlgoConfig, PFPPOConfig + + assert isinstance(algo_config, AlgoConfig) + assert isinstance(algo_config.pf_ppo, PFPPOConfig) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_nvtx_profile.py b/tests/utils/test_nvtx_profile.py index 63b67e9a77c..63d425990ae 100644 --- a/tests/utils/test_nvtx_profile.py +++ b/tests/utils/test_nvtx_profile.py @@ -19,8 +19,8 @@ from omegaconf import OmegaConf from verl.utils import omega_conf_to_dataclass -from verl.utils.debug import ProfilerConfig -from verl.utils.debug.nvtx_profile import NsightSystemsProfiler +from verl.utils.profiler import ProfilerConfig +from verl.utils.profiler.nvtx_profile import NsightSystemsProfiler class TestNsightSystemsProfiler(unittest.TestCase): @@ -79,8 +79,8 @@ def test_func(self, *args, **kwargs): return "result" with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, patch( - "verl.utils.debug.nvtx_profile.mark_start_range" - ) as mock_start_range, patch("verl.utils.debug.nvtx_profile.mark_end_range") as mock_end_range: + "verl.utils.profiler.nvtx_profile.mark_start_range" + ) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range: result = test_func(mock_self) self.assertEqual(result, "result") mock_start_range.assert_called_once() @@ -100,8 +100,8 @@ def test_func(self, *args, **kwargs): return "result" with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, patch( - "verl.utils.debug.nvtx_profile.mark_start_range" - ) as mock_start_range, patch("verl.utils.debug.nvtx_profile.mark_end_range") as mock_end_range: + "verl.utils.profiler.nvtx_profile.mark_start_range" + ) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range: result = test_func(mock_self) self.assertEqual(result, "result") mock_start_range.assert_called_once() @@ -119,11 +119,20 @@ def test_config_init(self): arr.ref.profiler, arr.rollout.profiler, ]: - profiler_config = omega_conf_to_dataclass(config, ProfilerConfig) - self.assertEqual(profiler_config.discrete, False) - self.assertEqual(profiler_config.all_ranks, False) - self.assertEqual(profiler_config.ranks, []) + profiler_config = omega_conf_to_dataclass(config) + self.assertEqual(profiler_config.discrete, config.discrete) + self.assertEqual(profiler_config.all_ranks, config.all_ranks) + self.assertEqual(profiler_config.ranks, config.ranks) assert isinstance(profiler_config, ProfilerConfig) + with self.assertRaises(AttributeError): + _ = profiler_config.non_existing_key + assert config.get("non_existing_key") == profiler_config.get("non_existing_key") + assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) + assert config["discrete"] == profiler_config["discrete"] + from dataclasses import FrozenInstanceError + + with self.assertRaises(FrozenInstanceError): + profiler_config.discrete = False if __name__ == "__main__": diff --git a/verl/base_config.py b/verl/base_config.py new file mode 100644 index 00000000000..d413160ded5 --- /dev/null +++ b/verl/base_config.py @@ -0,0 +1,74 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from dataclasses import fields # Import the fields function to inspect dataclass fields +from typing import Any + + +# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary +class BaseConfig(collections.abc.Mapping): + """The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config. + + The BaseConfig class implements the Mapping Abstract Base Class. + This allows instances of this class to be used like dictionaries. + """ + + def get(self, key: str, default: Any = None) -> Any: + """Get the value associated with the given key. If the key does not exist, return the default value. + + Args: + key (str): The attribute name to retrieve. + default (Any, optional): The value to return if the attribute does not exist. Defaults to None. + + Returns: + Any: The value of the attribute or the default value. + """ + try: + return getattr(self, key) + except AttributeError: + return default + + def __getitem__(self, key: str): + """Implement the [] operator for the class. Allows accessing attributes like dictionary items. + + Args: + key (str): The attribute name to retrieve. + + Returns: + Any: The value of the attribute. + + Raises: + AttributeError: If the attribute does not exist. + TypeError: If the key type is not string + """ + return getattr(self, key) + + def __iter__(self): + """Implement the iterator protocol. Allows iterating over the attribute names of the instance. + + Yields: + str: The name of each field in the dataclass. + """ + for f in fields(self): + yield f.name + + def __len__(self): + """ + Return the number of fields in the dataclass. + + Returns: + int: The number of fields in the dataclass. + """ + return len(fields(self)) diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py index 5738dc586ab..18ab8024ed5 100644 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -17,7 +17,7 @@ from uuid import uuid4 from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput -from verl.utils.debug import simple_timer +from verl.utils.profiler import simple_timer logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index 5f9ef1df0c8..14685adcf27 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -24,7 +24,7 @@ from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput from verl.tools.utils.tool_registry import initialize_tools_from_config -from verl.utils.debug import simple_timer +from verl.utils.profiler import simple_timer logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) diff --git a/verl/trainer/config/__init__.py b/verl/trainer/config/__init__.py index 1ce90c5eb35..f4cc9b8e2c7 100644 --- a/verl/trainer/config/__init__.py +++ b/verl/trainer/config/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .algorithm import AlgoConfig, FilterGroupsConfig, KLControlConfig, PFPPOConfig + +__all__ = [ + "AlgoConfig", + "FilterGroupsConfig", + "KLControlConfig", + "PFPPOConfig", +] diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py new file mode 100644 index 00000000000..55d2e78253a --- /dev/null +++ b/verl/trainer/config/algorithm.py @@ -0,0 +1,63 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from verl.base_config import BaseConfig + + +@dataclass(frozen=True) +class KLControlConfig(BaseConfig): + """Configuration for KL control.""" + + type: str = "fixed" # "fixed" or "adaptive" + kl_coef: float = 0.001 # Initial coefficient for KL penalty + horizon: int = 10000 # Horizon value for adaptive controller + target_kl: float = 0.1 # Target KL divergence for adaptive controller + + +@dataclass(frozen=True) +class PFPPOConfig(BaseConfig): + """Configuration for preference feedback PPO.""" + + reweight_method: str = "pow" # "pow", "max_min", or "max_random" + weight_pow: float = 2.0 # Power used for weight scaling in "pow" method + + +@dataclass(frozen=True) +class FilterGroupsConfig(BaseConfig): + """Configuration for filter groups (used in DAPO and Entropy).""" + + enable: bool = False # Whether to enable filter groups + metric: Optional[str] = None # Metric to use for filtering: "acc", "score", "seq_reward", "seq_final_reward", etc. + max_num_gen_batches: int = 0 # Non-positive values mean no upper limit + + +@dataclass(frozen=True) +class AlgoConfig(BaseConfig): + """Configuration for the algorithm.""" + + gamma: float = 1.0 # Discount factor for future rewards + lam: float = 1.0 # Trade-off between bias and variance in the GAE estimator + adv_estimator: str = "gae" # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + norm_adv_by_std_in_grpo: bool = True # Whether to normalize advantages by std (specific to GRPO) + use_kl_in_reward: bool = False # Whether to enable in-reward KL penalty + kl_penalty: str = "kl" # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full" + kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig) # KL control configuration + use_pf_ppo: bool = False # Whether to enable preference feedback PPO + pf_ppo: Optional[PFPPOConfig] = None # Preference feedback PPO settings + + # Filter groups parameters (used in DAPO and Entropy) + filter_groups: Optional[FilterGroupsConfig] = None # Filter groups configuration diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 56d2840e4df..dac4599b4f2 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -119,9 +119,11 @@ actor_rollout_ref: load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} # Nsight system profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig discrete: False all_ranks: False - ranks: null + ranks: [] ref: strategy: ${actor_rollout_ref.actor.strategy} use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} @@ -153,9 +155,11 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} # Nsight system profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig discrete: False all_ranks: False - ranks: null + ranks: [] rollout: name: vllm mode: sync # sync: LLM, async: AsyncLLM @@ -263,9 +267,11 @@ actor_rollout_ref: calculate_log_probs: False # Nsight system profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig discrete: False all_ranks: False - ranks: null + ranks: [] critic: rollout_n: ${actor_rollout_ref.rollout.n} @@ -342,9 +348,11 @@ critic: load_contents: ${critic.checkpoint.save_contents} # Nsight system profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig discrete: False all_ranks: False - ranks: null + ranks: [] reward_model: enable: False strategy: ${actor_rollout_ref.actor.strategy} @@ -383,15 +391,19 @@ reward_model: memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB # Nsight system profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig discrete: False all_ranks: False - ranks: null + ranks: [] custom_reward_function: path: null name: compute_score algorithm: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig gamma: 1.0 lam: 1.0 adv_estimator: gae @@ -399,12 +411,16 @@ algorithm: use_kl_in_reward: False kl_penalty: kl # how to estimate kl divergence kl_ctrl: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig type: fixed kl_coef: 0.001 horizon: 10000 target_kl: 0.1 use_pf_ppo: False pf_ppo: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig reweight_method: pow # ["pow", "max_min", "max_random"] weight_pow: 2.0 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index e4652bc4a2b..793f4dbc549 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -309,14 +309,17 @@ actor_rollout_ref: # profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + # True for each task has its own database, False for all tasks in one training step share one database. discrete: False # Whether to profile all ranks. all_ranks: False - # The ranks that will be profiled. null or [0,1,...] - ranks: null + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] # Reference model config. # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. @@ -373,14 +376,17 @@ actor_rollout_ref: # profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + # True for each task has its own database, False for all tasks in one training step share one database. discrete: False # Whether to profile all ranks. all_ranks: False - # The ranks that will be profiled. null or [0,1,...] - ranks: null + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] # Rollout model config. rollout: @@ -561,14 +567,17 @@ actor_rollout_ref: # profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + # True for each task has its own database, False for all tasks in one training step share one database. discrete: False # Whether to profile all ranks. all_ranks: False - # The ranks that will be profiled. null or [0,1,...] - ranks: null + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] # [Experimental] agent loop based rollout configs agent: @@ -735,17 +744,20 @@ critic: load_contents: ${critic.checkpoint.save_contents} # profiler configs - # the corresponding dataclass is verl.utils.debug.ProfilerConfig. + # the corresponding dataclass is verl.utils.profiler.ProfilerConfig. profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + # True for each task has its own database, False for all tasks in one training step share one database. discrete: False # Whether to profile all ranks. all_ranks: False - # The ranks that will be profiled. null or [0,1,...] - ranks: null + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] # configs for the reward model reward_model: @@ -849,14 +861,17 @@ reward_model: # profiler configs profiler: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.utils.profiler.ProfilerConfig + # True for each task has its own database, False for all tasks in one training step share one database. discrete: False # Whether to profile all ranks. all_ranks: False - # The ranks that will be profiled. null or [0,1,...] - ranks: null + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] # custom reward function definition custom_reward_function: @@ -871,6 +886,9 @@ custom_reward_function: # config for the algorithm algorithm: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.AlgoConfig + # Discount factor for future rewards gamma: 1.0 @@ -892,6 +910,9 @@ algorithm: # KL control configuration kl_ctrl: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.KLControlConfig + # KL control type: "fixed" or "adaptive" type: fixed @@ -910,6 +931,9 @@ algorithm: # Preference feedback PPO settings pf_ppo: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + _target_: verl.trainer.config.PFPPOConfig + # Method for reweighting samples: "pow", "max_min", or "max_random" reweight_method: pow diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 1137fcce3a6..531ebab6276 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -43,7 +43,6 @@ import verl.utils.hdfs_io as hdfs_io from verl.utils.dataset import SFTDataset from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset -from verl.utils.debug import log_gpu_memory_usage from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group from verl.utils.fs import copy_to_local @@ -57,6 +56,7 @@ get_init_weight_context_manager, init_fn, ) +from verl.utils.profiler import log_gpu_memory_usage from verl.utils.py_functional import convert_to_regular_types from verl.utils.torch_dtypes import PrecisionType from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 94da5b5fd69..e0e7dba5325 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -15,19 +15,43 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import copy import os import socket import hydra import ray -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.config import omega_conf_to_dataclass + + +def trainer_dict_to_dataclass(conf: DictConfig): + """Convert specific nested sections of a DictConfig object into dataclass instances. + + Args: + conf (DictConfig): An instance of DictConfig, typically from the omegaconf library, + representing a configuration dictionary. + + Returns: + DictConfig: A deep copy of the input `conf` with specific sections converted to dataclasses. + """ + # Create a deep copy of the input configuration to avoid modifying the original object + config = copy.deepcopy(conf) + config.algorithm = omega_conf_to_dataclass(config.algorithm) + config.critic.profiler = omega_conf_to_dataclass(config.critic.profiler) + config.reward_model.profiler = omega_conf_to_dataclass(config.reward_model.profiler) + config.actor_rollout_ref.actor.profiler = omega_conf_to_dataclass(config.actor_rollout_ref.actor.profiler) + config.actor_rollout_ref.ref.profiler = omega_conf_to_dataclass(config.actor_rollout_ref.ref.profiler) + config.actor_rollout_ref.rollout.profiler = omega_conf_to_dataclass(config.actor_rollout_ref.rollout.profiler) + return config @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) -def main(config): +def main(config_dict): + config = trainer_dict_to_dataclass(config_dict) run_ppo(config) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index ef227f84cb8..71b77e24171 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -22,11 +22,13 @@ from collections import defaultdict from enum import Enum +from typing import Optional import numpy as np import torch import verl.utils.torch_functional as verl_F +from verl.trainer.config import AlgoConfig POLICY_LOSS_REGISTRY = {} @@ -213,8 +215,9 @@ def compute_grpo_outcome_advantage( response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, - norm_adv_by_std_in_grpo: str = True, -): + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). @@ -224,10 +227,18 @@ def compute_grpo_outcome_advantage( shape is (bs, response_length) response_mask: `(torch.Tensor)` shape is (bs, response_length) - norm_adv_by_std_in_grpo: (bool) - whether to scale the GRPO advantage. - If True, the advantage is scaled by the std, as in the original GRPO. - If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + index: `(np.ndarray)` + index array for grouping + epsilon: `(float)` + small value to avoid division by zero + norm_adv_by_std_in_grpo: `(bool)` + whether to scale the GRPO advantage + config: `(Optional[AlgoConfig])` + algorithm configuration object + + Note: + If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). Returns: advantages: `(torch.Tensor)` @@ -271,9 +282,9 @@ def compute_grpo_passk_outcome_advantage( index: np.ndarray, epsilon: float = 1e-6, norm_adv_by_std_in_grpo: bool = True, - config=None, + config: Optional[AlgoConfig] = None, **kwargs, -): +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for Pass@k using a GRPO-style outcome reward formulation. Only the best response per group gets a non-zero advantage: r_max - r_second_max. @@ -285,7 +296,7 @@ def compute_grpo_passk_outcome_advantage( response_mask: (bs, response_length) index: (bs,) → group ID per sample epsilon: float for numerical stability - config: (dict) algorithm settings, which contains "norm_adv_by_std_in_grpo" + config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo" Returns: advantages: (bs, response_length) @@ -334,9 +345,9 @@ def compute_reinforce_plus_plus_baseline_outcome_advantage( response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6, - config=None, + config: Optional[AlgoConfig] = None, **kwargs, -): +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward (with only one scalar reward for each response). @@ -346,7 +357,7 @@ def compute_reinforce_plus_plus_baseline_outcome_advantage( shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) - config: (dict) algorithm config + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -386,9 +397,9 @@ def compute_rloo_outcome_advantage( response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, - config=None, + config: Optional[AlgoConfig] = None, **kwargs, -): +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 @@ -397,7 +408,7 @@ def compute_rloo_outcome_advantage( shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) - config: (dict) algorithm config + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -438,9 +449,9 @@ def compute_opo_outcome_advantage( response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, - config=None, + config: Optional[AlgoConfig] = None, **kwargs, -): +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 @@ -449,7 +460,7 @@ def compute_opo_outcome_advantage( shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) - config: (dict) algorithm config + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -488,8 +499,8 @@ def compute_opo_outcome_advantage( @register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") def compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config=None, **kwargs -): + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for REINFORCE++. This implementation is based on the paper: https://arxiv.org/abs/2501.03262 @@ -499,7 +510,7 @@ def compute_reinforce_plus_plus_outcome_advantage( shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) - config: (dict) algorithm config + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -530,9 +541,9 @@ def compute_remax_outcome_advantage( token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor, - config=None, + config: Optional[AlgoConfig] = None, **kwargs, -): +) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for ReMax, operating only on Outcome reward This implementation is based on the paper: https://arxiv.org/abs/2310.10505 @@ -545,7 +556,7 @@ def compute_remax_outcome_advantage( shape: (bs,) response_mask: `(torch.Tensor)` shape: (bs, response_length) - config: (dict) algorithm config + config: (AlgoConfig) algorithm config Returns: advantages: `(torch.Tensor)` @@ -762,13 +773,13 @@ def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, l @register_policy_loss("clip_cov") def compute_policy_loss_clip_cov( - old_log_prob, - log_prob, - advantages, - response_mask, - loss_agg_mode="token-mean", - config=None, -): + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for Clip-Cov. @@ -852,13 +863,13 @@ def compute_policy_loss_clip_cov( @register_policy_loss("kl_cov") def compute_policy_loss_kl_cov( - old_log_prob, - log_prob, - advantages, - response_mask, - loss_agg_mode="token-mean", - config=None, -): + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for Clip-Cov. diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 23b76e204c2..5e63b3221ad 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -41,6 +41,7 @@ from verl.single_controller.base import Worker from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig from verl.trainer.ppo import core_algos from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss from verl.trainer.ppo.metric_utils import ( @@ -51,10 +52,10 @@ ) from verl.trainer.ppo.reward import compute_reward, compute_reward_async from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi -from verl.utils.debug import marked_timer from verl.utils.metric import ( reduce_metrics, ) +from verl.utils.profiler import marked_timer from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger @@ -204,13 +205,13 @@ def compute_response_mask(data: DataProto): def compute_advantage( data: DataProto, - adv_estimator, - gamma=1.0, - lam=1.0, - num_repeat=1, - norm_adv_by_std_in_grpo=True, - config=None, -): + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, + num_repeat: int = 1, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> DataProto: """Compute advantage estimates for policy optimization. This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. @@ -218,7 +219,7 @@ def compute_advantage( Args: data (DataProto): The data containing batched model outputs and inputs. - adv_estimator: The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. @@ -247,8 +248,8 @@ def compute_advantage( if config.get("use_pf_ppo", False): data = core_algos.compute_pf_ppo_reweight_data( data, - config.get("pf_ppo_reweight_method", "pow"), - config.get("pf_ppo_weight_pow", 2.0), + config.pf_ppo.reweight_method, + config.pf_ppo.weight_pow, ) elif adv_estimator == AdvantageEstimator.GRPO: # Initialize the mask for GRPO calculation @@ -347,8 +348,8 @@ def __init__( # define in-reward KL control # kl loss control currently not suppoorted - if config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: self.use_critic = True @@ -482,7 +483,7 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): "seq-mean-token-sum-norm", ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" - if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: print("NOTICE: You have both enabled in-reward kl and kl loss.") # critic diff --git a/verl/utils/config.py b/verl/utils/config.py index 2c636c16dac..8847ef8cf9b 100644 --- a/verl/utils/config.py +++ b/verl/utils/config.py @@ -13,24 +13,37 @@ # limitations under the License. from dataclasses import is_dataclass -from typing import Any, Dict, Type, Union +from typing import Any, Dict, Optional, Type, Union from omegaconf import DictConfig, OmegaConf __all__ = ["omega_conf_to_dataclass"] -def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Type[Any]) -> Any: +def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Optional[Type[Any]] = None) -> Any: """ Convert an OmegaConf DictConfig to a dataclass. Args: config: The OmegaConf DictConfig or dict to convert. - dataclass_type: The dataclass type to convert to. + dataclass_type: The dataclass type to convert to. When dataclass_type is None, + the DictConfig must contain _target_ to be instantiated via hydra.instantiate API. Returns: The dataclass instance. """ + if dataclass_type is not None and isinstance(config, dataclass_type): + return config + + if dataclass_type is None: + assert "_target_" in config, ( + "When dataclass_type is not provided, config must contain _target_." + "See trainer/config/ppo_trainer.yaml algorithm section for an example." + ) + from hydra.utils import instantiate + + return instantiate(config, _convert_="partial") + if not is_dataclass(dataclass_type): raise ValueError(f"{dataclass_type} must be a dataclass") cfg = OmegaConf.create(config) # in case it's a dict diff --git a/verl/utils/debug/__init__.py b/verl/utils/debug/__init__.py index 436d1dd8caf..eb67df1b772 100644 --- a/verl/utils/debug/__init__.py +++ b/verl/utils/debug/__init__.py @@ -12,27 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..import_utils import is_nvtx_available -from .performance import GPUMemoryLogger, log_gpu_memory_usage, log_print, simple_timer -from .profile import DistProfilerExtension, ProfilerConfig - -if is_nvtx_available(): - from .nvtx_profile import NsightSystemsProfiler as DistProfiler - from .nvtx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer -else: - from .performance import marked_timer - from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range - -__all__ = [ - "GPUMemoryLogger", - "log_gpu_memory_usage", - "log_print", - "mark_start_range", - "mark_end_range", - "mark_annotate", - "DistProfiler", - "DistProfilerExtension", - "ProfilerConfig", - "simple_timer", - "marked_timer", -] +# APIs kept for backward compatibility purpose +# For new features please develop in verl/utils/profiler/ +from ..profiler import * # noqa diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index 56d439889e6..9186e125a20 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -12,186 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import inspect -import logging -from contextlib import contextmanager -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.distributed as dist -from codetiming import Timer - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import DecoratorLoggerBase - - -def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]: - """Get current memory usage.""" - assert unit in ["GB", "MB", "KB"] - divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024 - mem_allocated = get_torch_device().memory_allocated() - mem_reserved = get_torch_device().memory_reserved() - # use get_torch_device().mem_get_info to profile device memory - # since vllm's sleep mode works below pytorch - # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 - mem_free, mem_total = get_torch_device().mem_get_info() - mem_used = mem_total - mem_free - mem_allocated = f"{mem_allocated / divisor:.{precision}f}" - mem_reserved = f"{mem_reserved / divisor:.{precision}f}" - mem_used = f"{mem_used / divisor:.{precision}f}" - mem_total = f"{mem_total / divisor:.{precision}f}" - return mem_allocated, mem_reserved, mem_used, mem_total - - -def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): - if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = ( - f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " - f"device memory used/total (GB): {mem_used}/{mem_total}" - ) - - if logger is None: - print(message) - else: - logger.log(msg=message, level=level) - - -class GPUMemoryLogger(DecoratorLoggerBase): - """A decorator class to log GPU memory usage. - - Example: - >>> from verl.utils.debug.performance import GPUMemoryLogger - >>> @GPUMemoryLogger(role="actor") - >>> def update_actor(self, batch): - ... # real actor update logics - ... return - """ - - def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): - if dist.is_initialized() and dist.get_world_size() > 1: - rank = dist.get_rank() - else: - rank = 0 - super().__init__(role, logger, level, rank, log_only_rank_0) - - def __call__(self, decorated_function: callable): - def f(*args, **kwargs): - return self.log(decorated_function, *args, **kwargs) - - return f - - def log(self, func, *args, **kwargs): - name = func.__name__ - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = ( - f"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " - f"device memory used/total (GB): {mem_used}/{mem_total}" - ) - self.logging_function(message) - - output = func(*args, **kwargs) - - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = ( - f"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " - f"device memory used/total (GB): {mem_used}/{mem_total}" - ) - - self.logging_function(message) - return output - - -def log_print(ctn: Any): - current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - frame = inspect.currentframe().f_back - function_name = frame.f_code.co_name - line_number = frame.f_lineno - file_name = frame.f_code.co_filename.split("/")[-1] - print(f"[{current_time}-{file_name}:{line_number}:{function_name}]: {ctn}") - - -def _timer(name: str, timing_raw: Dict[str, float]): - """Inner function that handles the core timing logic. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - """ - with Timer(name=name, logger=None) as timer: - yield - if name not in timing_raw: - timing_raw[name] = 0 - timing_raw[name] += timer.last - - -@contextmanager -def simple_timer(name: str, timing_raw: Dict[str, float]): - """Context manager for basic timing without NVTX markers. - - This utility function measures the execution time of code within its context - and accumulates the timing information in the provided dictionary. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - - Yields: - None: This is a context manager that yields control back to the code block. - """ - yield from _timer(name, timing_raw) - - -@contextmanager -def marked_timer( - name: str, - timing_raw: Dict[str, float], - color: str = None, - domain: Optional[str] = None, - category: Optional[str] = None, -): - """Context manager for timing with platform markers. - - This utility function measures the execution time of code within its context, - accumulates the timing information, and adds platform markers for profiling. - This function is a default implementation when hardware profiler is not available. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - color (Optional[str]): Color for the marker. Defaults to None. - domain (Optional[str]): Domain for the marker. Defaults to None. - category (Optional[str]): Category for the marker. Defaults to None. - - Yields: - None: This is a context manager that yields control back to the code block. - """ - yield from _timer(name, timing_raw) - - -def reduce_timing(timing_raw: Dict[str, float]) -> Dict[str, float]: - """Reduce timing information across all processes. - - This function uses distributed communication to gather and sum the timing - information from all processes in a distributed environment. - - Args: - timing_raw (Dict[str, float]): Dictionary containing timing information. - - Returns: - Dict[str, float]: Reduced timing information. - """ - if not dist.is_initialized(): - return timing_raw - - key_list, timing_list = [], [] - for key in sorted(timing_raw.keys()): - key_list.append(key) - timing_list.append(timing_raw[key]) - timing_list = torch.tensor(timing_list, dtype=torch.float32, device=get_device_id()) - torch.distributed.all_reduce(timing_list, op=torch.distributed.ReduceOp.AVG) - timing_list = [tensor.item() for tensor in timing_list.to("cpu")] - timing_generate = {key_list[i]: timing_list[i] for i in range(len(key_list))} - return timing_generate +# APIs kept for backward compatibility purpose +# This file is deprecated, for new features please develop in profiler/performance.py +from verl.utils.profiler.performance import simple_timer, reduce_timing # noqa diff --git a/verl/utils/profiler/__init__.py b/verl/utils/profiler/__init__.py new file mode 100644 index 00000000000..fefcc6c1214 --- /dev/null +++ b/verl/utils/profiler/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..import_utils import is_nvtx_available +from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer +from .profile import DistProfilerExtension, ProfilerConfig + +if is_nvtx_available(): + from .nvtx_profile import NsightSystemsProfiler as DistProfiler + from .nvtx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer +else: + from .performance import marked_timer + from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range + +__all__ = [ + "GPUMemoryLogger", + "log_gpu_memory_usage", + "mark_start_range", + "mark_end_range", + "mark_annotate", + "DistProfiler", + "DistProfilerExtension", + "ProfilerConfig", + "simple_timer", + "marked_timer", +] diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py new file mode 100644 index 00000000000..295956bb5c6 --- /dev/null +++ b/verl/utils/profiler/config.py @@ -0,0 +1,51 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from verl.base_config import BaseConfig + + +@dataclass(frozen=True) +class ProfilerConfig(BaseConfig): + """Worker profiler config. Currently only support Nsight system profiler.""" + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: bool = False + + # Whether to profile all ranks. + all_ranks: bool = False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: list[int] = field(default_factory=list) + + def union(self, other: "ProfilerConfig") -> "ProfilerConfig": + return ProfilerConfig( + all_ranks=self.all_ranks or other.all_ranks, + ranks=list(set(self.ranks or []) | set(other.ranks or [])), + discrete=self.discrete or other.discrete, + ) + + def intersect(self, other: "ProfilerConfig") -> "ProfilerConfig": + return ProfilerConfig( + all_ranks=self.all_ranks and other.all_ranks, + ranks=list(set(self.ranks or []) & set(other.ranks or [])), + discrete=self.discrete and other.discrete, + ) + + def __post_init__(self) -> None: + """config validation logics go here""" + assert isinstance(self.ranks, (set, list, tuple)), ( + f"Profiler ranks must be of type list, got {type(self.ranks)}" + ) diff --git a/verl/utils/debug/empty_annotations.py b/verl/utils/profiler/empty_annotations.py similarity index 100% rename from verl/utils/debug/empty_annotations.py rename to verl/utils/profiler/empty_annotations.py diff --git a/verl/utils/debug/nvtx_profile.py b/verl/utils/profiler/nvtx_profile.py similarity index 90% rename from verl/utils/debug/nvtx_profile.py rename to verl/utils/profiler/nvtx_profile.py index 9d8c563c0b9..9e9c51bf708 100644 --- a/verl/utils/debug/nvtx_profile.py +++ b/verl/utils/profiler/nvtx_profile.py @@ -111,18 +111,24 @@ def marked_timer( class NsightSystemsProfiler(DistProfiler): - """ - Nsight system profiler. Installed in a worker to control the Nsight system profiler. - """ + """Nsight system profiler. Installed in a worker to control the Nsight system profiler.""" + + def __init__(self, rank: int, config: Optional[ProfilerConfig]): + """Initialize the NsightSystemsProfiler. - def __init__(self, rank: int, config: ProfilerConfig): - config = config + Args: + rank (int): The rank of the current process. + config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used. + """ + # If no configuration is provided, create a default ProfilerConfig with an empty list of ranks + if not config: + config = ProfilerConfig(ranks=[]) self.this_step: bool = False self.discrete: bool = config.discrete self.this_rank: bool = False if config.all_ranks: self.this_rank = True - elif config.ranks is not None: + elif not config.ranks: self.this_rank = rank in config.ranks def start(self): diff --git a/verl/utils/profiler/performance.py b/verl/utils/profiler/performance.py new file mode 100644 index 00000000000..cdf9cc4c694 --- /dev/null +++ b/verl/utils/profiler/performance.py @@ -0,0 +1,205 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import inspect +import logging +from contextlib import contextmanager +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.distributed as dist +from codetiming import Timer + +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.logger import DecoratorLoggerBase + + +def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> Tuple[str]: + """Get current memory usage.""" + assert unit in ["GB", "MB", "KB"] + divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024 + mem_allocated = get_torch_device().memory_allocated() + mem_reserved = get_torch_device().memory_reserved() + # use get_torch_device().mem_get_info to profile device memory + # since vllm's sleep mode works below pytorch + # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 + mem_free, mem_total = get_torch_device().mem_get_info() + mem_used = mem_total - mem_free + mem_allocated = f"{mem_allocated / divisor:.{precision}f}" + mem_reserved = f"{mem_reserved / divisor:.{precision}f}" + mem_used = f"{mem_used / divisor:.{precision}f}" + mem_total = f"{mem_total / divisor:.{precision}f}" + return mem_allocated, mem_reserved, mem_used, mem_total + + +def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): + """Log GPU memory usage information. + + Args: + head (str): A descriptive header for the memory usage log message. + logger (logging.Logger, optional): Logger instance to use for logging. If None, prints to stdout. + level: Logging level to use. Defaults to logging.DEBUG. + rank (int): The rank of the process to log memory for. Defaults to 0. + """ + if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): + mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() + message = ( + f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " + f"device memory used/total (GB): {mem_used}/{mem_total}" + ) + + if logger is None: + print(message) + else: + logger.log(msg=message, level=level) + + +class GPUMemoryLogger(DecoratorLoggerBase): + """A decorator class to log GPU memory usage. + + Example: + >>> from verl.utils.profiler.performance import GPUMemoryLogger + >>> @GPUMemoryLogger(role="actor") + >>> def update_actor(self, batch): + ... # real actor update logics + ... return + """ + + def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): + if dist.is_initialized() and dist.get_world_size() > 1: + rank = dist.get_rank() + else: + rank = 0 + super().__init__(role, logger, level, rank, log_only_rank_0) + + def __call__(self, decorated_function: callable): + def f(*args, **kwargs): + return self.log(decorated_function, *args, **kwargs) + + return f + + def log(self, func, *args, **kwargs): + name = func.__name__ + mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() + message = ( + f"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " + f"device memory used/total (GB): {mem_used}/{mem_total}" + ) + self.logging_function(message) + + output = func(*args, **kwargs) + + mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() + message = ( + f"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " + f"device memory used/total (GB): {mem_used}/{mem_total}" + ) + + self.logging_function(message) + return output + + +def log_print(ctn: Any): + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + frame = inspect.currentframe().f_back + function_name = frame.f_code.co_name + line_number = frame.f_lineno + file_name = frame.f_code.co_filename.split("/")[-1] + print(f"[{current_time}-{file_name}:{line_number}:{function_name}]: {ctn}") + + +def _timer(name: str, timing_raw: Dict[str, float]): + """Inner function that handles the core timing logic. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + """ + with Timer(name=name, logger=None) as timer: + yield + if name not in timing_raw: + timing_raw[name] = 0 + timing_raw[name] += timer.last + + +@contextmanager +def simple_timer(name: str, timing_raw: Dict[str, float]): + """Context manager for basic timing without NVTX markers. + + This utility function measures the execution time of code within its context + and accumulates the timing information in the provided dictionary. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + + Yields: + None: This is a context manager that yields control back to the code block. + """ + yield from _timer(name, timing_raw) + + +@contextmanager +def marked_timer( + name: str, + timing_raw: Dict[str, float], + color: str = None, + domain: Optional[str] = None, + category: Optional[str] = None, +): + """Context manager for timing with platform markers. + + This utility function measures the execution time of code within its context, + accumulates the timing information, and adds platform markers for profiling. + This function is a default implementation when hardware profiler is not available. + + Args: + name (str): The name/identifier for this timing measurement. + timing_raw (Dict[str, float]): Dictionary to store timing information. + color (Optional[str]): Color for the marker. Defaults to None. + domain (Optional[str]): Domain for the marker. Defaults to None. + category (Optional[str]): Category for the marker. Defaults to None. + + Yields: + None: This is a context manager that yields control back to the code block. + """ + yield from _timer(name, timing_raw) + + +def reduce_timing(timing_raw: Dict[str, float]) -> Dict[str, float]: + """Reduce timing information across all processes. + + This function uses distributed communication to gather and sum the timing + information from all processes in a distributed environment. + + Args: + timing_raw (Dict[str, float]): Dictionary containing timing information. + + Returns: + Dict[str, float]: Reduced timing information. + """ + if not dist.is_initialized(): + return timing_raw + + key_list, timing_list = [], [] + for key in sorted(timing_raw.keys()): + key_list.append(key) + timing_list.append(timing_raw[key]) + timing_list = torch.tensor(timing_list, dtype=torch.float32, device=get_device_id()) + torch.distributed.all_reduce(timing_list, op=torch.distributed.ReduceOp.AVG) + timing_list = [tensor.item() for tensor in timing_list.to("cpu")] + timing_generate = {key_list[i]: timing_list[i] for i in range(len(key_list))} + return timing_generate diff --git a/verl/utils/debug/profile.py b/verl/utils/profiler/profile.py similarity index 84% rename from verl/utils/debug/profile.py rename to verl/utils/profiler/profile.py index 7a97037ed13..28dd77d146e 100644 --- a/verl/utils/debug/profile.py +++ b/verl/utils/profiler/profile.py @@ -13,12 +13,13 @@ # limitations under the License. import os -from dataclasses import dataclass from typing import Callable, Optional import torch import torch.distributed +from .config import ProfilerConfig + class Profiler: """A PyTorch profiler wrapper class for collecting performance metrics. @@ -119,10 +120,23 @@ def mark_start_range( domain: Optional[str] = None, category: Optional[str] = None, ) -> None: + """Start a profiling range marker (no-op implementation). + + Args: + message (Optional[str]): Message to associate with the range marker. + color (Optional[str]): Color for the marker visualization. + domain (Optional[str]): Domain for the marker. + category (Optional[str]): Category for the marker. + """ pass def mark_end_range(range_id: str) -> None: + """End a profiling range marker (no-op implementation). + + Args: + range_id (str): Identifier of the range to end. + """ pass @@ -132,44 +146,22 @@ def mark_annotate( domain: Optional[str] = None, category: Optional[str] = None, ) -> Callable: - def decorator(func): - return func - - return decorator - - -@dataclass -class ProfilerConfig: - """Worker profiler config. Currently only support Nsight system profiler.""" - - # True for each task has its own database, False for all tasks in one training step share one database. - discrete: bool = False - - # Whether to profile all ranks. - all_ranks: bool = False + """Decorator to annotate a function with profiling markers (no-op implementation). - # The ranks that will be profiled. None or [0,1,...] - ranks: Optional[list[int]] = None + Args: + message (Optional[str]): Message to associate with the annotation. + color (Optional[str]): Color for the marker visualization. + domain (Optional[str]): Domain for the marker. + category (Optional[str]): Category for the marker. - def union(self, other: "ProfilerConfig") -> "ProfilerConfig": - return ProfilerConfig( - all_ranks=self.all_ranks or other.all_ranks, - ranks=list(set(self.ranks or []) | set(other.ranks or [])), - discrete=self.discrete or other.discrete, - ) + Returns: + Callable: Decorator function that returns the original function unchanged. + """ - def intersect(self, other: "ProfilerConfig") -> "ProfilerConfig": - return ProfilerConfig( - all_ranks=self.all_ranks and other.all_ranks, - ranks=list(set(self.ranks or []) & set(other.ranks or [])), - discrete=self.discrete and other.discrete, - ) + def decorator(func): + return func - def __post_init__(self) -> None: - """config validation logics go here""" - if self.ranks is None: - self.ranks = [] - assert isinstance(self.ranks, (set, list, tuple)) + return decorator class DistProfiler: diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index b3556992962..e9b365ad461 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -29,9 +29,9 @@ import verl.utils.torch_functional as verl_F from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty -from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import logprobs_from_logits diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 0891ef1fe86..efc0b430981 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -39,12 +39,12 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty -from verl.utils.debug import GPUMemoryLogger -from verl.utils.debug.profile import Profiler from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits from verl.utils.megatron_utils import get_model_config +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.profiler.profile import Profiler from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index b96c7da5fa8..ac777584288 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -26,9 +26,9 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import masked_mean diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index d22fc9e10a3..1d44a887614 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -31,9 +31,9 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator +from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor, masked_mean diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 743bc2d0507..b3039ca62d2 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -38,11 +38,9 @@ from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register -from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass +from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.debug import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer -from verl.utils.debug.performance import reduce_timing from verl.utils.device import ( get_device_id, get_device_name, @@ -70,6 +68,8 @@ ) from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask +from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer +from verl.utils.profiler.performance import reduce_timing from verl.utils.py_functional import convert_to_regular_types from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager @@ -150,11 +150,11 @@ def __init__(self, config: DictConfig, role: str): profiler_config: Optional[ProfilerConfig] = None if self._is_actor: - profiler_config = omega_conf_to_dataclass(config.actor.get("profiler", {}), ProfilerConfig) + profiler_config = config.actor.get("profiler", {}) if self._is_rollout: - profiler_config = omega_conf_to_dataclass(config.rollout.get("profiler", {}), ProfilerConfig) + profiler_config = config.rollout.get("profiler", {}) if self._is_ref: - profiler_config = omega_conf_to_dataclass(config.ref.get("profiler", {}), ProfilerConfig) + profiler_config = config.ref.get("profiler", {}) DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) @@ -915,8 +915,7 @@ def stop_profile(self) -> None: class CriticWorker(Worker, DistProfilerExtension): def __init__(self, config): Worker.__init__(self) - profiler_config = omega_conf_to_dataclass(config.get("profiler", {}), ProfilerConfig) - DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler", {}))) import torch.distributed if not torch.distributed.is_initialized(): @@ -1301,8 +1300,7 @@ class RewardModelWorker(Worker, DistProfilerExtension): def __init__(self, config): Worker.__init__(self) - profiler_config = omega_conf_to_dataclass(config.get("profiler", {}), ProfilerConfig) - DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler", {}))) import torch.distributed diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 190e493c88d..277ebb9733f 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -19,7 +19,7 @@ import logging import os import time -from typing import Union +from typing import Optional, Union import psutil import torch @@ -31,17 +31,8 @@ from verl import DataProto from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.base.megatron.worker import MegatronWorker -from verl.utils import hf_tokenizer, omega_conf_to_dataclass +from verl.utils import hf_tokenizer from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager -from verl.utils.debug import ( - DistProfiler, - DistProfilerExtension, - GPUMemoryLogger, - ProfilerConfig, - log_gpu_memory_usage, - simple_timer, -) -from verl.utils.debug.performance import reduce_timing from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local @@ -52,6 +43,15 @@ offload_megatron_optimizer, ) from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + GPUMemoryLogger, + ProfilerConfig, + log_gpu_memory_usage, + simple_timer, +) +from verl.utils.profiler.performance import reduce_timing from verl.workers.actor.megatron_actor import MegatronPPOActor from verl.workers.critic.megatron_critic import MegatronPPOCritic from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel @@ -127,19 +127,13 @@ def __init__(self, config: DictConfig, role: str): self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] - profiler_config = ProfilerConfig() + profiler_config: Optional[ProfilerConfig] = None if self._is_actor: - profiler_config = profiler_config.union( - ProfilerConfig(**OmegaConf.to_object(config.actor.get("profiler", DictConfig({})))) - ) + profiler_config = config.actor.get("profiler", {}) if self._is_rollout: - profiler_config = profiler_config.union( - ProfilerConfig(**OmegaConf.to_object(config.rollout.get("profiler", DictConfig({})))) - ) + profiler_config = config.rollout.get("profiler", {}) if self._is_ref: - profiler_config = profiler_config.union( - ProfilerConfig(**OmegaConf.to_object(config.ref.get("profiler", DictConfig({})))) - ) + profiler_config = config.ref.get("profiler", {}) DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) @@ -399,8 +393,6 @@ def init_model(self): importlib.import_module(self.config.model.external_lib) - from omegaconf import OmegaConf - from verl.utils.torch_dtypes import PrecisionType override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) @@ -718,8 +710,7 @@ async def sleep(self): class CriticWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config): MegatronWorker.__init__(self) - profiler_config = omega_conf_to_dataclass(config.get("profiler", {}), ProfilerConfig) - DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler", {}))) self.config = config # NOTE(sgm): We utilize colocate WorkerGroup by default. @@ -854,7 +845,6 @@ def megatron_critic_model_provider(pre_process, post_process): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # create critic - from omegaconf import OmegaConf from verl.utils.torch_dtypes import PrecisionType @@ -994,8 +984,7 @@ class RewardModelWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config): MegatronWorker.__init__(self) - profiler_config = omega_conf_to_dataclass(config.get("profiler", {}), ProfilerConfig) - DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=config.get("profiler", {}))) self.config = config # NOTE(sgm): We utilize colocate WorkerGroup by default. @@ -1103,7 +1092,6 @@ def megatron_rm_model_provider(pre_process, post_process): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # create critic - from omegaconf import OmegaConf from verl.utils.torch_dtypes import PrecisionType diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index c558afa61fa..939e9c52463 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -60,8 +60,8 @@ from verl.tools.base_tool import BaseTool from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall from verl.tools.utils.tool_registry import initialize_tools_from_config -from verl.utils.debug import GPUMemoryLogger from verl.utils.net_utils import is_ipv6 +from verl.utils.profiler import GPUMemoryLogger from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.schemas import ( diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 093df8a7007..b67fd9eb28d 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -51,7 +51,7 @@ from vllm.worker.worker_base import WorkerWrapperBase from verl import DataProto -from verl.utils.debug import GPUMemoryLogger +from verl.utils.profiler import GPUMemoryLogger from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length from verl.workers.rollout.base import BaseRollout diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 02cec3595a9..26094204fdc 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -30,10 +30,10 @@ from verl import DataProto from verl.protocol import all_gather_data_proto -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.device import get_device_id, get_torch_device from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.model import convert_weight_keys +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.torch_functional import check_device_is_available from .base import BaseShardingManager diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index efe70d41ee5..e0bbba5981c 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -34,7 +34,6 @@ from verl.protocol import all_gather_data_proto from verl.third_party.vllm import LLM from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.device import get_device_id, get_device_name, get_torch_device from verl.utils.fsdp_utils import ( fsdp_version, @@ -43,6 +42,7 @@ offload_fsdp_model_to_cpu, ) from verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.torch_functional import check_device_is_available from verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py index ba55a202336..61cf8aed4da 100644 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ b/verl/workers/sharding_manager/megatron_sglang.py @@ -27,9 +27,9 @@ from torch.distributed.device_mesh import DeviceMesh from verl.protocol import DataProto, all_gather_data_proto -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.device import get_torch_device from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from .base import BaseShardingManager diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index b3621f803fa..b04352c249e 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -30,10 +30,10 @@ from verl.protocol import all_gather_data_proto from verl.third_party.vllm import LLM from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage -from verl.utils.debug.performance import simple_timer from verl.utils.device import get_torch_device from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator +from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage +from verl.utils.profiler.performance import simple_timer from verl.utils.torch_functional import check_device_is_available from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader