Skip to content

Commit bda22aa

Browse files
hwchase17eyurtsev
andauthored
fix(langchain): handle parallel usage of the todo tool in planning middleware (#34637)
The agent should only make a single call to update the todo list at a time. A parallel call doesn't make sense, but also cannot work as there's no obvious reducer to use. On parallel calls of the todo tool, we return ToolMessage containing to guide the LLM to not call the tool in parallel. --------- Co-authored-by: Eugene Yurtsev <[email protected]>
1 parent 48cd131 commit bda22aa

File tree

2 files changed

+360
-6
lines changed
  • libs/langchain_v1
    • langchain/agents/middleware
    • tests/unit_tests/agents/middleware/implementations

2 files changed

+360
-6
lines changed

libs/langchain_v1/langchain/agents/middleware/todo.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Annotated, Literal, cast
5+
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
66

77
if TYPE_CHECKING:
88
from collections.abc import Awaitable, Callable
99

10-
from langchain_core.messages import SystemMessage, ToolMessage
10+
from langgraph.runtime import Runtime
11+
12+
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
1113
from langchain_core.tools import tool
1214
from langgraph.types import Command
1315
from typing_extensions import NotRequired, TypedDict
@@ -135,7 +137,9 @@ class TodoListMiddleware(AgentMiddleware):
135137
into task completion status.
136138
137139
The middleware automatically injects system prompts that guide the agent on when
138-
and how to use the todo functionality effectively.
140+
and how to use the todo functionality effectively. It also enforces that the
141+
`write_todos` tool is called at most once per model turn, since the tool replaces
142+
the entire todo list and parallel calls would create ambiguity about precedence.
139143
140144
Example:
141145
```python
@@ -222,3 +226,79 @@ async def awrap_model_call(
222226
content=cast("list[str | dict[str, str]]", new_system_content)
223227
)
224228
return await handler(request.override(system_message=new_system_message))
229+
230+
def after_model(
231+
self,
232+
state: AgentState,
233+
runtime: Runtime, # noqa: ARG002
234+
) -> dict[str, Any] | None:
235+
"""Check for parallel write_todos tool calls and return errors if detected.
236+
237+
The todo list is designed to be updated at most once per model turn. Since
238+
the `write_todos` tool replaces the entire todo list with each call, making
239+
multiple parallel calls would create ambiguity about which update should take
240+
precedence. This method prevents such conflicts by rejecting any response that
241+
contains multiple write_todos tool calls.
242+
243+
Args:
244+
state: The current agent state containing messages.
245+
runtime: The LangGraph runtime instance.
246+
247+
Returns:
248+
A dict containing error ToolMessages for each write_todos call if multiple
249+
parallel calls are detected, otherwise None to allow normal execution.
250+
"""
251+
messages = state["messages"]
252+
if not messages:
253+
return None
254+
255+
last_ai_msg = next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None)
256+
if not last_ai_msg or not last_ai_msg.tool_calls:
257+
return None
258+
259+
# Count write_todos tool calls
260+
write_todos_calls = [tc for tc in last_ai_msg.tool_calls if tc["name"] == "write_todos"]
261+
262+
if len(write_todos_calls) > 1:
263+
# Create error tool messages for all write_todos calls
264+
error_messages = [
265+
ToolMessage(
266+
content=(
267+
"Error: The `write_todos` tool should never be called multiple times "
268+
"in parallel. Please call it only once per model invocation to update "
269+
"the todo list."
270+
),
271+
tool_call_id=tc["id"],
272+
status="error",
273+
)
274+
for tc in write_todos_calls
275+
]
276+
277+
# Keep the tool calls in the AI message but return error messages
278+
# This follows the same pattern as HumanInTheLoopMiddleware
279+
return {"messages": error_messages}
280+
281+
return None
282+
283+
async def aafter_model(
284+
self,
285+
state: AgentState,
286+
runtime: Runtime,
287+
) -> dict[str, Any] | None:
288+
"""Check for parallel write_todos tool calls and return errors if detected.
289+
290+
Async version of `after_model`. The todo list is designed to be updated at
291+
most once per model turn. Since the `write_todos` tool replaces the entire
292+
todo list with each call, making multiple parallel calls would create ambiguity
293+
about which update should take precedence. This method prevents such conflicts
294+
by rejecting any response that contains multiple write_todos tool calls.
295+
296+
Args:
297+
state: The current agent state containing messages.
298+
runtime: The LangGraph runtime instance.
299+
300+
Returns:
301+
A dict containing error ToolMessages for each write_todos call if multiple
302+
parallel calls are detected, otherwise None to allow normal execution.
303+
"""
304+
return self.after_model(state, runtime)

0 commit comments

Comments
 (0)