Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[rollout] feat: add custom sampling parameters for rollout generation
Adds sampling_kwargs field to RolloutConfig and SamplingConfig, allowing
users to pass arbitrary sampling parameters at runtime without modifying
hardcoded configurations. This enables experimental flexibility for
testing custom sampling strategies.

Changes:
- Add sampling_kwargs: Optional[dict[str, Any]] to config dataclasses
- Apply custom kwargs in agent loop workers with proper OmegaConf handling
- Update rollout.yaml documentation
- Backward compatible (defaults to None)
  • Loading branch information
guillemgt authored and Guillem Tarrach committed Jan 30, 2026
commit 6ae1d68a92f17cb16377fb6de71d5e30af7cf97c
6 changes: 6 additions & 0 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,12 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
sampling_params["top_p"] = config.val_kwargs.top_p
sampling_params["top_k"] = config.val_kwargs.top_k
sampling_params["temperature"] = config.val_kwargs.temperature
if config.val_kwargs.sampling_kwargs is not None:
sampling_params = (
OmegaConf.to_container(config.val_kwargs.sampling_kwargs, resolve=True) | sampling_params
)
elif config.sampling_kwargs is not None:
sampling_params = OmegaConf.to_container(config.sampling_kwargs, resolve=True) | sampling_params

# by default, we assume it's a single turn agent
if "agent_name" not in batch.non_tensor_batch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import hydra
import numpy as np
import ray
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

from verl.experimental.agent_loop.agent_loop import (
AgentLoopManager,
Expand Down Expand Up @@ -109,6 +109,12 @@ async def generate_sequences_no_post(
if batch.meta_info.get("validate", False):
sampling_params["top_p"] = config.val_kwargs.top_p
sampling_params["temperature"] = config.val_kwargs.temperature
if config.val_kwargs.sampling_kwargs is not None:
sampling_params = (
OmegaConf.to_container(config.val_kwargs.sampling_kwargs, resolve=True) | sampling_params
)
elif config.sampling_kwargs is not None:
sampling_params = OmegaConf.to_container(config.sampling_kwargs, resolve=True) | sampling_params

if "agent_name" not in batch.non_tensor_batch:
default_agent_loop = config.agent.default_agent_loop
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ actor_rollout_ref:
temperature: 1.0
top_k: -1
top_p: 1
sampling_kwargs: null
prompt_length: ${oc.select:data.max_prompt_length,512}
response_length: ${oc.select:data.max_response_length,512}
dtype: bfloat16
Expand Down Expand Up @@ -242,6 +243,7 @@ actor_rollout_ref:
top_k: -1
top_p: 1.0
temperature: 0
sampling_kwargs: null
'n': 1
do_sample: false
multi_turn:
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ actor_rollout_ref:
temperature: 1.0
top_k: -1
top_p: 1
sampling_kwargs: null
prompt_length: ${oc.select:data.max_prompt_length,512}
response_length: ${oc.select:data.max_response_length,512}
dtype: bfloat16
Expand Down Expand Up @@ -233,6 +234,7 @@ actor_rollout_ref:
top_k: -1
top_p: 1.0
temperature: 0
sampling_kwargs: null
'n': 1
do_sample: false
multi_turn:
Expand Down
6 changes: 6 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ top_k: -1
# Top-p sampling parameter. Default 1.0.
top_p: 1

# Custom sampling kwargs: Default null (no kwargs)
sampling_kwargs: null

# typically the same as data max prompt length
# same as data.max_prompt_length if it exists
prompt_length: ${oc.select:data.max_prompt_length,512}
Expand Down Expand Up @@ -144,6 +147,9 @@ val_kwargs:
# Sampling temperature for rollout.
temperature: 0

# Custom sampling kwargs: Default null (no kwargs)
sampling_kwargs: null

# whether to repeat n times for validation
n: 1

Expand Down
4 changes: 3 additions & 1 deletion verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from typing import Optional
from typing import Any, Optional

from omegaconf import MISSING

Expand Down Expand Up @@ -41,6 +41,7 @@ class SamplingConfig(BaseConfig):
top_p: float = 1.0
do_sample: bool = True
n: int = 1
sampling_kwargs: Optional[dict[str, Any]] = None


@dataclass
Expand Down Expand Up @@ -145,6 +146,7 @@ class RolloutConfig(BaseConfig):
do_sample: bool = True
n: int = 1
repetition_penalty: float = 1.0
sampling_kwargs: Optional[dict[str, Any]] = None

# Early termination threshold for multi-turn rollout in sglang.
# Abort remaining requests when (1 - over_sample_rate) * total_requests are completed.
Expand Down
Loading