|
| 1 | +Interaction System for Multi-turn RL Training |
| 2 | +============================================= |
| 3 | + |
| 4 | +Overview |
| 5 | +-------- |
| 6 | + |
| 7 | +The verl interaction system enables dynamic, multi-turn conversational feedback during reinforcement learning training. This system allows models to engage in iterative problem-solving scenarios where an interaction agent can provide corrective feedback, guidance, or evaluation based on the model's responses. |
| 8 | + |
| 9 | +Key features: |
| 10 | + |
| 11 | +- **Async-based Architecture**: Non-blocking interaction processing for distributed training |
| 12 | +- **Instance Management**: Stateful session handling with unique instance IDs for concurrent interactions |
| 13 | +- **SGLang Integration**: Seamless integration with SGLang rollout system for multi-turn conversations |
| 14 | +- **Configuration-driven**: Dynamic agent loading via YAML configuration files |
| 15 | +- **Reward Integration**: Turn-level scoring mechanism integrated with verl's reward system |
| 16 | + |
| 17 | +Architecture |
| 18 | +------------ |
| 19 | + |
| 20 | +The interaction system follows a plugin-based architecture with clear separation of concerns: |
| 21 | + |
| 22 | +.. code-block:: |
| 23 | +
|
| 24 | + BaseInteraction (Abstract Interface) |
| 25 | + ↓ |
| 26 | + Gsm8kInteraction (Concrete Implementation) |
| 27 | + ↓ |
| 28 | + SGLang Rollout Integration |
| 29 | + ↓ |
| 30 | + Async Request Lifecycle Management |
| 31 | +
|
| 32 | +Core Components |
| 33 | +~~~~~~~~~~~~~~~ |
| 34 | + |
| 35 | +**BaseInteraction Interface** |
| 36 | + |
| 37 | +All interaction agents must implement the ``BaseInteraction`` abstract class: |
| 38 | + |
| 39 | +.. code-block:: python |
| 40 | +
|
| 41 | + from verl.interactions.base import BaseInteraction |
| 42 | + from typing import Dict, Any, List, Tuple, Optional |
| 43 | +
|
| 44 | + class BaseInteraction: |
| 45 | + async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: |
| 46 | + """Initialize interaction session, return instance_id""" |
| 47 | + |
| 48 | + async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[bool, str, float, Dict[str, Any]]: |
| 49 | + """Generate response, return (should_terminate, response, score, metadata)""" |
| 50 | + |
| 51 | + async def calculate_score(self, instance_id: str, **kwargs) -> float: |
| 52 | + """Calculate turn-level score for RL training""" |
| 53 | + |
| 54 | + async def finalize_interaction(self, instance_id: str, **kwargs) -> None: |
| 55 | + """Clean up resources""" |
| 56 | +
|
| 57 | +**Request Lifecycle** |
| 58 | + |
| 59 | +The interaction system integrates with SGLang's async rollout via state management: |
| 60 | + |
| 61 | +1. ``PENDING`` → Initialize interaction via ``start_interaction()`` |
| 62 | +2. ``GENERATING`` → Model generates response |
| 63 | +3. ``INTERACTING`` → Process response via ``generate_response()`` |
| 64 | +4. ``GENERATING`` → Continue if not terminated, otherwise ``COMPLETED`` |
| 65 | + |
| 66 | +Configuration |
| 67 | +------------- |
| 68 | + |
| 69 | +**Basic Setup** |
| 70 | + |
| 71 | +Enable interaction in your rollout configuration: |
| 72 | + |
| 73 | +.. code-block:: yaml |
| 74 | +
|
| 75 | + actor_rollout_ref: |
| 76 | + rollout: |
| 77 | + multi_turn: |
| 78 | + enable: true |
| 79 | + interaction_config_path: "path/to/interaction_config.yaml" |
| 80 | + max_user_turns: 10 |
| 81 | + max_assistant_turns: 10 |
| 82 | +
|
| 83 | +**Interaction Configuration File** |
| 84 | + |
| 85 | +Create an interaction configuration file (e.g., ``gsm8k_interaction_config.yaml``): |
| 86 | + |
| 87 | +.. code-block:: yaml |
| 88 | +
|
| 89 | + interaction: |
| 90 | + - class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" |
| 91 | + config: {} |
| 92 | +
|
| 93 | +The system will dynamically load the specified interaction class using importlib. |
| 94 | + |
| 95 | +Implementation Example: GSM8K |
| 96 | +----------------------------- |
| 97 | + |
| 98 | +The GSM8K interaction demonstrates a complete implementation for math problem-solving scenarios: |
| 99 | + |
| 100 | +.. code-block:: python |
| 101 | +
|
| 102 | + from verl.interactions.base import BaseInteraction |
| 103 | + from verl.utils.reward_score import gsm8k |
| 104 | + from uuid import uuid4 |
| 105 | +
|
| 106 | + class Gsm8kInteraction(BaseInteraction): |
| 107 | + def __init__(self, config: dict): |
| 108 | + super().__init__(config) |
| 109 | + self._instance_dict = {} |
| 110 | +
|
| 111 | + async def start_interaction(self, instance_id=None, ground_truth=None, **kwargs): |
| 112 | + if instance_id is None: |
| 113 | + instance_id = str(uuid4()) |
| 114 | + self._instance_dict[instance_id] = { |
| 115 | + "response": "", |
| 116 | + "ground_truth": ground_truth, |
| 117 | + "reward": 0.0, |
| 118 | + } |
| 119 | + return instance_id |
| 120 | +
|
| 121 | + async def generate_response(self, instance_id, messages, **kwargs): |
| 122 | + # Extract last user message content |
| 123 | + content = "" |
| 124 | + for item in reversed(messages): |
| 125 | + if item.get("role") == "user": |
| 126 | + content = item.get("content", "") |
| 127 | + break |
| 128 | +
|
| 129 | + # Ensure GSM8K format (#### prefix) |
| 130 | + if content.startswith("#### "): |
| 131 | + self._instance_dict[instance_id]["response"] = content |
| 132 | + else: |
| 133 | + self._instance_dict[instance_id]["response"] = "#### " + content |
| 134 | +
|
| 135 | + reward = await self.calculate_score(instance_id) |
| 136 | + if reward == 1.0: |
| 137 | + return True, "Your response is correct!", 1.0, {} |
| 138 | + else: |
| 139 | + return False, "Your response is incorrect! You need to reflect on your answer and try again.", 0.0, {} |
| 140 | +
|
| 141 | + async def calculate_score(self, instance_id, **kwargs): |
| 142 | + return gsm8k.compute_score( |
| 143 | + self._instance_dict[instance_id]["response"], |
| 144 | + self._instance_dict[instance_id]["ground_truth"], |
| 145 | + method="flexible", format_score=0.0, score=1.0, |
| 146 | + ) |
| 147 | +
|
| 148 | + async def finalize_interaction(self, instance_id, **kwargs): |
| 149 | + del self._instance_dict[instance_id] |
| 150 | +
|
| 151 | +Training Integration |
| 152 | +-------------------- |
| 153 | + |
| 154 | +**Training Script Configuration** |
| 155 | + |
| 156 | +Include interaction configuration in your training command: |
| 157 | + |
| 158 | +.. code-block:: bash |
| 159 | +
|
| 160 | + python3 -m verl.trainer.main_ppo \\ |
| 161 | + --config-path="$CONFIG_PATH" \\ |
| 162 | + --config-name='gsm8k_multiturn_grpo_w_interaction' \\ |
| 163 | + algorithm.adv_estimator=grpo \\ |
| 164 | + data.train_batch_size=512 \\ |
| 165 | + data.return_raw_chat=True \\ |
| 166 | + actor_rollout_ref.rollout.name=sglang \\ |
| 167 | + actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \\ |
| 168 | + trainer.total_epochs=15 |
| 169 | +
|
| 170 | +**Data Requirements** |
| 171 | + |
| 172 | +Ensure your dataset includes interaction parameters: |
| 173 | + |
| 174 | +.. code-block:: python |
| 175 | +
|
| 176 | + # Dataset should include interaction_kwargs in non_tensor_batch |
| 177 | + interaction_kwargs = [ |
| 178 | + {"query": "What is 2+2?", "ground_truth": "4"}, |
| 179 | + {"query": "What is 3+3?", "ground_truth": "6"}, |
| 180 | + ] |
| 181 | +
|
| 182 | +Best Practices |
| 183 | +-------------- |
| 184 | + |
| 185 | +**Resource Management** |
| 186 | + |
| 187 | +- Always implement proper cleanup in ``finalize_interaction()`` |
| 188 | +- Use unique instance IDs to avoid conflicts in concurrent training |
| 189 | +- Handle edge cases like empty messages or malformed content |
| 190 | + |
| 191 | +**Performance Optimization** |
| 192 | + |
| 193 | +- Keep interaction logic lightweight to avoid blocking training |
| 194 | +- Use async/await properly to maintain non-blocking behavior |
| 195 | +- Consider caching expensive computations within interaction instances |
| 196 | + |
| 197 | +**Testing** |
| 198 | + |
| 199 | +Comprehensive testing is essential for interaction systems: |
| 200 | + |
| 201 | +.. code-block:: python |
| 202 | +
|
| 203 | + import pytest |
| 204 | + from unittest.mock import patch |
| 205 | +
|
| 206 | + @pytest.mark.asyncio |
| 207 | + async def test_interaction_workflow(): |
| 208 | + interaction = YourInteraction({}) |
| 209 | + |
| 210 | + # Test complete workflow |
| 211 | + instance_id = await interaction.start_interaction(ground_truth="expected_answer") |
| 212 | + |
| 213 | + messages = [{"role": "user", "content": "user_response"}] |
| 214 | + should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages) |
| 215 | + |
| 216 | + assert should_terminate in [True, False] |
| 217 | + assert isinstance(reward, float) |
| 218 | + |
| 219 | + await interaction.finalize_interaction(instance_id) |
| 220 | +
|
| 221 | +Advanced Usage |
| 222 | +-------------- |
| 223 | + |
| 224 | +**Custom Scoring Functions** |
| 225 | + |
| 226 | +You can integrate custom reward functions: |
| 227 | + |
| 228 | +.. code-block:: python |
| 229 | +
|
| 230 | + async def calculate_score(self, instance_id, **kwargs): |
| 231 | + response = self._instance_dict[instance_id]["response"] |
| 232 | + ground_truth = self._instance_dict[instance_id]["ground_truth"] |
| 233 | + |
| 234 | + # Custom evaluation logic |
| 235 | + if custom_evaluation_function(response, ground_truth): |
| 236 | + return 1.0 |
| 237 | + else: |
| 238 | + return 0.0 |
| 239 | +
|
| 240 | +**Multi-step Interactions** |
| 241 | + |
| 242 | +For complex scenarios requiring multiple feedback rounds: |
| 243 | + |
| 244 | +.. code-block:: python |
| 245 | +
|
| 246 | + async def generate_response(self, instance_id, messages, **kwargs): |
| 247 | + instance = self._instance_dict[instance_id] |
| 248 | + instance["attempts"] += 1 |
| 249 | + |
| 250 | + # Evaluate current response |
| 251 | + reward = await self.calculate_score(instance_id) |
| 252 | + |
| 253 | + if reward > 0.8: |
| 254 | + return True, "Excellent work!", reward, {} |
| 255 | + elif instance["attempts"] < 3: |
| 256 | + return False, "Good attempt, but try to improve...", reward, {} |
| 257 | + else: |
| 258 | + return True, "Maximum attempts reached.", reward, {} |
| 259 | +
|
| 260 | +Troubleshooting |
| 261 | +--------------- |
| 262 | + |
| 263 | +**Common Issues** |
| 264 | + |
| 265 | +1. **Instance ID Conflicts**: Ensure unique instance IDs across concurrent sessions |
| 266 | +2. **Memory Leaks**: Always call ``finalize_interaction()`` to clean up resources |
| 267 | +3. **Blocking Operations**: Keep interaction logic async and non-blocking |
| 268 | +4. **Configuration Errors**: Verify interaction config path and class name are correct |
| 269 | + |
| 270 | +**Debugging** |
| 271 | + |
| 272 | +Enable debug logging to trace interaction flow: |
| 273 | + |
| 274 | +.. code-block:: bash |
| 275 | +
|
| 276 | + export VERL_LOGGING_LEVEL=DEBUG |
| 277 | +
|
| 278 | +**Performance Monitoring** |
| 279 | + |
| 280 | +Monitor interaction performance impact on training throughput and adjust accordingly. |
| 281 | + |
| 282 | +Related Documentation |
| 283 | +-------------------- |
| 284 | + |
| 285 | +- :doc:`multiturn`: Basic multi-turn rollout configuration |
| 286 | +- :doc:`sandbox_fusion`: Tool integration with SGLang |
| 287 | +- :doc:`search_tool_example`: Search tool implementation example |
0 commit comments