diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 436e7534..ae924aeb 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-openai" -version = "0.29.0" +version = "0.29.3" authors = ["Himanshu Neema"] categories = ["api-bindings", "web-programming", "asynchronous"] keywords = ["openai", "async", "openapi", "ai"] @@ -31,7 +31,7 @@ async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" } backoff = { version = "0.4.0", features = ["tokio"] } base64 = "0.22.1" futures = "0.3.31" -rand = "0.8.5" +rand = "0.9.0" reqwest = { version = "0.12.12", features = [ "json", "stream", diff --git a/async-openai/src/download.rs b/async-openai/src/download.rs index 087ba6f3..e3d9a1e4 100644 --- a/async-openai/src/download.rs +++ b/async-openai/src/download.rs @@ -1,7 +1,7 @@ use std::path::{Path, PathBuf}; use base64::{engine::general_purpose, Engine as _}; -use rand::{distributions::Alphanumeric, Rng}; +use rand::{distr::Alphanumeric, Rng}; use reqwest::Url; use crate::error::OpenAIError; @@ -57,7 +57,7 @@ pub(crate) async fn download_url>( } pub(crate) async fn save_b64>(b64: &str, dir: P) -> Result { - let filename: String = rand::thread_rng() + let filename: String = rand::rng() .sample_iter(&Alphanumeric) .take(10) .map(char::from) diff --git a/async-openai/src/responses.rs b/async-openai/src/responses.rs index 5c2689a3..9160b7be 100644 --- a/async-openai/src/responses.rs +++ b/async-openai/src/responses.rs @@ -1,13 +1,13 @@ use crate::{ config::Config, error::OpenAIError, - types::responses::{CreateResponse, Response}, + types::responses::{CreateResponse, Response, ResponseStream}, Client, }; /// Given text input or a list of context items, the model will generate a response. /// -/// Related guide: [Responses API](https://platform.openai.com/docs/guides/responses) +/// Related guide: [Responses](https://platform.openai.com/docs/api-reference/responses) pub struct Responses<'c, C: Config> { client: &'c Client, } @@ -26,4 +26,30 @@ impl<'c, C: Config> Responses<'c, C> { pub async fn create(&self, request: CreateResponse) -> Result { self.client.post("/responses", request).await } + + /// Creates a model response for the given input with streaming. + /// + /// Response events will be sent as server-sent events as they become available, + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned, + stream = "true", + where_clause = "R: std::marker::Send + 'static" + )] + #[allow(unused_mut)] + pub async fn create_stream( + &self, + mut request: CreateResponse, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if matches!(request.stream, Some(false)) { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Responses::create".into(), + )); + } + request.stream = Some(true); + } + Ok(self.client.post_stream("/responses", request).await) + } } diff --git a/async-openai/src/types/assistant_stream.rs b/async-openai/src/types/assistant_stream.rs index 755a322d..fca835cf 100644 --- a/async-openai/src/types/assistant_stream.rs +++ b/async-openai/src/types/assistant_stream.rs @@ -35,7 +35,7 @@ use super::{ pub enum AssistantStreamEvent { /// Occurs when a new [thread](https://platform.openai.com/docs/api-reference/threads/object) is created. #[serde(rename = "thread.created")] - TreadCreated(ThreadObject), + ThreadCreated(ThreadObject), /// Occurs when a new [run](https://platform.openai.com/docs/api-reference/runs/object) is created. #[serde(rename = "thread.run.created")] ThreadRunCreated(RunObject), @@ -119,7 +119,7 @@ impl TryFrom for AssistantStreamEvent { match value.event.as_str() { "thread.created" => serde_json::from_str::(value.data.as_str()) .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) - .map(AssistantStreamEvent::TreadCreated), + .map(AssistantStreamEvent::ThreadCreated), "thread.run.created" => serde_json::from_str::(value.data.as_str()) .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) .map(AssistantStreamEvent::ThreadRunCreated), diff --git a/async-openai/src/types/chat.rs b/async-openai/src/types/chat.rs index 68c93e19..d9373db6 100644 --- a/async-openai/src/types/chat.rs +++ b/async-openai/src/types/chat.rs @@ -43,7 +43,9 @@ pub enum CompletionFinishReason { pub struct Choice { pub text: String, pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, } @@ -94,8 +96,10 @@ pub struct CompletionUsage { /// Total number of tokens used in the request (prompt + completion). pub total_tokens: u32, /// Breakdown of tokens used in the prompt. + #[serde(skip_serializing_if = "Option::is_none")] pub prompt_tokens_details: Option, /// Breakdown of tokens used in a completion. + #[serde(skip_serializing_if = "Option::is_none")] pub completion_tokens_details: Option, } @@ -414,10 +418,13 @@ pub struct ChatCompletionResponseMessageAudio { #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct ChatCompletionResponseMessage { /// The contents of the message. + #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, /// The refusal message generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] pub refusal: Option, /// The tool calls generated by the model, such as function calls. + #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, /// The role of the author of this message. @@ -425,10 +432,12 @@ pub struct ChatCompletionResponseMessage { /// Deprecated and replaced by `tool_calls`. /// The name and arguments of a function that should be called, as generated by the model. + #[serde(skip_serializing_if = "Option::is_none")] #[deprecated] pub function_call: Option, /// If the audio output modality is requested, this object contains data about the audio response from the model. [Learn more](https://platform.openai.com/docs/guides/audio). + #[serde(skip_serializing_if = "Option::is_none")] pub audio: Option, } @@ -542,7 +551,7 @@ pub struct ChatCompletionNamedToolChoice { /// `required` means the model must call one or more tools. /// Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. /// -/// `none` is the default when no tools are present. `auto` is the default if tools are present.present. +/// `none` is the default when no tools are present. `auto` is the default if tools are present. #[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ChatCompletionToolChoiceOption { @@ -607,6 +616,8 @@ pub enum ServiceTier { Auto, Default, Flex, + Scale, + Priority, } #[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] @@ -615,11 +626,13 @@ pub enum ServiceTierResponse { Scale, Default, Flex, + Priority, } #[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ReasoningEffort { + Minimal, Low, Medium, High, @@ -927,8 +940,10 @@ pub struct ChatChoice { /// `length` if the maximum number of tokens specified in the request was reached, /// `content_filter` if content was omitted due to a flag from our content filters, /// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, /// Log probability information for the choice. + #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, } @@ -944,10 +959,12 @@ pub struct CreateChatCompletionResponse { /// The model used for the chat completion. pub model: String, /// The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + #[serde(skip_serializing_if = "Option::is_none")] pub service_tier: Option, /// This fingerprint represents the backend configuration that the model runs with. /// /// Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + #[serde(skip_serializing_if = "Option::is_none")] pub system_fingerprint: Option, /// The object type, which is always `chat.completion`. @@ -1011,8 +1028,10 @@ pub struct ChatChoiceStream { /// content filters, /// `tool_calls` if the model called a tool, or `function_call` /// (deprecated) if the model called a function. + #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, /// Log probability information for the choice. + #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, } diff --git a/async-openai/src/types/image.rs b/async-openai/src/types/image.rs index 86169c46..b3c30d74 100644 --- a/async-openai/src/types/image.rs +++ b/async-openai/src/types/image.rs @@ -57,6 +57,10 @@ pub enum ImageQuality { #[default] Standard, HD, + High, + Medium, + Low, + Auto, } #[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] @@ -67,6 +71,14 @@ pub enum ImageStyle { Natural, } +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ImageModeration { + #[default] + Auto, + Low, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default, Builder, PartialEq)] #[builder(name = "CreateImageRequestArgs")] #[builder(pattern = "mutable")] @@ -110,6 +122,11 @@ pub struct CreateImageRequest { /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). #[serde(skip_serializing_if = "Option::is_none")] pub user: Option, + + /// Control the content-moderation level for images generated by gpt-image-1. + /// Must be either `low` for less restrictive filtering or `auto` (default value). + #[serde(skip_serializing_if = "Option::is_none")] + pub moderation: Option, } #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] diff --git a/async-openai/src/types/responses.rs b/async-openai/src/types/responses.rs index 4e0eeec7..6b2762c3 100644 --- a/async-openai/src/types/responses.rs +++ b/async-openai/src/types/responses.rs @@ -4,9 +4,11 @@ pub use crate::types::{ ResponseFormatJsonSchema, }; use derive_builder::Builder; +use futures::Stream; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +use std::pin::Pin; /// Role of messages in the API. #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] @@ -327,6 +329,8 @@ pub enum ServiceTier { Auto, Default, Flex, + Scale, + Priority, } /// Truncation strategies. @@ -540,7 +544,7 @@ pub enum ComparisonType { pub struct CompoundFilter { /// Type of operation #[serde(rename = "type")] - pub op: ComparisonType, + pub op: CompoundType, /// Array of filters to combine. Items can be ComparisonFilter or CompoundFilter. pub filters: Vec, } @@ -1181,15 +1185,17 @@ pub struct ImageGenerationCallOutput { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct CodeInterpreterCallOutput { /// The code that was executed. - pub code: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, /// Unique ID of the call. pub id: String, /// Status of the tool call. pub status: String, /// ID of the container used to run the code. pub container_id: String, - /// The results of the execution: logs or files. - pub results: Vec, + /// The outputs of the execution: logs or files. + #[serde(skip_serializing_if = "Option::is_none")] + pub outputs: Option>, } /// Individual result from a code interpreter: either logs or files. @@ -1434,3 +1440,734 @@ pub enum Status { InProgress, Incomplete, } + +/// Event types for streaming responses from the Responses API +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[non_exhaustive] // Future-proof against breaking changes +pub enum ResponseEvent { + /// Response creation started + #[serde(rename = "response.created")] + ResponseCreated(ResponseCreated), + /// Processing in progress + #[serde(rename = "response.in_progress")] + ResponseInProgress(ResponseInProgress), + /// Response completed (different from done) + #[serde(rename = "response.completed")] + ResponseCompleted(ResponseCompleted), + /// Response failed + #[serde(rename = "response.failed")] + ResponseFailed(ResponseFailed), + /// Response incomplete + #[serde(rename = "response.incomplete")] + ResponseIncomplete(ResponseIncomplete), + /// Response queued + #[serde(rename = "response.queued")] + ResponseQueued(ResponseQueued), + /// Output item added + #[serde(rename = "response.output_item.added")] + ResponseOutputItemAdded(ResponseOutputItemAdded), + /// Content part added + #[serde(rename = "response.content_part.added")] + ResponseContentPartAdded(ResponseContentPartAdded), + /// Text delta update + #[serde(rename = "response.output_text.delta")] + ResponseOutputTextDelta(ResponseOutputTextDelta), + /// Text output completed + #[serde(rename = "response.output_text.done")] + ResponseOutputTextDone(ResponseOutputTextDone), + /// Refusal delta update + #[serde(rename = "response.refusal.delta")] + ResponseRefusalDelta(ResponseRefusalDelta), + /// Refusal completed + #[serde(rename = "response.refusal.done")] + ResponseRefusalDone(ResponseRefusalDone), + /// Content part completed + #[serde(rename = "response.content_part.done")] + ResponseContentPartDone(ResponseContentPartDone), + /// Output item completed + #[serde(rename = "response.output_item.done")] + ResponseOutputItemDone(ResponseOutputItemDone), + /// Function call arguments delta + #[serde(rename = "response.function_call_arguments.delta")] + ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDelta), + /// Function call arguments completed + #[serde(rename = "response.function_call_arguments.done")] + ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDone), + /// File search call in progress + #[serde(rename = "response.file_search_call.in_progress")] + ResponseFileSearchCallInProgress(ResponseFileSearchCallInProgress), + /// File search call searching + #[serde(rename = "response.file_search_call.searching")] + ResponseFileSearchCallSearching(ResponseFileSearchCallSearching), + /// File search call completed + #[serde(rename = "response.file_search_call.completed")] + ResponseFileSearchCallCompleted(ResponseFileSearchCallCompleted), + /// Web search call in progress + #[serde(rename = "response.web_search_call.in_progress")] + ResponseWebSearchCallInProgress(ResponseWebSearchCallInProgress), + /// Web search call searching + #[serde(rename = "response.web_search_call.searching")] + ResponseWebSearchCallSearching(ResponseWebSearchCallSearching), + /// Web search call completed + #[serde(rename = "response.web_search_call.completed")] + ResponseWebSearchCallCompleted(ResponseWebSearchCallCompleted), + /// Reasoning summary part added + #[serde(rename = "response.reasoning_summary_part.added")] + ResponseReasoningSummaryPartAdded(ResponseReasoningSummaryPartAdded), + /// Reasoning summary part done + #[serde(rename = "response.reasoning_summary_part.done")] + ResponseReasoningSummaryPartDone(ResponseReasoningSummaryPartDone), + /// Reasoning summary text delta + #[serde(rename = "response.reasoning_summary_text.delta")] + ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDelta), + /// Reasoning summary text done + #[serde(rename = "response.reasoning_summary_text.done")] + ResponseReasoningSummaryTextDone(ResponseReasoningSummaryTextDone), + /// Reasoning summary delta + #[serde(rename = "response.reasoning_summary.delta")] + ResponseReasoningSummaryDelta(ResponseReasoningSummaryDelta), + /// Reasoning summary done + #[serde(rename = "response.reasoning_summary.done")] + ResponseReasoningSummaryDone(ResponseReasoningSummaryDone), + /// Image generation call in progress + #[serde(rename = "response.image_generation_call.in_progress")] + ResponseImageGenerationCallInProgress(ResponseImageGenerationCallInProgress), + /// Image generation call generating + #[serde(rename = "response.image_generation_call.generating")] + ResponseImageGenerationCallGenerating(ResponseImageGenerationCallGenerating), + /// Image generation call partial image + #[serde(rename = "response.image_generation_call.partial_image")] + ResponseImageGenerationCallPartialImage(ResponseImageGenerationCallPartialImage), + /// Image generation call completed + #[serde(rename = "response.image_generation_call.completed")] + ResponseImageGenerationCallCompleted(ResponseImageGenerationCallCompleted), + /// MCP call arguments delta + #[serde(rename = "response.mcp_call_arguments.delta")] + ResponseMcpCallArgumentsDelta(ResponseMcpCallArgumentsDelta), + /// MCP call arguments done + #[serde(rename = "response.mcp_call_arguments.done")] + ResponseMcpCallArgumentsDone(ResponseMcpCallArgumentsDone), + /// MCP call completed + #[serde(rename = "response.mcp_call.completed")] + ResponseMcpCallCompleted(ResponseMcpCallCompleted), + /// MCP call failed + #[serde(rename = "response.mcp_call.failed")] + ResponseMcpCallFailed(ResponseMcpCallFailed), + /// MCP call in progress + #[serde(rename = "response.mcp_call.in_progress")] + ResponseMcpCallInProgress(ResponseMcpCallInProgress), + /// MCP list tools completed + #[serde(rename = "response.mcp_list_tools.completed")] + ResponseMcpListToolsCompleted(ResponseMcpListToolsCompleted), + /// MCP list tools failed + #[serde(rename = "response.mcp_list_tools.failed")] + ResponseMcpListToolsFailed(ResponseMcpListToolsFailed), + /// MCP list tools in progress + #[serde(rename = "response.mcp_list_tools.in_progress")] + ResponseMcpListToolsInProgress(ResponseMcpListToolsInProgress), + /// Code interpreter call in progress + #[serde(rename = "response.code_interpreter_call.in_progress")] + ResponseCodeInterpreterCallInProgress(ResponseCodeInterpreterCallInProgress), + /// Code interpreter call interpreting + #[serde(rename = "response.code_interpreter_call.interpreting")] + ResponseCodeInterpreterCallInterpreting(ResponseCodeInterpreterCallInterpreting), + /// Code interpreter call completed + #[serde(rename = "response.code_interpreter_call.completed")] + ResponseCodeInterpreterCallCompleted(ResponseCodeInterpreterCallCompleted), + /// Code interpreter call code delta + #[serde(rename = "response.code_interpreter_call_code.delta")] + ResponseCodeInterpreterCallCodeDelta(ResponseCodeInterpreterCallCodeDelta), + /// Code interpreter call code done + #[serde(rename = "response.code_interpreter_call_code.done")] + ResponseCodeInterpreterCallCodeDone(ResponseCodeInterpreterCallCodeDone), + /// Output text annotation added + #[serde(rename = "response.output_text.annotation.added")] + ResponseOutputTextAnnotationAdded(ResponseOutputTextAnnotationAdded), + /// Error occurred + #[serde(rename = "error")] + ResponseError(ResponseError), + + /// Unknown event type + #[serde(untagged)] + Unknown(serde_json::Value), +} + +/// Stream of response events +pub type ResponseStream = Pin> + Send>>; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCreated { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseInProgress { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputItemAdded { + pub sequence_number: u64, + pub output_index: u32, + pub item: OutputItem, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseContentPartAdded { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputTextDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub delta: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseContentPartDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputItemDone { + pub sequence_number: u64, + pub output_index: u32, + pub item: OutputItem, +} + +/// Response completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCompleted { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Response failed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFailed { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Response incomplete event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseIncomplete { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Response queued event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseQueued { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Text output completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputTextDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub text: String, + pub logprobs: Option>, +} + +/// Refusal delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseRefusalDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub delta: String, +} + +/// Refusal done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseRefusalDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub refusal: String, +} + +/// Function call arguments delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFunctionCallArgumentsDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub delta: String, +} + +/// Function call arguments done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFunctionCallArgumentsDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub arguments: String, +} + +/// Error event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseError { + pub sequence_number: u64, + pub code: Option, + pub message: String, + pub param: Option, +} + +/// File search call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFileSearchCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// File search call searching event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFileSearchCallSearching { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// File search call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFileSearchCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Web search call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseWebSearchCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Web search call searching event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseWebSearchCallSearching { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Web search call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseWebSearchCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Reasoning summary part added event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryPartAdded { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub part: serde_json::Value, // Could be more specific but using Value for flexibility +} + +/// Reasoning summary part done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryPartDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub part: serde_json::Value, +} + +/// Reasoning summary text delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryTextDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub delta: String, +} + +/// Reasoning summary text done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryTextDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub text: String, +} + +/// Reasoning summary delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub delta: serde_json::Value, +} + +/// Reasoning summary done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub text: String, +} + +/// Image generation call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Image generation call generating event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallGenerating { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Image generation call partial image event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallPartialImage { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub partial_image_index: u32, + pub partial_image_b64: String, +} + +/// Image generation call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP call arguments delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallArgumentsDelta { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub delta: String, +} + +/// MCP call arguments done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallArgumentsDone { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub arguments: String, +} + +/// MCP call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP call failed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallFailed { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP list tools completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpListToolsCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP list tools failed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpListToolsFailed { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP list tools in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpListToolsInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call interpreting event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallInterpreting { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call code delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallCodeDelta { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub delta: String, +} + +/// Code interpreter call code done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallCodeDone { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub code: String, +} + +/// Response metadata +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMetadata { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + pub created_at: u64, + pub status: Status, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub incomplete_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + /// Whether the model was run in background mode + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, + /// The service tier that was actually used + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + /// The effective value of top_logprobs parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + /// The effective value of max_tool_calls parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + /// Prompt cache key for improved performance + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + /// Safety identifier for content filtering + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_identifier: Option, +} + +/// Output item +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum OutputItem { + Message(OutputMessage), + FileSearchCall(FileSearchCallOutput), + FunctionCall(FunctionCall), + WebSearchCall(WebSearchCallOutput), + ComputerCall(ComputerCallOutput), + Reasoning(ReasoningItem), + ImageGenerationCall(ImageGenerationCallOutput), + CodeInterpreterCall(CodeInterpreterCallOutput), + LocalShellCall(LocalShellCallOutput), + McpCall(McpCallOutput), + McpListTools(McpListToolsOutput), + McpApprovalRequest(McpApprovalRequestOutput), + CustomToolCall(CustomToolCallOutput), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct CustomToolCallOutput { + pub call_id: String, + pub input: String, + pub name: String, + pub id: String, +} + +/// Content part +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ContentPart { + #[serde(rename = "type")] + pub part_type: String, + pub text: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub annotations: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub logprobs: Option>, +} + +// ===== RESPONSE COLLECTOR ===== + +/// Collects streaming response events into a complete response + +/// Output text annotation added event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputTextAnnotationAdded { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub annotation_index: u32, + pub annotation: TextAnnotation, +} + +/// Text annotation object for output text +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct TextAnnotation { + #[serde(rename = "type")] + pub annotation_type: String, + pub text: String, + pub start: u32, + pub end: u32, +} diff --git a/examples/responses-stream/Cargo.toml b/examples/responses-stream/Cargo.toml new file mode 100644 index 00000000..82eb90a7 --- /dev/null +++ b/examples/responses-stream/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "responses-stream" +version = "0.1.0" +edition = "2024" + +[dependencies] +async-openai = { path = "../../async-openai" } +tokio = { version = "1.0", features = ["full"] } +futures = "0.3" +serde_json = "1.0" diff --git a/examples/responses-stream/src/main.rs b/examples/responses-stream/src/main.rs new file mode 100644 index 00000000..5b565cd8 --- /dev/null +++ b/examples/responses-stream/src/main.rs @@ -0,0 +1,51 @@ +use async_openai::{ + Client, + types::responses::{ + CreateResponseArgs, Input, InputContent, InputItem, InputMessageArgs, ResponseEvent, Role, + }, +}; +use futures::StreamExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(); + + let request = CreateResponseArgs::default() + .model("gpt-4.1") + .stream(true) + .input(Input::Items(vec![InputItem::Message( + InputMessageArgs::default() + .role(Role::User) + .content(InputContent::TextInput( + "Write a haiku about programming.".to_string(), + )) + .build()?, + )])) + .build()?; + + let mut stream = client.responses().create_stream(request).await?; + + while let Some(result) = stream.next().await { + match result { + Ok(response_event) => match &response_event { + ResponseEvent::ResponseOutputTextDelta(delta) => { + print!("{}", delta.delta); + } + ResponseEvent::ResponseCompleted(_) + | ResponseEvent::ResponseIncomplete(_) + | ResponseEvent::ResponseFailed(_) => { + break; + } + _ => { println!("{response_event:#?}"); } + }, + Err(e) => { + eprintln!("{e:#?}"); + // When a stream ends, it returns Err(OpenAIError::StreamError("Stream ended")) + // Without this, the stream will never end + break; + } + } + } + + Ok(()) +}