Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/workers/config/test_model_config_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright Amazon.com and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pprint import pp

from omegaconf import OmegaConf

from verl.workers.config.model import HFModelConfig


def test_target_modules_accepts_list_via_omegaconf():
"""
Test that target_modules field accepts both string and list values
when merging OmegaConf configs (simulates CLI override behavior).

The purpose is to ensure we can pass
actor_rollout_ref.model.target_modules='["k_proj","o_proj","down_proj","q_proj"]'
"""
model_path = "~/models/Qwen/Qwen2.5-0.5B" # Just a path string, not loaded

# Create structured config from the dataclass defaults
# This is what omega_conf_to_dataclass does internally
cfg_from_dataclass = OmegaConf.structured(HFModelConfig)

pp("{cfg_from_dataclass=}")

# Simulate CLI override with target_modules as a list
cli_config = OmegaConf.create(
{
"path": model_path,
"target_modules": ["k_proj", "o_proj", "q_proj", "v_proj"],
}
)

pp("{cli_config=}")

# This merge should NOT raise ValidationError
# Before the fix (target_modules: str), this would fail with:
# "Cannot convert 'ListConfig' to string"
merged = OmegaConf.merge(cfg_from_dataclass, cli_config)

# Verify the list was merged correctly
assert list(merged.target_modules) == ["k_proj", "o_proj", "q_proj", "v_proj"]


def test_target_modules_accepts_string_via_omegaconf():
"""Test that target_modules still accepts string values."""
cfg_from_dataclass = OmegaConf.structured(HFModelConfig)

cli_config = OmegaConf.create(
{
"path": "~/models/some-model",
"target_modules": "all-linear",
}
)

merged = OmegaConf.merge(cfg_from_dataclass, cli_config)
assert merged.target_modules == "all-linear"
10 changes: 9 additions & 1 deletion verl/workers/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class HFModelConfig(BaseConfig):
# fsdp lora related. We may setup a separate config later
lora_rank: int = 0
lora_alpha: int = 16
target_modules: Optional[str] = "all-linear"
target_modules: Optional[Any] = "all-linear" # allow both "all-linear" and ["q_proj","k_proj"]
target_parameters: Optional[list[str]] = None # for lora adapter on nn.Parameter

exclude_modules: Optional[str] = None
Expand Down Expand Up @@ -204,5 +204,13 @@ def __post_init__(self):
if getattr(self.hf_config, "model_type", None) == "kimi_vl":
self.hf_config.text_config.topk_method = "greedy"

# Ensure target_modules is a str or list[str]
if self.target_modules is None:
self.target_modules = "all-linear"
assert isinstance(self.target_modules, (str | list))
if isinstance(self.target_modules, list):
for x in self.target_modules:
assert isinstance(x, str)

def get_processor(self):
return self.processor if self.processor is not None else self.tokenizer