diff --git a/src/anthropic/lib/streaming/_beta_messages.py b/src/anthropic/lib/streaming/_beta_messages.py index 7e6a774..b6241d2 100644 --- a/src/anthropic/lib/streaming/_beta_messages.py +++ b/src/anthropic/lib/streaming/_beta_messages.py @@ -5,11 +5,13 @@ from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never import httpx +from pydantic import BaseModel from ..._utils import consume_sync_iterator, consume_async_iterator from ..._models import build, construct_type from ._beta_types import ( BetaTextEvent, + BetaCitationEvent, BetaInputJsonEvent, BetaMessageStopEvent, BetaMessageStreamEvent, @@ -314,24 +316,40 @@ def build_events( events_to_fire.append(event) content_block = message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content_block.type == "text": - events_to_fire.append( - build( - BetaTextEvent, - type="text", - text=event.delta.text, - snapshot=content_block.text, + if event.delta.type == "text_delta": + if content_block.type == "text": + events_to_fire.append( + build( + BetaTextEvent, + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) ) - ) - elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": - events_to_fire.append( - build( - BetaInputJsonEvent, - type="input_json", - partial_json=event.delta.partial_json, - snapshot=content_block.input, + elif event.delta.type == "input_json_delta": + if content_block.type == "tool_use": + events_to_fire.append( + build( + BetaInputJsonEvent, + type="input_json", + partial_json=event.delta.partial_json, + snapshot=content_block.input, + ) ) - ) + elif event.delta.type == "citations_delta": + if content_block.type == "text": + events_to_fire.append( + build( + BetaCitationEvent, + type="citation", + citation=event.delta.citation, + snapshot=content_block.citations or [], + ) + ) + else: + # we only want exhaustive checking for linters, not at runtime + if TYPE_CHECKING: # type: ignore[unreachable] + assert_never(event.delta) elif event.type == "content_block_stop": content_block = message_snapshot.content[event.index] @@ -354,6 +372,9 @@ def accumulate_event( event: BetaRawMessageStreamEvent, current_snapshot: BetaMessage | None, ) -> BetaMessage: + if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError(f"Unexpected event runtime type - {event}") + if current_snapshot is None: if event.type == "message_start": return BetaMessage.construct(**cast(Any, event.message.to_dict())) @@ -370,21 +391,33 @@ def accumulate_event( ) elif event.type == "content_block_delta": content = current_snapshot.content[event.index] - if content.type == "text" and event.delta.type == "text_delta": - content.text += event.delta.text - elif content.type == "tool_use" and event.delta.type == "input_json_delta": - from jiter import from_json - - # we need to keep track of the raw JSON string as well so that we can - # re-parse it for each delta, for now we just store it as an untyped - # property on the snapshot - json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b"")) - json_buf += bytes(event.delta.partial_json, "utf-8") - - if json_buf: - content.input = from_json(json_buf, partial_mode=True) - - setattr(content, JSON_BUF_PROPERTY, json_buf) + if event.delta.type == "text_delta": + if content.type == "text": + content.text += event.delta.text + elif event.delta.type == "input_json_delta": + if content.type == "tool_use": + from jiter import from_json + + # we need to keep track of the raw JSON string as well so that we can + # re-parse it for each delta, for now we just store it as an untyped + # property on the snapshot + json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b"")) + json_buf += bytes(event.delta.partial_json, "utf-8") + + if json_buf: + content.input = from_json(json_buf, partial_mode=True) + + setattr(content, JSON_BUF_PROPERTY, json_buf) + elif event.delta.type == "citations_delta": + if content.type == "text": + if not content.citations: + content.citations = [event.delta.citation] + else: + content.citations.append(event.delta.citation) + else: + # we only want exhaustive checking for linters, not at runtime + if TYPE_CHECKING: # type: ignore[unreachable] + assert_never(event.delta) elif event.type == "message_delta": current_snapshot.stop_reason = event.delta.stop_reason current_snapshot.stop_sequence = event.delta.stop_sequence diff --git a/src/anthropic/lib/streaming/_beta_types.py b/src/anthropic/lib/streaming/_beta_types.py index c3ee61f..4ef7e13 100644 --- a/src/anthropic/lib/streaming/_beta_types.py +++ b/src/anthropic/lib/streaming/_beta_types.py @@ -1,5 +1,5 @@ from typing import Union -from typing_extensions import Literal, Annotated +from typing_extensions import List, Literal, Annotated from ..._models import BaseModel from ...types.beta import ( @@ -13,6 +13,7 @@ BetaRawContentBlockStartEvent, ) from ..._utils._transform import PropertyInfo +from ...types.beta.beta_citations_delta import Citation class BetaTextEvent(BaseModel): @@ -25,6 +26,16 @@ class BetaTextEvent(BaseModel): """The entire accumulated text""" +class BetaCitationEvent(BaseModel): + type: Literal["citation"] + + citation: Citation + """The new citation""" + + snapshot: List[Citation] + """All of the accumulated citations""" + + class BetaInputJsonEvent(BaseModel): type: Literal["input_json"] @@ -57,6 +68,7 @@ class BetaContentBlockStopEvent(BetaRawContentBlockStopEvent): BetaMessageStreamEvent = Annotated[ Union[ BetaTextEvent, + BetaCitationEvent, BetaInputJsonEvent, BetaRawMessageStartEvent, BetaRawMessageDeltaEvent, diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index ece0a16..146a1ba 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -9,6 +9,7 @@ from ._types import ( TextEvent, + CitationEvent, InputJsonEvent, MessageStopEvent, MessageStreamEvent, @@ -315,24 +316,40 @@ def build_events( events_to_fire.append(event) content_block = message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content_block.type == "text": - events_to_fire.append( - build( - TextEvent, - type="text", - text=event.delta.text, - snapshot=content_block.text, + if event.delta.type == "text_delta": + if content_block.type == "text": + events_to_fire.append( + build( + TextEvent, + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) ) - ) - elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": - events_to_fire.append( - build( - InputJsonEvent, - type="input_json", - partial_json=event.delta.partial_json, - snapshot=content_block.input, + elif event.delta.type == "input_json_delta": + if content_block.type == "tool_use": + events_to_fire.append( + build( + InputJsonEvent, + type="input_json", + partial_json=event.delta.partial_json, + snapshot=content_block.input, + ) ) - ) + elif event.delta.type == "citations_delta": + if content_block.type == "text": + events_to_fire.append( + build( + CitationEvent, + type="citation", + citation=event.delta.citation, + snapshot=content_block.citations or [], + ) + ) + else: + # we only want exhaustive checking for linters, not at runtime + if TYPE_CHECKING: # type: ignore[unreachable] + assert_never(event.delta) elif event.type == "content_block_stop": content_block = message_snapshot.content[event.index] @@ -374,21 +391,33 @@ def accumulate_event( ) elif event.type == "content_block_delta": content = current_snapshot.content[event.index] - if content.type == "text" and event.delta.type == "text_delta": - content.text += event.delta.text - elif content.type == "tool_use" and event.delta.type == "input_json_delta": - from jiter import from_json - - # we need to keep track of the raw JSON string as well so that we can - # re-parse it for each delta, for now we just store it as an untyped - # property on the snapshot - json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b"")) - json_buf += bytes(event.delta.partial_json, "utf-8") - - if json_buf: - content.input = from_json(json_buf, partial_mode=True) - - setattr(content, JSON_BUF_PROPERTY, json_buf) + if event.delta.type == "text_delta": + if content.type == "text": + content.text += event.delta.text + elif event.delta.type == "input_json_delta": + if content.type == "tool_use": + from jiter import from_json + + # we need to keep track of the raw JSON string as well so that we can + # re-parse it for each delta, for now we just store it as an untyped + # property on the snapshot + json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b"")) + json_buf += bytes(event.delta.partial_json, "utf-8") + + if json_buf: + content.input = from_json(json_buf, partial_mode=True) + + setattr(content, JSON_BUF_PROPERTY, json_buf) + elif event.delta.type == "citations_delta": + if content.type == "text": + if not content.citations: + content.citations = [event.delta.citation] + else: + content.citations.append(event.delta.citation) + else: + # we only want exhaustive checking for linters, not at runtime + if TYPE_CHECKING: # type: ignore[unreachable] + assert_never(event.delta) elif event.type == "message_delta": current_snapshot.stop_reason = event.delta.stop_reason current_snapshot.stop_sequence = event.delta.stop_sequence diff --git a/src/anthropic/lib/streaming/_types.py b/src/anthropic/lib/streaming/_types.py index 59ee779..40af5ee 100644 --- a/src/anthropic/lib/streaming/_types.py +++ b/src/anthropic/lib/streaming/_types.py @@ -1,5 +1,5 @@ from typing import Union -from typing_extensions import Literal, Annotated +from typing_extensions import List, Literal, Annotated from ...types import ( Message, @@ -13,6 +13,7 @@ ) from ..._models import BaseModel from ..._utils._transform import PropertyInfo +from ...types.citations_delta import Citation class TextEvent(BaseModel): @@ -25,6 +26,16 @@ class TextEvent(BaseModel): """The entire accumulated text""" +class CitationEvent(BaseModel): + type: Literal["citation"] + + citation: Citation + """The new citation""" + + snapshot: List[Citation] + """All of the accumulated citations""" + + class InputJsonEvent(BaseModel): type: Literal["input_json"] @@ -57,6 +68,7 @@ class ContentBlockStopEvent(RawContentBlockStopEvent): MessageStreamEvent = Annotated[ Union[ TextEvent, + CitationEvent, InputJsonEvent, RawMessageStartEvent, RawMessageDeltaEvent,