|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from typing import TYPE_CHECKING, Annotated, Literal, cast |
| 5 | +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast |
6 | 6 |
|
7 | 7 | if TYPE_CHECKING: |
8 | 8 | from collections.abc import Awaitable, Callable |
9 | 9 |
|
10 | | -from langchain_core.messages import SystemMessage, ToolMessage |
| 10 | + from langgraph.runtime import Runtime |
| 11 | + |
| 12 | +from langchain_core.messages import AIMessage, SystemMessage, ToolMessage |
11 | 13 | from langchain_core.tools import tool |
12 | 14 | from langgraph.types import Command |
13 | 15 | from typing_extensions import NotRequired, TypedDict |
@@ -135,7 +137,9 @@ class TodoListMiddleware(AgentMiddleware): |
135 | 137 | into task completion status. |
136 | 138 |
|
137 | 139 | The middleware automatically injects system prompts that guide the agent on when |
138 | | - and how to use the todo functionality effectively. |
| 140 | + and how to use the todo functionality effectively. It also enforces that the |
| 141 | + `write_todos` tool is called at most once per model turn, since the tool replaces |
| 142 | + the entire todo list and parallel calls would create ambiguity about precedence. |
139 | 143 |
|
140 | 144 | Example: |
141 | 145 | ```python |
@@ -222,3 +226,79 @@ async def awrap_model_call( |
222 | 226 | content=cast("list[str | dict[str, str]]", new_system_content) |
223 | 227 | ) |
224 | 228 | return await handler(request.override(system_message=new_system_message)) |
| 229 | + |
| 230 | + def after_model( |
| 231 | + self, |
| 232 | + state: AgentState, |
| 233 | + runtime: Runtime, # noqa: ARG002 |
| 234 | + ) -> dict[str, Any] | None: |
| 235 | + """Check for parallel write_todos tool calls and return errors if detected. |
| 236 | +
|
| 237 | + The todo list is designed to be updated at most once per model turn. Since |
| 238 | + the `write_todos` tool replaces the entire todo list with each call, making |
| 239 | + multiple parallel calls would create ambiguity about which update should take |
| 240 | + precedence. This method prevents such conflicts by rejecting any response that |
| 241 | + contains multiple write_todos tool calls. |
| 242 | +
|
| 243 | + Args: |
| 244 | + state: The current agent state containing messages. |
| 245 | + runtime: The LangGraph runtime instance. |
| 246 | +
|
| 247 | + Returns: |
| 248 | + A dict containing error ToolMessages for each write_todos call if multiple |
| 249 | + parallel calls are detected, otherwise None to allow normal execution. |
| 250 | + """ |
| 251 | + messages = state["messages"] |
| 252 | + if not messages: |
| 253 | + return None |
| 254 | + |
| 255 | + last_ai_msg = next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None) |
| 256 | + if not last_ai_msg or not last_ai_msg.tool_calls: |
| 257 | + return None |
| 258 | + |
| 259 | + # Count write_todos tool calls |
| 260 | + write_todos_calls = [tc for tc in last_ai_msg.tool_calls if tc["name"] == "write_todos"] |
| 261 | + |
| 262 | + if len(write_todos_calls) > 1: |
| 263 | + # Create error tool messages for all write_todos calls |
| 264 | + error_messages = [ |
| 265 | + ToolMessage( |
| 266 | + content=( |
| 267 | + "Error: The `write_todos` tool should never be called multiple times " |
| 268 | + "in parallel. Please call it only once per model invocation to update " |
| 269 | + "the todo list." |
| 270 | + ), |
| 271 | + tool_call_id=tc["id"], |
| 272 | + status="error", |
| 273 | + ) |
| 274 | + for tc in write_todos_calls |
| 275 | + ] |
| 276 | + |
| 277 | + # Keep the tool calls in the AI message but return error messages |
| 278 | + # This follows the same pattern as HumanInTheLoopMiddleware |
| 279 | + return {"messages": error_messages} |
| 280 | + |
| 281 | + return None |
| 282 | + |
| 283 | + async def aafter_model( |
| 284 | + self, |
| 285 | + state: AgentState, |
| 286 | + runtime: Runtime, |
| 287 | + ) -> dict[str, Any] | None: |
| 288 | + """Check for parallel write_todos tool calls and return errors if detected. |
| 289 | +
|
| 290 | + Async version of `after_model`. The todo list is designed to be updated at |
| 291 | + most once per model turn. Since the `write_todos` tool replaces the entire |
| 292 | + todo list with each call, making multiple parallel calls would create ambiguity |
| 293 | + about which update should take precedence. This method prevents such conflicts |
| 294 | + by rejecting any response that contains multiple write_todos tool calls. |
| 295 | +
|
| 296 | + Args: |
| 297 | + state: The current agent state containing messages. |
| 298 | + runtime: The LangGraph runtime instance. |
| 299 | +
|
| 300 | + Returns: |
| 301 | + A dict containing error ToolMessages for each write_todos call if multiple |
| 302 | + parallel calls are detected, otherwise None to allow normal execution. |
| 303 | + """ |
| 304 | + return self.after_model(state, runtime) |
0 commit comments