Skip to content

Commit efa88f7

Browse files
authored
Wait for input guardrails in streaming runs (openai#1730)
1 parent 581111c commit efa88f7

File tree

13 files changed

+87
-23
lines changed

13 files changed

+87
-23
lines changed

examples/basic/usage_tracking.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,3 @@ async def main() -> None:
4343

4444
if __name__ == "__main__":
4545
asyncio.run(main())
46-

examples/model_providers/litellm_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
# import logging
1616
# logging.basicConfig(level=logging.DEBUG)
1717

18+
1819
@function_tool
1920
def get_weather(city: str):
2021
print(f"[debug] getting weather for {city}")
2122
return f"The weather in {city} is sunny."
2223

24+
2325
class Result(BaseModel):
2426
output_text: str
2527
tool_results: list[str]

examples/realtime/app/server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
198198
{"type": "input_text", "text": prompt_text},
199199
]
200200
if prompt_text
201-
else [
202-
{"type": "input_image", "image_url": data_url, "detail": "high"}
203-
]
201+
else [{"type": "input_image", "image_url": data_url, "detail": "high"}]
204202
),
205203
}
206204
await manager.send_user_message(session_id, user_msg)
@@ -271,7 +269,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
271269
"role": "user",
272270
"content": (
273271
[
274-
{"type": "input_image", "image_url": data_url, "detail": "high"},
272+
{
273+
"type": "input_image",
274+
"image_url": data_url,
275+
"detail": "high",
276+
},
275277
{"type": "input_text", "text": prompt_text},
276278
]
277279
if prompt_text

examples/realtime/cli/demo.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
FORMAT = np.int16
2424
CHANNELS = 1
2525
ENERGY_THRESHOLD = 0.015 # RMS threshold for barge‑in while assistant is speaking
26-
PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks)
27-
FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting
26+
PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks)
27+
FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting
2828

2929
# Set up logging for OpenAI agents SDK
3030
# logging.basicConfig(
@@ -108,14 +108,18 @@ def _output_callback(self, outdata, frames: int, time, status) -> None:
108108

109109
samples, item_id, content_index = self.current_audio_chunk
110110
samples_filled = 0
111-
while samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples:
111+
while (
112+
samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples
113+
):
112114
remaining_output = len(outdata) - samples_filled
113115
remaining_fade = self.fade_total_samples - self.fade_done_samples
114116
n = min(remaining_output, remaining_fade)
115117

116118
src = samples[self.chunk_position : self.chunk_position + n].astype(np.float32)
117119
# Linear ramp from current level down to 0 across remaining fade samples
118-
idx = np.arange(self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32)
120+
idx = np.arange(
121+
self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32
122+
)
119123
gain = 1.0 - (idx / float(self.fade_total_samples))
120124
ramped = np.clip(src * gain, -32768.0, 32767.0).astype(np.int16)
121125
outdata[samples_filled : samples_filled + n, 0] = ramped
@@ -155,7 +159,10 @@ def _output_callback(self, outdata, frames: int, time, status) -> None:
155159
if self.current_audio_chunk is None:
156160
try:
157161
# Respect a small jitter buffer before starting playback
158-
if self.prebuffering and self.output_queue.qsize() < self.prebuffer_target_chunks:
162+
if (
163+
self.prebuffering
164+
and self.output_queue.qsize() < self.prebuffer_target_chunks
165+
):
159166
break
160167
self.prebuffering = False
161168
self.current_audio_chunk = self.output_queue.get_nowait()

src/agents/extensions/memory/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
21
"""Session memory backends living in the extensions namespace.
32
43
This package contains optional, production-grade session implementations that
54
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
65
conform to the :class:`agents.memory.session.Session` protocol so they can be
76
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
87
"""
8+
99
from __future__ import annotations
1010

1111
from .sqlalchemy_session import SQLAlchemySession # noqa: F401

src/agents/extensions/models/litellm_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ def convert_message_to_openai(
413413
else:
414414
# Convert object to dict by accessing its attributes
415415
block_dict: dict[str, Any] = {}
416-
if hasattr(block, '__dict__'):
416+
if hasattr(block, "__dict__"):
417417
block_dict = dict(block.__dict__.items())
418-
elif hasattr(block, 'model_dump'):
418+
elif hasattr(block, "model_dump"):
419419
block_dict = block.model_dump()
420420
else:
421421
# Last resort: convert to string representation

src/agents/models/chatcmpl_converter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TRespon
106106
# Store thinking blocks in the reasoning item's content
107107
# Convert thinking blocks to Content objects
108108
from openai.types.responses.response_reasoning_item import Content
109+
109110
reasoning_item.content = [
110111
Content(text=str(block.get("thinking", "")), type="reasoning_text")
111112
for block in message.thinking_blocks
@@ -282,9 +283,7 @@ def extract_all_content(
282283
f"Only file_data is supported for input_file {casted_file_param}"
283284
)
284285
if "filename" not in casted_file_param or not casted_file_param["filename"]:
285-
raise UserError(
286-
f"filename must be provided for input_file {casted_file_param}"
287-
)
286+
raise UserError(f"filename must be provided for input_file {casted_file_param}")
288287
out.append(
289288
File(
290289
type="file",

src/agents/realtime/model_events.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class RealtimeModelInputAudioTranscriptionCompletedEvent:
8484

8585
type: Literal["input_audio_transcription_completed"] = "input_audio_transcription_completed"
8686

87+
8788
@dataclass
8889
class RealtimeModelInputAudioTimeoutTriggeredEvent:
8990
"""Input audio timeout triggered."""
@@ -94,6 +95,7 @@ class RealtimeModelInputAudioTimeoutTriggeredEvent:
9495

9596
type: Literal["input_audio_timeout_triggered"] = "input_audio_timeout_triggered"
9697

98+
9799
@dataclass
98100
class RealtimeModelTranscriptDeltaEvent:
99101
"""Partial transcript update."""

src/agents/result.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
201201
break
202202

203203
if isinstance(item, QueueCompleteSentinel):
204+
# Await input guardrails if they are still running, so late exceptions are captured.
205+
await self._await_task_safely(self._input_guardrails_task)
206+
204207
self._event_queue.task_done()
208+
205209
# Check for errors, in case the queue was completed due to an exception
206210
self._check_errors()
207211
break
@@ -274,3 +278,19 @@ def _cleanup_tasks(self):
274278

275279
def __str__(self) -> str:
276280
return pretty_print_run_result_streaming(self)
281+
282+
async def _await_task_safely(self, task: asyncio.Task[Any] | None) -> None:
283+
"""Await a task if present, ignoring cancellation and storing exceptions elsewhere.
284+
285+
This ensures we do not lose late guardrail exceptions while not surfacing
286+
CancelledError to callers of stream_events.
287+
"""
288+
if task and not task.done():
289+
try:
290+
await task
291+
except asyncio.CancelledError:
292+
# Task was cancelled (e.g., due to result.cancel()). Nothing to do here.
293+
pass
294+
except Exception:
295+
# The exception will be surfaced via _check_errors() if needed.
296+
pass

src/agents/run.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,14 +1127,11 @@ async def _run_single_turn_streamed(
11271127

11281128
# Filter out HandoffCallItem to avoid duplicates (already sent earlier)
11291129
items_to_filter = [
1130-
item for item in items_to_filter
1131-
if not isinstance(item, HandoffCallItem)
1130+
item for item in items_to_filter if not isinstance(item, HandoffCallItem)
11321131
]
11331132

11341133
# Create filtered result and send to queue
1135-
filtered_result = _dc.replace(
1136-
single_step_result, new_step_items=items_to_filter
1137-
)
1134+
filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter)
11381135
RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue)
11391136
return single_step_result
11401137

@@ -1235,8 +1232,7 @@ async def _get_single_step_result_from_response(
12351232
# Send handoff items immediately for streaming, but avoid duplicates
12361233
if event_queue is not None and processed_response.new_items:
12371234
handoff_items = [
1238-
item for item in processed_response.new_items
1239-
if isinstance(item, HandoffCallItem)
1235+
item for item in processed_response.new_items if isinstance(item, HandoffCallItem)
12401236
]
12411237
if handoff_items:
12421238
RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue)

0 commit comments

Comments
 (0)