diff --git a/Cargo.toml b/Cargo.toml index 433f32af..c0601e7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace] -members = [ "async-openai", "async-openai-*", "examples/*" ] +members = [ "clia-async-openai", "async-openai", "async-openai-*", "examples/*" ] # Only check / build main crates by default (check all with `--workspace`) -default-members = ["async-openai", "async-openai-*"] +default-members = ["clia-async-openai", "async-openai", "async-openai-*"] resolver = "2" [workspace.package] diff --git a/clia-async-openai/Cargo.toml b/clia-async-openai/Cargo.toml new file mode 100644 index 00000000..253cbe0a --- /dev/null +++ b/clia-async-openai/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "clia-async-openai" +version = "0.28.0" +authors = ["Himanshu Neema"] +categories = ["api-bindings", "web-programming", "asynchronous"] +keywords = ["openai", "async", "openapi", "ai"] +description = "Rust library for OpenAI (with rustls)" +edition = "2021" +rust-version = { workspace = true } +license = "MIT" +readme = "README.md" +homepage = "https://github.com/64bit/async-openai" +repository = "https://github.com/clia-mod/async-openai" + +[features] +default = ["rustls-webpki-roots"] +# Enable rustls for TLS support +rustls = ["reqwest/rustls-tls-native-roots"] +# Enable rustls and webpki-roots +rustls-webpki-roots = ["reqwest/rustls-tls-webpki-roots"] +# Enable native-tls for TLS support +native-tls = ["reqwest/native-tls"] +# Remove dependency on OpenSSL +native-tls-vendored = ["reqwest/native-tls-vendored"] +realtime = ["dep:tokio-tungstenite"] +# Bring your own types +byot = [] + +[dependencies] +async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" } +backoff = { version = "0.4.0", features = ["tokio"] } +base64 = "0.22.1" +futures = "0.3.31" +rand = "0.8.5" +reqwest = { version = "0.12.12", features = [ + "json", + "stream", + "multipart", + "rustls-tls" +], default-features = false } +reqwest-eventsource = "0.6.0" +serde = { version = "1.0.217", features = ["derive", "rc"] } +serde_json = "1.0.135" +thiserror = "2.0.11" +tokio = { version = "1.43.0", features = ["fs", "macros"] } +tokio-stream = "0.1.17" +tokio-util = { version = "0.7.13", features = ["codec", "io-util"] } +tracing = "0.1.41" +derive_builder = "0.20.2" +secrecy = { version = "0.10.3", features = ["serde"] } +bytes = "1.9.0" +eventsource-stream = "0.2.3" +tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false } + +[dev-dependencies] +tokio-test = "0.4.4" +serde_json = "1.0" + +[[test]] +name = "bring-your-own-type" +required-features = ["byot"] + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/clia-async-openai/README.md b/clia-async-openai/README.md new file mode 100644 index 00000000..cd2abd5b --- /dev/null +++ b/clia-async-openai/README.md @@ -0,0 +1,169 @@ +
+ + + +
+

async-openai

+

Async Rust library for OpenAI

+
+ + + + + + +
+
+Logo created by this repo itself +
+ +## Overview + +`async-openai` is an unofficial Rust library for OpenAI. + +- It's based on [OpenAI OpenAPI spec](https://github.com/openai/openai-openapi) +- Current features: + - [x] Assistants (v2) + - [x] Audio + - [x] Batch + - [x] Chat + - [x] Completions (Legacy) + - [x] Embeddings + - [x] Files + - [x] Fine-Tuning + - [x] Images + - [x] Models + - [x] Moderations + - [x] Organizations | Administration (partially implemented) + - [x] Realtime (Beta) (partially implemented) + - [x] Uploads +- Bring your own custom types for Request or Response objects. +- SSE streaming on available APIs +- Requests (except SSE streaming) including form submissions are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits). +- Ergonomic builder pattern for all request objects. +- Microsoft Azure OpenAI Service (only for APIs matching OpenAI spec) + +## Usage + +The library reads [API key](https://platform.openai.com/account/api-keys) from the environment variable `OPENAI_API_KEY`. + +```bash +# On macOS/Linux +export OPENAI_API_KEY='sk-...' +``` + +```powershell +# On Windows Powershell +$Env:OPENAI_API_KEY='sk-...' +``` + +- Visit [examples](https://github.com/64bit/async-openai/tree/main/examples) directory on how to use `async-openai`. +- Visit [docs.rs/async-openai](https://docs.rs/async-openai) for docs. + +## Realtime API + +Only types for Realtime API are implemented, and can be enabled with feature flag `realtime`. +These types were written before OpenAI released official specs. + +## Image Generation Example + +```rust +use async_openai::{ + types::{CreateImageRequestArgs, ImageSize, ImageResponseFormat}, + Client, +}; +use std::error::Error; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // create client, reads OPENAI_API_KEY environment variable for API key. + let client = Client::new(); + + let request = CreateImageRequestArgs::default() + .prompt("cats on sofa and carpet in living room") + .n(2) + .response_format(ImageResponseFormat::Url) + .size(ImageSize::S256x256) + .user("async-openai") + .build()?; + + let response = client.images().create(request).await?; + + // Download and save images to ./data directory. + // Each url is downloaded and saved in dedicated Tokio task. + // Directory is created if it doesn't exist. + let paths = response.save("./data").await?; + + paths + .iter() + .for_each(|path| println!("Image file path: {}", path.display())); + + Ok(()) +} +``` + +
+ + +
+ Scaled up for README, actual size 256x256 +
+ +## Bring Your Own Types + +Enable methods whose input and outputs are generics with `byot` feature. It creates a new method with same name and `_byot` suffix. + +For example, to use `serde_json::Value` as request and response type: +```rust +let response: Value = client + .chat() + .create_byot(json!({ + "messages": [ + { + "role": "developer", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "What do you think about life?" + } + ], + "model": "gpt-4o", + "store": false + })) + .await?; +``` + +This can be useful in many scenarios: +- To use this library with other OpenAI compatible APIs whose types don't exactly match OpenAI. +- Extend existing types in this crate with new fields with `serde`. +- To avoid verbose types. +- To escape deserialization errors. + +Visit [examples/bring-your-own-type](https://github.com/64bit/async-openai/tree/main/examples/bring-your-own-type) directory to learn more. + +## Contributing + +Thank you for taking the time to contribute and improve the project. I'd be happy to have you! + +All forms of contributions, such as new features requests, bug fixes, issues, documentation, testing, comments, [examples](../examples) etc. are welcome. + +A good starting point would be to look at existing [open issues](https://github.com/64bit/async-openai/issues). + +To maintain quality of the project, a minimum of the following is a must for code contribution: + +- **Names & Documentation**: All struct names, field names and doc comments are from OpenAPI spec. Nested objects in spec without names leaves room for making appropriate name. +- **Tested**: For changes supporting test(s) and/or example is required. Existing examples, doc tests, unit tests, and integration tests should be made to work with the changes if applicable. +- **Scope**: Keep scope limited to APIs available in official documents such as [API Reference](https://platform.openai.com/docs/api-reference) or [OpenAPI spec](https://github.com/openai/openai-openapi/). Other LLMs or AI Providers offer OpenAI-compatible APIs, yet they may not always have full parity. In such cases, the OpenAI spec takes precedence. +- **Consistency**: Keep code style consistent across all the "APIs" that library exposes; it creates a great developer experience. + +This project adheres to [Rust Code of Conduct](https://www.rust-lang.org/policies/code-of-conduct) + +## Complimentary Crates + +- [openai-func-enums](https://github.com/frankfralick/openai-func-enums) provides procedural macros that make it easier to use this library with OpenAI API's tool calling feature. It also provides derive macros you can add to existing [clap](https://github.com/clap-rs/clap) application subcommands for natural language use of command line tools. It also supports openai's [parallel tool calls](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling) and allows you to choose between running multiple tool calls concurrently or own their own OS threads. +- [async-openai-wasm](https://github.com/ifsheldon/async-openai-wasm) provides WASM support. + +## License + +This project is licensed under [MIT license](https://github.com/64bit/async-openai/blob/main/LICENSE). diff --git a/clia-async-openai/src/assistants.rs b/clia-async-openai/src/assistants.rs new file mode 100644 index 00000000..494f2dec --- /dev/null +++ b/clia-async-openai/src/assistants.rs @@ -0,0 +1,70 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + AssistantObject, CreateAssistantRequest, DeleteAssistantResponse, ListAssistantsResponse, + ModifyAssistantRequest, + }, + Client, +}; + +/// Build assistants that can call models and use tools to perform tasks. +/// +/// [Get started with the Assistants API](https://platform.openai.com/docs/assistants) +pub struct Assistants<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Assistants<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Create an assistant with a model and instructions. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateAssistantRequest, + ) -> Result { + self.client.post("/assistants", request).await + } + + /// Retrieves an assistant. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, assistant_id: &str) -> Result { + self.client + .get(&format!("/assistants/{assistant_id}")) + .await + } + + /// Modifies an assistant. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn update( + &self, + assistant_id: &str, + request: ModifyAssistantRequest, + ) -> Result { + self.client + .post(&format!("/assistants/{assistant_id}"), request) + .await + } + + /// Delete an assistant. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, assistant_id: &str) -> Result { + self.client + .delete(&format!("/assistants/{assistant_id}")) + .await + } + + /// Returns a list of assistants. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client.get_with_query("/assistants", &query).await + } +} diff --git a/clia-async-openai/src/audio.rs b/clia-async-openai/src/audio.rs new file mode 100644 index 00000000..1ee4631d --- /dev/null +++ b/clia-async-openai/src/audio.rs @@ -0,0 +1,111 @@ +use bytes::Bytes; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest, + CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson, + CreateTranslationRequest, CreateTranslationResponseJson, + CreateTranslationResponseVerboseJson, + }, + Client, +}; + +/// Turn audio into text or text into audio. +/// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text) +pub struct Audio<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Audio<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Transcribes audio into the input language. + #[crate::byot( + T0 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom", + )] + pub async fn transcribe( + &self, + request: CreateTranscriptionRequest, + ) -> Result { + self.client + .post_form("/audio/transcriptions", request) + .await + } + + /// Transcribes audio into the input language. + #[crate::byot( + T0 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom", + )] + pub async fn transcribe_verbose_json( + &self, + request: CreateTranscriptionRequest, + ) -> Result { + self.client + .post_form("/audio/transcriptions", request) + .await + } + + /// Transcribes audio into the input language. + pub async fn transcribe_raw( + &self, + request: CreateTranscriptionRequest, + ) -> Result { + self.client + .post_form_raw("/audio/transcriptions", request) + .await + } + + /// Translates audio into English. + #[crate::byot( + T0 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom", + )] + pub async fn translate( + &self, + request: CreateTranslationRequest, + ) -> Result { + self.client.post_form("/audio/translations", request).await + } + + /// Translates audio into English. + #[crate::byot( + T0 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom", + )] + pub async fn translate_verbose_json( + &self, + request: CreateTranslationRequest, + ) -> Result { + self.client.post_form("/audio/translations", request).await + } + + /// Transcribes audio into the input language. + pub async fn translate_raw( + &self, + request: CreateTranslationRequest, + ) -> Result { + self.client + .post_form_raw("/audio/translations", request) + .await + } + + /// Generates audio from the input text. + pub async fn speech( + &self, + request: CreateSpeechRequest, + ) -> Result { + let bytes = self.client.post_raw("/audio/speech", request).await?; + + Ok(CreateSpeechResponse { bytes }) + } +} diff --git a/clia-async-openai/src/audit_logs.rs b/clia-async-openai/src/audit_logs.rs new file mode 100644 index 00000000..753c318b --- /dev/null +++ b/clia-async-openai/src/audit_logs.rs @@ -0,0 +1,27 @@ +use serde::Serialize; + +use crate::{config::Config, error::OpenAIError, types::ListAuditLogsResponse, Client}; + +/// Logs of user actions and configuration changes within this organization. +/// To log events, you must activate logging in the [Organization Settings](https://platform.openai.com/settings/organization/general). +/// Once activated, for security reasons, logging cannot be deactivated. +pub struct AuditLogs<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> AuditLogs<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// List user actions and configuration changes within this organization. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn get(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query("/organization/audit_logs", &query) + .await + } +} diff --git a/clia-async-openai/src/batches.rs b/clia-async-openai/src/batches.rs new file mode 100644 index 00000000..57910490 --- /dev/null +++ b/clia-async-openai/src/batches.rs @@ -0,0 +1,53 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{Batch, BatchRequest, ListBatchesResponse}, + Client, +}; + +/// Create large batches of API requests for asynchronous processing. The Batch API returns completions within 24 hours for a 50% discount. +/// +/// Related guide: [Batch](https://platform.openai.com/docs/guides/batch) +pub struct Batches<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Batches<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Creates and executes a batch from an uploaded file of requests + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create(&self, request: BatchRequest) -> Result { + self.client.post("/batches", request).await + } + + /// List your organization's batches. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client.get_with_query("/batches", &query).await + } + + /// Retrieves a batch. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, batch_id: &str) -> Result { + self.client.get(&format!("/batches/{batch_id}")).await + } + + /// Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, where it will have partial results (if any) available in the output file. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn cancel(&self, batch_id: &str) -> Result { + self.client + .post( + &format!("/batches/{batch_id}/cancel"), + serde_json::json!({}), + ) + .await + } +} diff --git a/clia-async-openai/src/chat.rs b/clia-async-openai/src/chat.rs new file mode 100644 index 00000000..28c89f9d --- /dev/null +++ b/clia-async-openai/src/chat.rs @@ -0,0 +1,88 @@ +use crate::{ + config::Config, + error::OpenAIError, + types::{ + ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, + }, + Client, +}; + +/// Given a list of messages comprising a conversation, the model will return a response. +/// +/// Related guide: [Chat completions](https://platform.openai.com//docs/guides/text-generation) +pub struct Chat<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Chat<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Creates a model response for the given chat conversation. Learn more in + /// the + /// + /// [text generation](https://platform.openai.com/docs/guides/text-generation), + /// [vision](https://platform.openai.com/docs/guides/vision), + /// + /// and [audio](https://platform.openai.com/docs/guides/audio) guides. + /// + /// + /// Parameter support can differ depending on the model used to generate the + /// response, particularly for newer reasoning models. Parameters that are + /// only supported for reasoning models are noted below. For the current state + /// of unsupported parameters in reasoning models, + /// + /// [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning). + /// + /// byot: You must ensure "stream: false" in serialized `request` + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned + )] + pub async fn create( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if request.stream.is_some() && request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is true, use Chat::create_stream".into(), + )); + } + } + self.client.post("/chat/completions", request).await + } + + /// Creates a completion for the chat message + /// + /// partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. + /// + /// [ChatCompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server. + /// + /// byot: You must ensure "stream: true" in serialized `request` + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned, + stream = "true", + where_clause = "R: std::marker::Send + 'static" + )] + #[allow(unused_mut)] + pub async fn create_stream( + &self, + mut request: CreateChatCompletionRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if request.stream.is_some() && !request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Chat::create".into(), + )); + } + + request.stream = Some(true); + } + Ok(self.client.post_stream("/chat/completions", request).await) + } +} diff --git a/clia-async-openai/src/client.rs b/clia-async-openai/src/client.rs new file mode 100644 index 00000000..11046817 --- /dev/null +++ b/clia-async-openai/src/client.rs @@ -0,0 +1,546 @@ +use std::pin::Pin; + +use bytes::Bytes; +use futures::{stream::StreamExt, Stream}; +use reqwest::multipart::Form; +use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; +use serde::{de::DeserializeOwned, Serialize}; + +use crate::{ + config::{Config, OpenAIConfig}, + error::{map_deserialization_error, OpenAIError, WrappedError}, + file::Files, + image::Images, + moderation::Moderations, + traits::AsyncTryFrom, + Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites, + Models, Projects, Threads, Uploads, Users, VectorStores, +}; + +#[derive(Debug, Clone, Default)] +/// Client is a container for config, backoff and http_client +/// used to make API calls. +pub struct Client { + http_client: reqwest::Client, + config: C, + backoff: backoff::ExponentialBackoff, +} + +impl Client { + /// Client with default [OpenAIConfig] + pub fn new() -> Self { + Self::default() + } +} + +impl Client { + /// Create client with a custom HTTP client, OpenAI config, and backoff. + pub fn build( + http_client: reqwest::Client, + config: C, + backoff: backoff::ExponentialBackoff, + ) -> Self { + Self { + http_client, + config, + backoff, + } + } + + /// Create client with [OpenAIConfig] or [crate::config::AzureConfig] + pub fn with_config(config: C) -> Self { + Self { + http_client: reqwest::Client::new(), + config, + backoff: Default::default(), + } + } + + /// Provide your own [client] to make HTTP requests with. + /// + /// [client]: reqwest::Client + pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self { + self.http_client = http_client; + self + } + + /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests. + pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self { + self.backoff = backoff; + self + } + + // API groups + + /// To call [Models] group related APIs using this client. + pub fn models(&self) -> Models { + Models::new(self) + } + + /// To call [Completions] group related APIs using this client. + pub fn completions(&self) -> Completions { + Completions::new(self) + } + + /// To call [Chat] group related APIs using this client. + pub fn chat(&self) -> Chat { + Chat::new(self) + } + + /// To call [Images] group related APIs using this client. + pub fn images(&self) -> Images { + Images::new(self) + } + + /// To call [Moderations] group related APIs using this client. + pub fn moderations(&self) -> Moderations { + Moderations::new(self) + } + + /// To call [Files] group related APIs using this client. + pub fn files(&self) -> Files { + Files::new(self) + } + + /// To call [Uploads] group related APIs using this client. + pub fn uploads(&self) -> Uploads { + Uploads::new(self) + } + + /// To call [FineTuning] group related APIs using this client. + pub fn fine_tuning(&self) -> FineTuning { + FineTuning::new(self) + } + + /// To call [Embeddings] group related APIs using this client. + pub fn embeddings(&self) -> Embeddings { + Embeddings::new(self) + } + + /// To call [Audio] group related APIs using this client. + pub fn audio(&self) -> Audio { + Audio::new(self) + } + + /// To call [Assistants] group related APIs using this client. + pub fn assistants(&self) -> Assistants { + Assistants::new(self) + } + + /// To call [Threads] group related APIs using this client. + pub fn threads(&self) -> Threads { + Threads::new(self) + } + + /// To call [VectorStores] group related APIs using this client. + pub fn vector_stores(&self) -> VectorStores { + VectorStores::new(self) + } + + /// To call [Batches] group related APIs using this client. + pub fn batches(&self) -> Batches { + Batches::new(self) + } + + /// To call [AuditLogs] group related APIs using this client. + pub fn audit_logs(&self) -> AuditLogs { + AuditLogs::new(self) + } + + /// To call [Invites] group related APIs using this client. + pub fn invites(&self) -> Invites { + Invites::new(self) + } + + /// To call [Users] group related APIs using this client. + pub fn users(&self) -> Users { + Users::new(self) + } + + /// To call [Projects] group related APIs using this client. + pub fn projects(&self) -> Projects { + Projects::new(self) + } + + pub fn config(&self) -> &C { + &self.config + } + + /// Make a GET request to {path} and deserialize the response body + pub(crate) async fn get(&self, path: &str) -> Result + where + O: DeserializeOwned, + { + let request_maker = || async { + Ok(self + .http_client + .get(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .build()?) + }; + + self.execute(request_maker).await + } + + /// Make a GET request to {path} with given Query and deserialize the response body + pub(crate) async fn get_with_query(&self, path: &str, query: &Q) -> Result + where + O: DeserializeOwned, + Q: Serialize + ?Sized, + { + let request_maker = || async { + Ok(self + .http_client + .get(self.config.url(path)) + .query(&self.config.query()) + .query(query) + .headers(self.config.headers()) + .build()?) + }; + + self.execute(request_maker).await + } + + /// Make a DELETE request to {path} and deserialize the response body + pub(crate) async fn delete(&self, path: &str) -> Result + where + O: DeserializeOwned, + { + let request_maker = || async { + Ok(self + .http_client + .delete(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .build()?) + }; + + self.execute(request_maker).await + } + + /// Make a GET request to {path} and return the response body + pub(crate) async fn get_raw(&self, path: &str) -> Result { + let request_maker = || async { + Ok(self + .http_client + .get(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .build()?) + }; + + self.execute_raw(request_maker).await + } + + /// Make a POST request to {path} and return the response body + pub(crate) async fn post_raw(&self, path: &str, request: I) -> Result + where + I: Serialize, + { + let request_maker = || async { + Ok(self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .json(&request) + .build()?) + }; + + self.execute_raw(request_maker).await + } + + /// Make a POST request to {path} and deserialize the response body + pub(crate) async fn post(&self, path: &str, request: I) -> Result + where + I: Serialize, + O: DeserializeOwned, + { + let request_maker = || async { + Ok(self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .json(&request) + .build()?) + }; + + self.execute(request_maker).await + } + + /// POST a form at {path} and return the response body + pub(crate) async fn post_form_raw(&self, path: &str, form: F) -> Result + where + Form: AsyncTryFrom, + F: Clone, + { + let request_maker = || async { + Ok(self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .multipart(
>::try_from(form.clone()).await?) + .build()?) + }; + + self.execute_raw(request_maker).await + } + + /// POST a form at {path} and deserialize the response body + pub(crate) async fn post_form(&self, path: &str, form: F) -> Result + where + O: DeserializeOwned, + Form: AsyncTryFrom, + F: Clone, + { + let request_maker = || async { + Ok(self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .multipart(>::try_from(form.clone()).await?) + .build()?) + }; + + self.execute(request_maker).await + } + + /// Execute a HTTP request and retry on rate limit + /// + /// request_maker serves one purpose: to be able to create request again + /// to retry API call after getting rate limited. request_maker is async because + /// reqwest::multipart::Form is created by async calls to read files for uploads. + async fn execute_raw(&self, request_maker: M) -> Result + where + M: Fn() -> Fut, + Fut: core::future::Future>, + { + let client = self.http_client.clone(); + + backoff::future::retry(self.backoff.clone(), || async { + let request = request_maker().await.map_err(backoff::Error::Permanent)?; + let response = client + .execute(request) + .await + .map_err(OpenAIError::Reqwest) + .map_err(backoff::Error::Permanent)?; + + let status = response.status(); + let bytes = response + .bytes() + .await + .map_err(OpenAIError::Reqwest) + .map_err(backoff::Error::Permanent)?; + + // Deserialize response body from either error object or actual response object + if !status.is_success() { + let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref())) + .map_err(backoff::Error::Permanent)?; + + if status.as_u16() == 429 + // API returns 429 also when: + // "You exceeded your current quota, please check your plan and billing details." + && wrapped_error.error.r#type != Some("insufficient_quota".to_string()) + { + // Rate limited retry... + tracing::warn!("Rate limited: {}", wrapped_error.error.message); + return Err(backoff::Error::Transient { + err: OpenAIError::ApiError(wrapped_error.error), + retry_after: None, + }); + } else { + return Err(backoff::Error::Permanent(OpenAIError::ApiError( + wrapped_error.error, + ))); + } + } + + Ok(bytes) + }) + .await + } + + /// Execute a HTTP request and retry on rate limit + /// + /// request_maker serves one purpose: to be able to create request again + /// to retry API call after getting rate limited. request_maker is async because + /// reqwest::multipart::Form is created by async calls to read files for uploads. + async fn execute(&self, request_maker: M) -> Result + where + O: DeserializeOwned, + M: Fn() -> Fut, + Fut: core::future::Future>, + { + let bytes = self.execute_raw(request_maker).await?; + + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + + Ok(response) + } + + /// Make HTTP POST request to receive SSE + pub(crate) async fn post_stream( + &self, + path: &str, + request: I, + ) -> Pin> + Send>> + where + I: Serialize, + O: DeserializeOwned + std::marker::Send + 'static, + { + let event_source = self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .json(&request) + .eventsource() + .unwrap(); + + stream(event_source).await + } + + pub(crate) async fn post_stream_mapped_raw_events( + &self, + path: &str, + request: I, + event_mapper: impl Fn(eventsource_stream::Event) -> Result + Send + 'static, + ) -> Pin> + Send>> + where + I: Serialize, + O: DeserializeOwned + std::marker::Send + 'static, + { + let event_source = self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .json(&request) + .eventsource() + .unwrap(); + + stream_mapped_raw_events(event_source, event_mapper).await + } + + /// Make HTTP GET request to receive SSE + pub(crate) async fn _get_stream( + &self, + path: &str, + query: &Q, + ) -> Pin> + Send>> + where + Q: Serialize + ?Sized, + O: DeserializeOwned + std::marker::Send + 'static, + { + let event_source = self + .http_client + .get(self.config.url(path)) + .query(query) + .query(&self.config.query()) + .headers(self.config.headers()) + .eventsource() + .unwrap(); + + stream(event_source).await + } +} + +/// Request which responds with SSE. +/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) +pub(crate) async fn stream( + mut event_source: EventSource, +) -> Pin> + Send>> +where + O: DeserializeOwned + std::marker::Send + 'static, +{ + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(ev) = event_source.next().await { + match ev { + Err(e) => { + if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + // rx dropped + break; + } + } + Ok(event) => match event { + Event::Message(message) => { + if message.data == "[DONE]" { + break; + } + + let response = match serde_json::from_str::(&message.data) { + Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())), + Ok(output) => Ok(output), + }; + + if let Err(_e) = tx.send(response) { + // rx dropped + break; + } + } + Event::Open => continue, + }, + } + } + + event_source.close(); + }); + + Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) +} + +pub(crate) async fn stream_mapped_raw_events( + mut event_source: EventSource, + event_mapper: impl Fn(eventsource_stream::Event) -> Result + Send + 'static, +) -> Pin> + Send>> +where + O: DeserializeOwned + std::marker::Send + 'static, +{ + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(ev) = event_source.next().await { + match ev { + Err(e) => { + if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + // rx dropped + break; + } + } + Ok(event) => match event { + Event::Message(message) => { + let mut done = false; + + if message.data == "[DONE]" { + done = true; + } + + let response = event_mapper(message); + + if let Err(_e) = tx.send(response) { + // rx dropped + break; + } + + if done { + break; + } + } + Event::Open => continue, + }, + } + } + + event_source.close(); + }); + + Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) +} diff --git a/clia-async-openai/src/completion.rs b/clia-async-openai/src/completion.rs new file mode 100644 index 00000000..432201c3 --- /dev/null +++ b/clia-async-openai/src/completion.rs @@ -0,0 +1,77 @@ +use crate::{ + client::Client, + config::Config, + error::OpenAIError, + types::{CompletionResponseStream, CreateCompletionRequest, CreateCompletionResponse}, +}; + +/// Given a prompt, the model will return one or more predicted completions, +/// and can also return the probabilities of alternative tokens at each position. +/// We recommend most users use our Chat completions API. +/// [Learn more](https://platform.openai.com/docs/deprecations/2023-07-06-gpt-and-embeddings) +/// +/// Related guide: [Legacy Completions](https://platform.openai.com/docs/guides/gpt/completions-api) +pub struct Completions<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Completions<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Creates a completion for the provided prompt and parameters + /// + /// You must ensure that "stream: false" in serialized `request` + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned + )] + pub async fn create( + &self, + request: CreateCompletionRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if request.stream.is_some() && request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is true, use Completion::create_stream".into(), + )); + } + } + self.client.post("/completions", request).await + } + + /// Creates a completion request for the provided prompt and parameters + /// + /// Stream back partial progress. Tokens will be sent as data-only + /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) + /// as they become available, with the stream terminated by a data: \[DONE\] message. + /// + /// [CompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server. + /// + /// You must ensure that "stream: true" in serialized `request` + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned, + stream = "true", + where_clause = "R: std::marker::Send + 'static" + )] + #[allow(unused_mut)] + pub async fn create_stream( + &self, + mut request: CreateCompletionRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if request.stream.is_some() && !request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Completion::create".into(), + )); + } + + request.stream = Some(true); + } + Ok(self.client.post_stream("/completions", request).await) + } +} diff --git a/clia-async-openai/src/config.rs b/clia-async-openai/src/config.rs new file mode 100644 index 00000000..4c5468c2 --- /dev/null +++ b/clia-async-openai/src/config.rs @@ -0,0 +1,213 @@ +//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service. +use reqwest::header::{HeaderMap, AUTHORIZATION}; +use secrecy::{ExposeSecret, SecretString}; +use serde::Deserialize; + +/// Default v1 API base url +pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1"; +/// Organization header +pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization"; +/// Project header +pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project"; + +/// Calls to the Assistants API require that you pass a Beta header +pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta"; + +/// [crate::Client] relies on this for every API call on OpenAI +/// or Azure OpenAI service +pub trait Config: Clone { + fn headers(&self) -> HeaderMap; + fn url(&self, path: &str) -> String; + fn query(&self) -> Vec<(&str, &str)>; + + fn api_base(&self) -> &str; + + fn api_key(&self) -> &SecretString; +} + +/// Configuration for OpenAI API +#[derive(Clone, Debug, Deserialize)] +#[serde(default)] +pub struct OpenAIConfig { + api_base: String, + api_key: SecretString, + org_id: String, + project_id: String, +} + +impl Default for OpenAIConfig { + fn default() -> Self { + Self { + api_base: OPENAI_API_BASE.to_string(), + api_key: std::env::var("OPENAI_API_KEY") + .unwrap_or_else(|_| "".to_string()) + .into(), + org_id: Default::default(), + project_id: Default::default(), + } + } +} + +impl OpenAIConfig { + /// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var + pub fn new() -> Self { + Default::default() + } + + /// To use a different organization id other than default + pub fn with_org_id>(mut self, org_id: S) -> Self { + self.org_id = org_id.into(); + self + } + + /// Non default project id + pub fn with_project_id>(mut self, project_id: S) -> Self { + self.project_id = project_id.into(); + self + } + + /// To use a different API key different from default OPENAI_API_KEY env var + pub fn with_api_key>(mut self, api_key: S) -> Self { + self.api_key = SecretString::from(api_key.into()); + self + } + + /// To use a API base url different from default [OPENAI_API_BASE] + pub fn with_api_base>(mut self, api_base: S) -> Self { + self.api_base = api_base.into(); + self + } + + pub fn org_id(&self) -> &str { + &self.org_id + } +} + +impl Config for OpenAIConfig { + fn headers(&self) -> HeaderMap { + let mut headers = HeaderMap::new(); + if !self.org_id.is_empty() { + headers.insert( + OPENAI_ORGANIZATION_HEADER, + self.org_id.as_str().parse().unwrap(), + ); + } + + if !self.project_id.is_empty() { + headers.insert( + OPENAI_PROJECT_HEADER, + self.project_id.as_str().parse().unwrap(), + ); + } + + headers.insert( + AUTHORIZATION, + format!("Bearer {}", self.api_key.expose_secret()) + .as_str() + .parse() + .unwrap(), + ); + + // hack for Assistants APIs + // Calls to the Assistants API require that you pass a Beta header + headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap()); + + headers + } + + fn url(&self, path: &str) -> String { + format!("{}{}", self.api_base, path) + } + + fn api_base(&self) -> &str { + &self.api_base + } + + fn api_key(&self) -> &SecretString { + &self.api_key + } + + fn query(&self) -> Vec<(&str, &str)> { + vec![] + } +} + +/// Configuration for Azure OpenAI Service +#[derive(Clone, Debug, Deserialize)] +#[serde(default)] +pub struct AzureConfig { + api_version: String, + deployment_id: String, + api_base: String, + api_key: SecretString, +} + +impl Default for AzureConfig { + fn default() -> Self { + Self { + api_base: Default::default(), + api_key: std::env::var("OPENAI_API_KEY") + .unwrap_or_else(|_| "".to_string()) + .into(), + deployment_id: Default::default(), + api_version: Default::default(), + } + } +} + +impl AzureConfig { + pub fn new() -> Self { + Default::default() + } + + pub fn with_api_version>(mut self, api_version: S) -> Self { + self.api_version = api_version.into(); + self + } + + pub fn with_deployment_id>(mut self, deployment_id: S) -> Self { + self.deployment_id = deployment_id.into(); + self + } + + /// To use a different API key different from default OPENAI_API_KEY env var + pub fn with_api_key>(mut self, api_key: S) -> Self { + self.api_key = SecretString::from(api_key.into()); + self + } + + /// API base url in form of + pub fn with_api_base>(mut self, api_base: S) -> Self { + self.api_base = api_base.into(); + self + } +} + +impl Config for AzureConfig { + fn headers(&self) -> HeaderMap { + let mut headers = HeaderMap::new(); + + headers.insert("api-key", self.api_key.expose_secret().parse().unwrap()); + + headers + } + + fn url(&self, path: &str) -> String { + format!( + "{}/openai/deployments/{}{}", + self.api_base, self.deployment_id, path + ) + } + + fn api_base(&self) -> &str { + &self.api_base + } + + fn api_key(&self) -> &SecretString { + &self.api_key + } + + fn query(&self) -> Vec<(&str, &str)> { + vec![("api-version", &self.api_version)] + } +} diff --git a/clia-async-openai/src/download.rs b/clia-async-openai/src/download.rs new file mode 100644 index 00000000..087ba6f3 --- /dev/null +++ b/clia-async-openai/src/download.rs @@ -0,0 +1,80 @@ +use std::path::{Path, PathBuf}; + +use base64::{engine::general_purpose, Engine as _}; +use rand::{distributions::Alphanumeric, Rng}; +use reqwest::Url; + +use crate::error::OpenAIError; + +fn create_paths>(url: &Url, base_dir: P) -> (PathBuf, PathBuf) { + let mut dir = PathBuf::from(base_dir.as_ref()); + let mut path = dir.clone(); + let segments = url.path_segments().map(|c| c.collect::>()); + if let Some(segments) = segments { + for (idx, segment) in segments.iter().enumerate() { + if idx != segments.len() - 1 { + dir.push(segment); + } + path.push(segment); + } + } + + (dir, path) +} + +pub(crate) async fn download_url>( + url: &str, + dir: P, +) -> Result { + let parsed_url = Url::parse(url).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; + let response = reqwest::get(url) + .await + .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; + + if !response.status().is_success() { + return Err(OpenAIError::FileSaveError(format!( + "couldn't download file, status: {}, url: {url}", + response.status() + ))); + } + + let (dir, file_path) = create_paths(&parsed_url, dir); + + tokio::fs::create_dir_all(dir.as_path()) + .await + .map_err(|e| OpenAIError::FileSaveError(format!("{}, dir: {}", e, dir.display())))?; + + tokio::fs::write( + file_path.as_path(), + response.bytes().await.map_err(|e| { + OpenAIError::FileSaveError(format!("{}, file path: {}", e, file_path.display())) + })?, + ) + .await + .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; + + Ok(file_path) +} + +pub(crate) async fn save_b64>(b64: &str, dir: P) -> Result { + let filename: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(10) + .map(char::from) + .collect(); + + let filename = format!("{filename}.png"); + + let path = PathBuf::from(dir.as_ref()).join(filename); + + tokio::fs::write( + path.as_path(), + general_purpose::STANDARD + .decode(b64) + .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?, + ) + .await + .map_err(|e| OpenAIError::FileSaveError(format!("{}, path: {}", e, path.display())))?; + + Ok(path) +} diff --git a/clia-async-openai/src/embedding.rs b/clia-async-openai/src/embedding.rs new file mode 100644 index 00000000..f5759296 --- /dev/null +++ b/clia-async-openai/src/embedding.rs @@ -0,0 +1,219 @@ +use crate::{ + config::Config, + error::OpenAIError, + types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse}, + Client, +}; + +#[cfg(not(feature = "byot"))] +use crate::types::EncodingFormat; + +/// Get a vector representation of a given input that can be easily +/// consumed by machine learning models and algorithms. +/// +/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings) +pub struct Embeddings<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Embeddings<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Creates an embedding vector representing the input text. + /// + /// byot: In serialized `request` you must ensure "encoding_format" is not "base64" + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateEmbeddingRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if matches!(request.encoding_format, Some(EncodingFormat::Base64)) { + return Err(OpenAIError::InvalidArgument( + "When encoding_format is base64, use Embeddings::create_base64".into(), + )); + } + } + self.client.post("/embeddings", request).await + } + + /// Creates an embedding vector representing the input text. + /// + /// The response will contain the embedding in base64 format. + /// + /// byot: In serialized `request` you must ensure "encoding_format" is "base64" + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create_base64( + &self, + request: CreateEmbeddingRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) { + return Err(OpenAIError::InvalidArgument( + "When encoding_format is not base64, use Embeddings::create".into(), + )); + } + } + self.client.post("/embeddings", request).await + } +} + +#[cfg(test)] +mod tests { + use crate::error::OpenAIError; + use crate::types::{CreateEmbeddingResponse, Embedding, EncodingFormat}; + use crate::{types::CreateEmbeddingRequestArgs, Client}; + + #[tokio::test] + async fn test_embedding_string() { + let client = Client::new(); + + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-ada-002") + .input("The food was delicious and the waiter...") + .build() + .unwrap(); + + let response = client.embeddings().create(request).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn test_embedding_string_array() { + let client = Client::new(); + + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-ada-002") + .input(["The food was delicious", "The waiter was good"]) + .build() + .unwrap(); + + let response = client.embeddings().create(request).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn test_embedding_integer_array() { + let client = Client::new(); + + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-ada-002") + .input([1, 2, 3]) + .build() + .unwrap(); + + let response = client.embeddings().create(request).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn test_embedding_array_of_integer_array_matrix() { + let client = Client::new(); + + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-ada-002") + .input([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + .build() + .unwrap(); + + let response = client.embeddings().create(request).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn test_embedding_array_of_integer_array() { + let client = Client::new(); + + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-ada-002") + .input([vec![1, 2, 3], vec![4, 5, 6, 7], vec![7, 8, 10, 11, 100257]]) + .build() + .unwrap(); + + let response = client.embeddings().create(request).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn test_embedding_with_reduced_dimensions() { + let client = Client::new(); + let dimensions = 256u32; + let request = CreateEmbeddingRequestArgs::default() + .model("text-embedding-3-small") + .input("The food was delicious and the waiter...") + .dimensions(dimensions) + .build() + .unwrap(); + + let response = client.embeddings().create(request).await; + + assert!(response.is_ok()); + + let CreateEmbeddingResponse { mut data, .. } = response.unwrap(); + assert_eq!(data.len(), 1); + let Embedding { embedding, .. } = data.pop().unwrap(); + assert_eq!(embedding.len(), dimensions as usize); + } + + #[tokio::test] + #[cfg(not(feature = "byot"))] + async fn test_cannot_use_base64_encoding_with_normal_create_request() { + let client = Client::new(); + + const MODEL: &str = "text-embedding-ada-002"; + const INPUT: &str = "You shall not pass."; + + let b64_request = CreateEmbeddingRequestArgs::default() + .model(MODEL) + .input(INPUT) + .encoding_format(EncodingFormat::Base64) + .build() + .unwrap(); + let b64_response = client.embeddings().create(b64_request).await; + assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_)))); + } + + #[tokio::test] + async fn test_embedding_create_base64() { + let client = Client::new(); + + const MODEL: &str = "text-embedding-ada-002"; + const INPUT: &str = "CoLoop will eat the other qual research tools..."; + + let b64_request = CreateEmbeddingRequestArgs::default() + .model(MODEL) + .input(INPUT) + .encoding_format(EncodingFormat::Base64) + .build() + .unwrap(); + let b64_response = client + .embeddings() + .create_base64(b64_request) + .await + .unwrap(); + let b64_embedding = b64_response.data.into_iter().next().unwrap().embedding; + let b64_embedding: Vec = b64_embedding.into(); + + let request = CreateEmbeddingRequestArgs::default() + .model(MODEL) + .input(INPUT) + .build() + .unwrap(); + let response = client.embeddings().create(request).await.unwrap(); + let embedding = response.data.into_iter().next().unwrap().embedding; + + assert_eq!(b64_embedding.len(), embedding.len()); + for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) { + assert!((b64 - normal).abs() < 1e-6); + } + } +} diff --git a/clia-async-openai/src/error.rs b/clia-async-openai/src/error.rs new file mode 100644 index 00000000..eea51c10 --- /dev/null +++ b/clia-async-openai/src/error.rs @@ -0,0 +1,76 @@ +//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system. +use serde::Deserialize; + +#[derive(Debug, thiserror::Error)] +pub enum OpenAIError { + /// Underlying error from reqwest library after an API call was made + #[error("http error: {0}")] + Reqwest(#[from] reqwest::Error), + /// OpenAI returns error object with details of API call failure + #[error("{0}")] + ApiError(ApiError), + /// Error when a response cannot be deserialized into a Rust type + #[error("failed to deserialize api response: {0}")] + JSONDeserialize(serde_json::Error), + /// Error on the client side when saving file to file system + #[error("failed to save file: {0}")] + FileSaveError(String), + /// Error on the client side when reading file from file system + #[error("failed to read file: {0}")] + FileReadError(String), + /// Error on SSE streaming + #[error("stream failed: {0}")] + StreamError(String), + /// Error from client side validation + /// or when builder fails to build request before making API call + #[error("invalid args: {0}")] + InvalidArgument(String), +} + +/// OpenAI API returns error object on failure +#[derive(Debug, Deserialize, Clone)] +pub struct ApiError { + pub message: String, + pub r#type: Option, + pub param: Option, + pub code: Option, +} + +impl std::fmt::Display for ApiError { + /// If all fields are available, `ApiError` is formatted as: + /// `{type}: {message} (param: {param}) (code: {code})` + /// Otherwise, missing fields will be ignored. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut parts = Vec::new(); + + if let Some(r#type) = &self.r#type { + parts.push(format!("{}:", r#type)); + } + + parts.push(self.message.clone()); + + if let Some(param) = &self.param { + parts.push(format!("(param: {param})")); + } + + if let Some(code) = &self.code { + parts.push(format!("(code: {code})")); + } + + write!(f, "{}", parts.join(" ")) + } +} + +/// Wrapper to deserialize the error object nested in "error" JSON key +#[derive(Debug, Deserialize)] +pub(crate) struct WrappedError { + pub(crate) error: ApiError, +} + +pub(crate) fn map_deserialization_error(e: serde_json::Error, bytes: &[u8]) -> OpenAIError { + tracing::error!( + "failed deserialization of: {}", + String::from_utf8_lossy(bytes) + ); + OpenAIError::JSONDeserialize(e) +} diff --git a/clia-async-openai/src/file.rs b/clia-async-openai/src/file.rs new file mode 100644 index 00000000..cfca19c7 --- /dev/null +++ b/clia-async-openai/src/file.rs @@ -0,0 +1,131 @@ +use bytes::Bytes; +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{CreateFileRequest, DeleteFileResponse, ListFilesResponse, OpenAIFile}, + Client, +}; + +/// Files are used to upload documents that can be used with features like Assistants and Fine-tuning. +pub struct Files<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Files<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Upload a file that can be used across various endpoints. Individual files can be up to 512 MB, and the size of all files uploaded by one organization can be up to 100 GB. + /// + /// The Assistants API supports files up to 2 million tokens and of specific file types. See the [Assistants Tools guide](https://platform.openai.com/docs/assistants/tools) for details. + /// + /// The Fine-tuning API only supports `.jsonl` files. The input also has certain required formats for fine-tuning [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input) or [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) models. + /// + ///The Batch API only supports `.jsonl` files up to 100 MB in size. The input also has a specific required [format](https://platform.openai.com/docs/api-reference/batch/request-input). + /// + /// Please [contact us](https://help.openai.com/) if you need to increase these storage limits. + #[crate::byot( + T0 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom", + )] + pub async fn create(&self, request: CreateFileRequest) -> Result { + self.client.post_form("/files", request).await + } + + /// Returns a list of files that belong to the user's organization. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client.get_with_query("/files", &query).await + } + + /// Returns information about a specific file. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, file_id: &str) -> Result { + self.client.get(format!("/files/{file_id}").as_str()).await + } + + /// Delete a file. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, file_id: &str) -> Result { + self.client + .delete(format!("/files/{file_id}").as_str()) + .await + } + + /// Returns the contents of the specified file + pub async fn content(&self, file_id: &str) -> Result { + self.client + .get_raw(format!("/files/{file_id}/content").as_str()) + .await + } +} + +#[cfg(test)] +mod tests { + use crate::{ + types::{CreateFileRequestArgs, FilePurpose}, + Client, + }; + + #[tokio::test] + async fn test_file_mod() { + let test_file_path = "/tmp/test.jsonl"; + let contents = concat!( + "{\"prompt\": \"\", \"completion\": \"\"}\n", // \n is to make it valid jsonl + "{\"prompt\": \"\", \"completion\": \"\"}" + ); + + tokio::fs::write(test_file_path, contents).await.unwrap(); + + let client = Client::new(); + + let request = CreateFileRequestArgs::default() + .file(test_file_path) + .purpose(FilePurpose::FineTune) + .build() + .unwrap(); + + let openai_file = client.files().create(request).await.unwrap(); + + assert_eq!(openai_file.bytes, 135); + assert_eq!(openai_file.filename, "test.jsonl"); + //assert_eq!(openai_file.purpose, "fine-tune"); + + //assert_eq!(openai_file.status, Some("processed".to_owned())); // uploaded or processed + let query = [("purpose", "fine-tune")]; + + let list_files = client.files().list(&query).await.unwrap(); + + assert_eq!(list_files.data.into_iter().last().unwrap(), openai_file); + + let retrieved_file = client.files().retrieve(&openai_file.id).await.unwrap(); + + assert_eq!(openai_file.created_at, retrieved_file.created_at); + assert_eq!(openai_file.bytes, retrieved_file.bytes); + assert_eq!(openai_file.filename, retrieved_file.filename); + assert_eq!(openai_file.purpose, retrieved_file.purpose); + + /* + // "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." + let retrieved_contents = client.files().retrieve_content(&openai_file.id) + .await + .unwrap(); + + assert_eq!(contents, retrieved_contents); + */ + + // Sleep to prevent "File is still processing. Check back later." + tokio::time::sleep(std::time::Duration::from_secs(15)).await; + let delete_response = client.files().delete(&openai_file.id).await.unwrap(); + + assert_eq!(openai_file.id, delete_response.id); + assert!(delete_response.deleted); + } +} diff --git a/clia-async-openai/src/fine_tuning.rs b/clia-async-openai/src/fine_tuning.rs new file mode 100644 index 00000000..c599ae63 --- /dev/null +++ b/clia-async-openai/src/fine_tuning.rs @@ -0,0 +1,107 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + CreateFineTuningJobRequest, FineTuningJob, ListFineTuningJobCheckpointsResponse, + ListFineTuningJobEventsResponse, ListPaginatedFineTuningJobsResponse, + }, + Client, +}; + +/// Manage fine-tuning jobs to tailor a model to your specific training data. +/// +/// Related guide: [Fine-tune models](https://platform.openai.com/docs/guides/fine-tuning) +pub struct FineTuning<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> FineTuning<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Creates a job that fine-tunes a specified model from a given dataset. + /// + /// Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete. + /// + /// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning) + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateFineTuningJobRequest, + ) -> Result { + self.client.post("/fine_tuning/jobs", request).await + } + + /// List your organization's fine-tuning jobs + #[crate::byot(T0 = serde::Serialize, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list_paginated( + &self, + query: &Q, + ) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query("/fine_tuning/jobs", &query) + .await + } + + /// Gets info about the fine-tune job. + /// + /// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning) + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, fine_tuning_job_id: &str) -> Result { + self.client + .get(format!("/fine_tuning/jobs/{fine_tuning_job_id}").as_str()) + .await + } + + /// Immediately cancel a fine-tune job. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn cancel(&self, fine_tuning_job_id: &str) -> Result { + self.client + .post( + format!("/fine_tuning/jobs/{fine_tuning_job_id}/cancel").as_str(), + (), + ) + .await + } + + /// Get fine-grained status updates for a fine-tune job. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list_events( + &self, + fine_tuning_job_id: &str, + query: &Q, + ) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + format!("/fine_tuning/jobs/{fine_tuning_job_id}/events").as_str(), + &query, + ) + .await + } + + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list_checkpoints( + &self, + fine_tuning_job_id: &str, + query: &Q, + ) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + format!("/fine_tuning/jobs/{fine_tuning_job_id}/checkpoints").as_str(), + &query, + ) + .await + } +} diff --git a/clia-async-openai/src/image.rs b/clia-async-openai/src/image.rs new file mode 100644 index 00000000..fd7394a8 --- /dev/null +++ b/clia-async-openai/src/image.rs @@ -0,0 +1,53 @@ +use crate::{ + config::Config, + error::OpenAIError, + types::{ + CreateImageEditRequest, CreateImageRequest, CreateImageVariationRequest, ImagesResponse, + }, + Client, +}; + +/// Given a prompt and/or an input image, the model will generate a new image. +/// +/// Related guide: [Image generation](https://platform.openai.com/docs/guides/images) +pub struct Images<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Images<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Creates an image given a prompt. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create(&self, request: CreateImageRequest) -> Result { + self.client.post("/images/generations", request).await + } + + /// Creates an edited or extended image given an original image and a prompt. + #[crate::byot( + T0 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom", + )] + pub async fn create_edit( + &self, + request: CreateImageEditRequest, + ) -> Result { + self.client.post_form("/images/edits", request).await + } + + /// Creates a variation of a given image. + #[crate::byot( + T0 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom", + )] + pub async fn create_variation( + &self, + request: CreateImageVariationRequest, + ) -> Result { + self.client.post_form("/images/variations", request).await + } +} diff --git a/clia-async-openai/src/invites.rs b/clia-async-openai/src/invites.rs new file mode 100644 index 00000000..83600176 --- /dev/null +++ b/clia-async-openai/src/invites.rs @@ -0,0 +1,52 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{Invite, InviteDeleteResponse, InviteListResponse, InviteRequest}, + Client, +}; + +/// Invite and manage invitations for an organization. Invited users are automatically added to the Default project. +pub struct Invites<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Invites<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Returns a list of invites in the organization. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query("/organization/invites", &query) + .await + } + + /// Retrieves an invite. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, invite_id: &str) -> Result { + self.client + .get(format!("/organization/invites/{invite_id}").as_str()) + .await + } + + /// Create an invite for a user to the organization. The invite must be accepted by the user before they have access to the organization. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create(&self, request: InviteRequest) -> Result { + self.client.post("/organization/invites", request).await + } + + /// Delete an invite. If the invite has already been accepted, it cannot be deleted. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, invite_id: &str) -> Result { + self.client + .delete(format!("/organization/invites/{invite_id}").as_str()) + .await + } +} diff --git a/clia-async-openai/src/lib.rs b/clia-async-openai/src/lib.rs new file mode 100644 index 00000000..182e58ae --- /dev/null +++ b/clia-async-openai/src/lib.rs @@ -0,0 +1,187 @@ +//! Rust library for OpenAI +//! +//! ## Creating client +//! +//! ``` +//! use async_openai::{Client, config::OpenAIConfig}; +//! +//! // Create a OpenAI client with api key from env var OPENAI_API_KEY and default base url. +//! let client = Client::new(); +//! +//! // Above is shortcut for +//! let config = OpenAIConfig::default(); +//! let client = Client::with_config(config); +//! +//! // OR use API key from different source and a non default organization +//! let api_key = "sk-..."; // This secret could be from a file, or environment variable. +//! let config = OpenAIConfig::new() +//! .with_api_key(api_key) +//! .with_org_id("the-continental"); +//! +//! let client = Client::with_config(config); +//! +//! // Use custom reqwest client +//! let http_client = reqwest::ClientBuilder::new().user_agent("async-openai").build().unwrap(); +//! let client = Client::new().with_http_client(http_client); +//! ``` +//! +//! +//! ## Making requests +//! +//!``` +//!# tokio_test::block_on(async { +//! +//! use async_openai::{Client, types::{CreateCompletionRequestArgs}}; +//! +//! // Create client +//! let client = Client::new(); +//! +//! // Create request using builder pattern +//! // Every request struct has companion builder struct with same name + Args suffix +//! let request = CreateCompletionRequestArgs::default() +//! .model("gpt-3.5-turbo-instruct") +//! .prompt("Tell me the recipe of alfredo pasta") +//! .max_tokens(40_u32) +//! .build() +//! .unwrap(); +//! +//! // Call API +//! let response = client +//! .completions() // Get the API "group" (completions, images, etc.) from the client +//! .create(request) // Make the API call in that "group" +//! .await +//! .unwrap(); +//! +//! println!("{}", response.choices.first().unwrap().text); +//! # }); +//!``` +//! +//! ## Bring Your Own Types +//! +//! To use custom types for inputs and outputs, enable `byot` feature which provides additional generic methods with same name and `_byot` suffix. +//! This feature is available on methods whose return type is not `Bytes` +//! +//!``` +//!# #[cfg(feature = "byot")] +//!# tokio_test::block_on(async { +//! use async_openai::Client; +//! use serde_json::{Value, json}; +//! +//! let client = Client::new(); +//! +//! let response: Value = client +//! .chat() +//! .create_byot(json!({ +//! "messages": [ +//! { +//! "role": "developer", +//! "content": "You are a helpful assistant" +//! }, +//! { +//! "role": "user", +//! "content": "What do you think about life?" +//! } +//! ], +//! "model": "gpt-4o", +//! "store": false +//! })) +//! .await +//! .unwrap(); +//! +//! if let Some(content) = response["choices"][0]["message"]["content"].as_str() { +//! println!("{}", content); +//! } +//! # }); +//!``` +//! +//! ## Microsoft Azure +//! +//! ``` +//! use async_openai::{Client, config::AzureConfig}; +//! +//! let config = AzureConfig::new() +//! .with_api_base("https://my-resource-name.openai.azure.com") +//! .with_api_version("2023-03-15-preview") +//! .with_deployment_id("deployment-id") +//! .with_api_key("..."); +//! +//! let client = Client::with_config(config); +//! +//! // Note that `async-openai` only implements OpenAI spec +//! // and doesn't maintain parity with the spec of Azure OpenAI service. +//! +//! ``` +//! +//! +//! ## Examples +//! For full working examples for all supported features see [examples](https://github.com/64bit/async-openai/tree/main/examples) directory in the repository. +//! +#![cfg_attr(docsrs, feature(doc_cfg))] + +#[cfg(feature = "byot")] +pub(crate) use async_openai_macros::byot; + +#[cfg(not(feature = "byot"))] +pub(crate) use async_openai_macros::byot_passthrough as byot; + +mod assistants; +mod audio; +mod audit_logs; +mod batches; +mod chat; +mod client; +mod completion; +pub mod config; +mod download; +mod embedding; +pub mod error; +mod file; +mod fine_tuning; +mod image; +mod invites; +mod messages; +mod model; +mod moderation; +mod project_api_keys; +mod project_service_accounts; +mod project_users; +mod projects; +mod runs; +mod steps; +mod threads; +pub mod traits; +pub mod types; +mod uploads; +mod users; +mod util; +mod vector_store_file_batches; +mod vector_store_files; +mod vector_stores; + +pub use assistants::Assistants; +pub use audio::Audio; +pub use audit_logs::AuditLogs; +pub use batches::Batches; +pub use chat::Chat; +pub use client::Client; +pub use completion::Completions; +pub use embedding::Embeddings; +pub use file::Files; +pub use fine_tuning::FineTuning; +pub use image::Images; +pub use invites::Invites; +pub use messages::Messages; +pub use model::Models; +pub use moderation::Moderations; +pub use project_api_keys::ProjectAPIKeys; +pub use project_service_accounts::ProjectServiceAccounts; +pub use project_users::ProjectUsers; +pub use projects::Projects; +pub use runs::Runs; +pub use steps::Steps; +pub use threads::Threads; +pub use uploads::Uploads; +pub use users::Users; +pub use vector_store_file_batches::VectorStoreFileBatches; +pub use vector_store_files::VectorStoreFiles; +pub use vector_stores::VectorStores; diff --git a/clia-async-openai/src/messages.rs b/clia-async-openai/src/messages.rs new file mode 100644 index 00000000..9368e114 --- /dev/null +++ b/clia-async-openai/src/messages.rs @@ -0,0 +1,85 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + CreateMessageRequest, DeleteMessageResponse, ListMessagesResponse, MessageObject, + ModifyMessageRequest, + }, + Client, +}; + +/// Represents a message within a [thread](https://platform.openai.com/docs/api-reference/threads). +pub struct Messages<'c, C: Config> { + /// The ID of the [thread](https://platform.openai.com/docs/api-reference/threads) to create a message for. + pub thread_id: String, + client: &'c Client, +} + +impl<'c, C: Config> Messages<'c, C> { + pub fn new(client: &'c Client, thread_id: &str) -> Self { + Self { + client, + thread_id: thread_id.into(), + } + } + + /// Create a message. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateMessageRequest, + ) -> Result { + self.client + .post(&format!("/threads/{}/messages", self.thread_id), request) + .await + } + + /// Retrieve a message. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, message_id: &str) -> Result { + self.client + .get(&format!( + "/threads/{}/messages/{message_id}", + self.thread_id + )) + .await + } + + /// Modifies a message. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn update( + &self, + message_id: &str, + request: ModifyMessageRequest, + ) -> Result { + self.client + .post( + &format!("/threads/{}/messages/{message_id}", self.thread_id), + request, + ) + .await + } + + /// Returns a list of messages for a given thread. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query(&format!("/threads/{}/messages", self.thread_id), &query) + .await + } + + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, message_id: &str) -> Result { + self.client + .delete(&format!( + "/threads/{}/messages/{message_id}", + self.thread_id + )) + .await + } +} diff --git a/clia-async-openai/src/model.rs b/clia-async-openai/src/model.rs new file mode 100644 index 00000000..47cc8781 --- /dev/null +++ b/clia-async-openai/src/model.rs @@ -0,0 +1,41 @@ +use crate::{ + config::Config, + error::OpenAIError, + types::{DeleteModelResponse, ListModelResponse, Model}, + Client, +}; + +/// List and describe the various models available in the API. +/// You can refer to the [Models](https://platform.openai.com/docs/models) documentation to understand what +/// models are available and the differences between them. +pub struct Models<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Models<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Lists the currently available models, and provides basic information + /// about each one such as the owner and availability. + #[crate::byot(R = serde::de::DeserializeOwned)] + pub async fn list(&self) -> Result { + self.client.get("/models").await + } + + /// Retrieves a model instance, providing basic information about the model + /// such as the owner and permissioning. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, id: &str) -> Result { + self.client.get(format!("/models/{id}").as_str()).await + } + + /// Delete a fine-tuned model. You must have the Owner role in your organization. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, model: &str) -> Result { + self.client + .delete(format!("/models/{model}").as_str()) + .await + } +} diff --git a/clia-async-openai/src/moderation.rs b/clia-async-openai/src/moderation.rs new file mode 100644 index 00000000..6f831374 --- /dev/null +++ b/clia-async-openai/src/moderation.rs @@ -0,0 +1,29 @@ +use crate::{ + config::Config, + error::OpenAIError, + types::{CreateModerationRequest, CreateModerationResponse}, + Client, +}; + +/// Given text and/or image inputs, classifies if those inputs are potentially harmful across several categories. +/// +/// Related guide: [Moderations](https://platform.openai.com/docs/guides/moderation) +pub struct Moderations<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Moderations<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Classifies if text and/or image inputs are potentially harmful. Learn + /// more in the [moderation guide](https://platform.openai.com/docs/guides/moderation). + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateModerationRequest, + ) -> Result { + self.client.post("/moderations", request).await + } +} diff --git a/clia-async-openai/src/project_api_keys.rs b/clia-async-openai/src/project_api_keys.rs new file mode 100644 index 00000000..6f3778d0 --- /dev/null +++ b/clia-async-openai/src/project_api_keys.rs @@ -0,0 +1,66 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ProjectApiKey, ProjectApiKeyDeleteResponse, ProjectApiKeyListResponse}, + Client, +}; + +/// Manage API keys for a given project. Supports listing and deleting keys for users. +/// This API does not allow issuing keys for users, as users need to authorize themselves to generate keys. +pub struct ProjectAPIKeys<'c, C: Config> { + client: &'c Client, + pub project_id: String, +} + +impl<'c, C: Config> ProjectAPIKeys<'c, C> { + pub fn new(client: &'c Client, project_id: &str) -> Self { + Self { + client, + project_id: project_id.into(), + } + } + + /// Returns a list of API keys in the project. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + format!("/organization/projects/{}/api_keys", self.project_id).as_str(), + &query, + ) + .await + } + + /// Retrieves an API key in the project. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, api_key: &str) -> Result { + self.client + .get( + format!( + "/organization/projects/{}/api_keys/{api_key}", + self.project_id + ) + .as_str(), + ) + .await + } + + /// Deletes an API key from the project. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, api_key: &str) -> Result { + self.client + .delete( + format!( + "/organization/projects/{}/api_keys/{api_key}", + self.project_id + ) + .as_str(), + ) + .await + } +} diff --git a/clia-async-openai/src/project_service_accounts.rs b/clia-async-openai/src/project_service_accounts.rs new file mode 100644 index 00000000..04b02aaf --- /dev/null +++ b/clia-async-openai/src/project_service_accounts.rs @@ -0,0 +1,100 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + ProjectServiceAccount, ProjectServiceAccountCreateRequest, + ProjectServiceAccountCreateResponse, ProjectServiceAccountDeleteResponse, + ProjectServiceAccountListResponse, + }, + Client, +}; + +/// Manage service accounts within a project. A service account is a bot user that is not +/// associated with a user. If a user leaves an organization, their keys and membership in projects +/// will no longer work. Service accounts do not have this limitation. +/// However, service accounts can also be deleted from a project. +pub struct ProjectServiceAccounts<'c, C: Config> { + client: &'c Client, + pub project_id: String, +} + +impl<'c, C: Config> ProjectServiceAccounts<'c, C> { + pub fn new(client: &'c Client, project_id: &str) -> Self { + Self { + client, + project_id: project_id.into(), + } + } + + /// Returns a list of service accounts in the project. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + format!( + "/organization/projects/{}/service_accounts", + self.project_id + ) + .as_str(), + &query, + ) + .await + } + + /// Creates a new service account in the project. This also returns an unredacted API key for the service account. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: ProjectServiceAccountCreateRequest, + ) -> Result { + self.client + .post( + format!( + "/organization/projects/{}/service_accounts", + self.project_id + ) + .as_str(), + request, + ) + .await + } + + /// Retrieves a service account in the project. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve( + &self, + service_account_id: &str, + ) -> Result { + self.client + .get( + format!( + "/organization/projects/{}/service_accounts/{service_account_id}", + self.project_id + ) + .as_str(), + ) + .await + } + + /// Deletes a service account from the project. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete( + &self, + service_account_id: &str, + ) -> Result { + self.client + .delete( + format!( + "/organization/projects/{}/service_accounts/{service_account_id}", + self.project_id + ) + .as_str(), + ) + .await + } +} diff --git a/clia-async-openai/src/project_users.rs b/clia-async-openai/src/project_users.rs new file mode 100644 index 00000000..bd790d5a --- /dev/null +++ b/clia-async-openai/src/project_users.rs @@ -0,0 +1,86 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + ProjectUser, ProjectUserCreateRequest, ProjectUserDeleteResponse, ProjectUserListResponse, + ProjectUserUpdateRequest, + }, + Client, +}; + +/// Manage users within a project, including adding, updating roles, and removing users. +/// Users cannot be removed from the Default project, unless they are being removed from the organization. +pub struct ProjectUsers<'c, C: Config> { + client: &'c Client, + pub project_id: String, +} + +impl<'c, C: Config> ProjectUsers<'c, C> { + pub fn new(client: &'c Client, project_id: &str) -> Self { + Self { + client, + project_id: project_id.into(), + } + } + + /// Returns a list of users in the project. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + format!("/organization/projects/{}/users", self.project_id).as_str(), + &query, + ) + .await + } + + /// Adds a user to the project. Users must already be members of the organization to be added to a project. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: ProjectUserCreateRequest, + ) -> Result { + self.client + .post( + format!("/organization/projects/{}/users", self.project_id).as_str(), + request, + ) + .await + } + + /// Retrieves a user in the project. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, user_id: &str) -> Result { + self.client + .get(format!("/organization/projects/{}/users/{user_id}", self.project_id).as_str()) + .await + } + + /// Modifies a user's role in the project. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn modify( + &self, + user_id: &str, + request: ProjectUserUpdateRequest, + ) -> Result { + self.client + .post( + format!("/organization/projects/{}/users/{user_id}", self.project_id).as_str(), + request, + ) + .await + } + + /// Deletes a user from the project. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, user_id: &str) -> Result { + self.client + .delete(format!("/organization/projects/{}/users/{user_id}", self.project_id).as_str()) + .await + } +} diff --git a/clia-async-openai/src/projects.rs b/clia-async-openai/src/projects.rs new file mode 100644 index 00000000..5b058636 --- /dev/null +++ b/clia-async-openai/src/projects.rs @@ -0,0 +1,87 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + project_api_keys::ProjectAPIKeys, + types::{Project, ProjectCreateRequest, ProjectListResponse, ProjectUpdateRequest}, + Client, ProjectServiceAccounts, ProjectUsers, +}; + +/// Manage the projects within an organization includes creation, updating, and archiving or projects. +/// The Default project cannot be modified or archived. +pub struct Projects<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Projects<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + // call [ProjectUsers] group APIs + pub fn users(&self, project_id: &str) -> ProjectUsers { + ProjectUsers::new(self.client, project_id) + } + + // call [ProjectServiceAccounts] group APIs + pub fn service_accounts(&self, project_id: &str) -> ProjectServiceAccounts { + ProjectServiceAccounts::new(self.client, project_id) + } + + // call [ProjectAPIKeys] group APIs + pub fn api_keys(&self, project_id: &str) -> ProjectAPIKeys { + ProjectAPIKeys::new(self.client, project_id) + } + + /// Returns a list of projects. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query("/organization/projects", &query) + .await + } + + /// Create a new project in the organization. Projects can be created and archived, but cannot be deleted. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create(&self, request: ProjectCreateRequest) -> Result { + self.client.post("/organization/projects", request).await + } + + /// Retrieves a project. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, project_id: String) -> Result { + self.client + .get(format!("/organization/projects/{project_id}").as_str()) + .await + } + + /// Modifies a project in the organization. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn modify( + &self, + project_id: String, + request: ProjectUpdateRequest, + ) -> Result { + self.client + .post( + format!("/organization/projects/{project_id}").as_str(), + request, + ) + .await + } + + /// Archives a project in the organization. Archived projects cannot be used or updated. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn archive(&self, project_id: String) -> Result { + self.client + .post( + format!("/organization/projects/{project_id}/archive").as_str(), + (), + ) + .await + } +} diff --git a/clia-async-openai/src/runs.rs b/clia-async-openai/src/runs.rs new file mode 100644 index 00000000..4d022ec6 --- /dev/null +++ b/clia-async-openai/src/runs.rs @@ -0,0 +1,178 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + steps::Steps, + types::{ + AssistantEventStream, CreateRunRequest, ListRunsResponse, ModifyRunRequest, RunObject, + SubmitToolOutputsRunRequest, + }, + Client, +}; + +/// Represents an execution run on a thread. +/// +/// Related guide: [Assistants](https://platform.openai.com/docs/assistants/overview) +pub struct Runs<'c, C: Config> { + pub thread_id: String, + client: &'c Client, +} + +impl<'c, C: Config> Runs<'c, C> { + pub fn new(client: &'c Client, thread_id: &str) -> Self { + Self { + client, + thread_id: thread_id.into(), + } + } + + /// [Steps] API group + pub fn steps(&self, run_id: &str) -> Steps { + Steps::new(self.client, &self.thread_id, run_id) + } + + /// Create a run. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create(&self, request: CreateRunRequest) -> Result { + self.client + .post(&format!("/threads/{}/runs", self.thread_id), request) + .await + } + + /// Create a run. + /// + /// byot: You must ensure "stream: true" in serialized `request` + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned, + stream = "true", + where_clause = "R: std::marker::Send + 'static + TryFrom" + )] + #[allow(unused_mut)] + pub async fn create_stream( + &self, + mut request: CreateRunRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if request.stream.is_some() && !request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Runs::create".into(), + )); + } + + request.stream = Some(true); + } + + Ok(self + .client + .post_stream_mapped_raw_events( + &format!("/threads/{}/runs", self.thread_id), + request, + TryFrom::try_from, + ) + .await) + } + + /// Retrieves a run. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, run_id: &str) -> Result { + self.client + .get(&format!("/threads/{}/runs/{run_id}", self.thread_id)) + .await + } + + /// Modifies a run. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn update( + &self, + run_id: &str, + request: ModifyRunRequest, + ) -> Result { + self.client + .post( + &format!("/threads/{}/runs/{run_id}", self.thread_id), + request, + ) + .await + } + + /// Returns a list of runs belonging to a thread. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query(&format!("/threads/{}/runs", self.thread_id), &query) + .await + } + + /// When a run has the status: "requires_action" and required_action.type is submit_tool_outputs, this endpoint can be used to submit the outputs from the tool calls once they're all completed. All outputs must be submitted in a single request. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn submit_tool_outputs( + &self, + run_id: &str, + request: SubmitToolOutputsRunRequest, + ) -> Result { + self.client + .post( + &format!( + "/threads/{}/runs/{run_id}/submit_tool_outputs", + self.thread_id + ), + request, + ) + .await + } + + /// byot: You must ensure "stream: true" in serialized `request` + #[crate::byot( + T0 = std::fmt::Display, + T1 = serde::Serialize, + R = serde::de::DeserializeOwned, + stream = "true", + where_clause = "R: std::marker::Send + 'static + TryFrom" + )] + #[allow(unused_mut)] + pub async fn submit_tool_outputs_stream( + &self, + run_id: &str, + mut request: SubmitToolOutputsRunRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if request.stream.is_some() && !request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Runs::submit_tool_outputs".into(), + )); + } + + request.stream = Some(true); + } + + Ok(self + .client + .post_stream_mapped_raw_events( + &format!( + "/threads/{}/runs/{run_id}/submit_tool_outputs", + self.thread_id + ), + request, + TryFrom::try_from, + ) + .await) + } + + /// Cancels a run that is `in_progress` + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn cancel(&self, run_id: &str) -> Result { + self.client + .post( + &format!("/threads/{}/runs/{run_id}/cancel", self.thread_id), + (), + ) + .await + } +} diff --git a/clia-async-openai/src/steps.rs b/clia-async-openai/src/steps.rs new file mode 100644 index 00000000..924cda82 --- /dev/null +++ b/clia-async-openai/src/steps.rs @@ -0,0 +1,50 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ListRunStepsResponse, RunStepObject}, + Client, +}; + +/// Represents a step in execution of a run. +pub struct Steps<'c, C: Config> { + pub thread_id: String, + pub run_id: String, + client: &'c Client, +} + +impl<'c, C: Config> Steps<'c, C> { + pub fn new(client: &'c Client, thread_id: &str, run_id: &str) -> Self { + Self { + client, + thread_id: thread_id.into(), + run_id: run_id.into(), + } + } + + /// Retrieves a run step. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, step_id: &str) -> Result { + self.client + .get(&format!( + "/threads/{}/runs/{}/steps/{step_id}", + self.thread_id, self.run_id + )) + .await + } + + /// Returns a list of run steps belonging to a run. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + &format!("/threads/{}/runs/{}/steps", self.thread_id, self.run_id), + &query, + ) + .await + } +} diff --git a/clia-async-openai/src/threads.rs b/clia-async-openai/src/threads.rs new file mode 100644 index 00000000..8e738a67 --- /dev/null +++ b/clia-async-openai/src/threads.rs @@ -0,0 +1,101 @@ +use crate::{ + config::Config, + error::OpenAIError, + types::{ + AssistantEventStream, CreateThreadAndRunRequest, CreateThreadRequest, DeleteThreadResponse, + ModifyThreadRequest, RunObject, ThreadObject, + }, + Client, Messages, Runs, +}; + +/// Create threads that assistants can interact with. +/// +/// Related guide: [Assistants](https://platform.openai.com/docs/assistants/overview) +pub struct Threads<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Threads<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Call [Messages] group API to manage message in [thread_id] thread. + pub fn messages(&self, thread_id: &str) -> Messages { + Messages::new(self.client, thread_id) + } + + /// Call [Runs] group API to manage runs in [thread_id] thread. + pub fn runs(&self, thread_id: &str) -> Runs { + Runs::new(self.client, thread_id) + } + + /// Create a thread and run it in one request. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create_and_run( + &self, + request: CreateThreadAndRunRequest, + ) -> Result { + self.client.post("/threads/runs", request).await + } + + /// Create a thread and run it in one request (streaming). + /// + /// byot: You must ensure "stream: true" in serialized `request` + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned, + stream = "true", + where_clause = "R: std::marker::Send + 'static + TryFrom" + )] + #[allow(unused_mut)] + pub async fn create_and_run_stream( + &self, + mut request: CreateThreadAndRunRequest, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if request.stream.is_some() && !request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Threads::create_and_run".into(), + )); + } + + request.stream = Some(true); + } + Ok(self + .client + .post_stream_mapped_raw_events("/threads/runs", request, TryFrom::try_from) + .await) + } + + /// Create a thread. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create(&self, request: CreateThreadRequest) -> Result { + self.client.post("/threads", request).await + } + + /// Retrieves a thread. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, thread_id: &str) -> Result { + self.client.get(&format!("/threads/{thread_id}")).await + } + + /// Modifies a thread. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn update( + &self, + thread_id: &str, + request: ModifyThreadRequest, + ) -> Result { + self.client + .post(&format!("/threads/{thread_id}"), request) + .await + } + + /// Delete a thread. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, thread_id: &str) -> Result { + self.client.delete(&format!("/threads/{thread_id}")).await + } +} diff --git a/clia-async-openai/src/traits.rs b/clia-async-openai/src/traits.rs new file mode 100644 index 00000000..62e8ae3c --- /dev/null +++ b/clia-async-openai/src/traits.rs @@ -0,0 +1,7 @@ +pub trait AsyncTryFrom: Sized { + /// The type returned in the event of a conversion error. + type Error; + + /// Performs the conversion. + fn try_from(value: T) -> impl std::future::Future> + Send; +} diff --git a/clia-async-openai/src/types/assistant.rs b/clia-async-openai/src/types/assistant.rs new file mode 100644 index 00000000..cd0aba47 --- /dev/null +++ b/clia-async-openai/src/types/assistant.rs @@ -0,0 +1,326 @@ +use std::collections::HashMap; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +use super::{FunctionName, FunctionObject, ResponseFormat}; + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct AssistantToolCodeInterpreterResources { + ///A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool. + pub file_ids: Vec, // maxItems: 20 +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct AssistantToolFileSearchResources { + /// The ID of the [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + pub vector_store_ids: Vec, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct AssistantToolResources { + #[serde(skip_serializing_if = "Option::is_none")] + pub code_interpreter: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_search: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct CreateAssistantToolResources { + #[serde(skip_serializing_if = "Option::is_none")] + pub code_interpreter: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_search: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct CreateAssistantToolFileSearchResources { + /// The [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + pub vector_store_ids: Option>, + /// A helper to create a [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant. + pub vector_stores: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct AssistantVectorStore { + /// A list of [file](https://platform.openai.com/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store. + pub file_ids: Vec, + + /// The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + pub chunking_strategy: Option, + + /// Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + pub metadata: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +#[serde(tag = "type")] +pub enum AssistantVectorStoreChunkingStrategy { + /// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + #[default] + #[serde(rename = "auto")] + Auto, + #[serde(rename = "static")] + Static { r#static: StaticChunkingStrategy }, +} + +/// Static Chunking Strategy +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct StaticChunkingStrategy { + /// The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + pub max_chunk_size_tokens: u16, + /// The number of tokens that overlap between chunks. The default value is `400`. + /// + /// Note that the overlap must not exceed half of `max_chunk_size_tokens`. + pub chunk_overlap_tokens: u16, +} + +/// Represents an `assistant` that can call the model and use tools. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct AssistantObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `assistant`. + pub object: String, + /// The Unix timestamp (in seconds) for when the assistant was created. + pub created_at: i32, + /// The name of the assistant. The maximum length is 256 characters. + pub name: Option, + /// The description of the assistant. The maximum length is 512 characters. + pub description: Option, + /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models) for descriptions of them. + pub model: String, + /// The system instructions that the assistant uses. The maximum length is 256,000 characters. + pub instructions: Option, + /// A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`. + #[serde(default)] + pub tools: Vec, + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + pub temperature: Option, + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// We generally recommend altering this or temperature but not both. + pub top_p: Option, + + pub response_format: Option, +} + +/// Specifies the format that the model must output. Compatible with [GPT-4o](https://platform.openai.com/docs/models/gpt-4o), [GPT-4 Turbo](https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. +/// +/// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). +/// +/// Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. +/// +/// **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub enum AssistantsApiResponseFormatOption { + #[default] + #[serde(rename = "auto")] + Auto, + #[serde(untagged)] + Format(ResponseFormat), +} + +/// Retrieval tool +#[derive(Clone, Serialize, Debug, Default, Deserialize, PartialEq)] +pub struct AssistantToolsFileSearch { + /// Overrides for the file search tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub file_search: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct AssistantToolsFileSearchOverrides { + /// The maximum number of results the file search tool should output. The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. This number should be between 1 and 50 inclusive. + /// + //// Note that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search/customizing-file-search-settings) for more information. + pub max_num_results: Option, + pub ranking_options: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub enum FileSearchRanker { + #[serde(rename = "auto")] + Auto, + #[serde(rename = "default_2024_08_21")] + Default2024_08_21, +} + +/// The ranking options for the file search. If not specified, the file search tool will use the `auto` ranker and a score_threshold of 0. +/// +/// See the [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings) for more information. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FileSearchRankingOptions { + /// The ranker to use for the file search. If not specified will use the `auto` ranker. + #[serde(skip_serializing_if = "Option::is_none")] + pub ranker: Option, + + /// The score threshold for the file search. All values must be a floating point number between 0 and 1. + pub score_threshold: f32, +} + +/// Function tool +#[derive(Clone, Serialize, Debug, Default, Deserialize, PartialEq)] +pub struct AssistantToolsFunction { + pub function: FunctionObject, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum AssistantTools { + CodeInterpreter, + FileSearch(AssistantToolsFileSearch), + Function(AssistantToolsFunction), +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "CreateAssistantRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateAssistantRequest { + /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them. + pub model: String, + + /// The name of the assistant. The maximum length is 256 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// The description of the assistant. The maximum length is 512 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + + /// The system instructions that the assistant uses. The maximum length is 256,000 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_resources: Option, + + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or temperature but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "ModifyAssistantRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ModifyAssistantRequest { + /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// The name of the assistant. The maximum length is 256 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// The description of the assistant. The maximum length is 512 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + + /// The system instructions that the assistant uses. The maximum length is 256,000 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or temperature but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct DeleteAssistantResponse { + pub id: String, + pub deleted: bool, + pub object: String, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct ListAssistantsResponse { + pub object: String, + pub data: Vec, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, +} + +/// Controls which (if any) tool is called by the model. +/// `none` means the model will not call any tools and instead generates a message. +/// `auto` is the default value and means the model can pick between generating a message or calling one or more tools. +/// `required` means the model must call one or more tools before responding to the user. +/// Specifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum AssistantsApiToolChoiceOption { + #[default] + None, + Auto, + Required, + #[serde(untagged)] + Named(AssistantsNamedToolChoice), +} + +/// Specifies a tool the model should use. Use to force the model to call a specific tool. +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct AssistantsNamedToolChoice { + /// The type of the tool. If type is `function`, the function name must be set + pub r#type: AssistantToolType, + + pub function: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum AssistantToolType { + #[default] + Function, + CodeInterpreter, + FileSearch, +} diff --git a/clia-async-openai/src/types/assistant_impls.rs b/clia-async-openai/src/types/assistant_impls.rs new file mode 100644 index 00000000..bd8d4bf7 --- /dev/null +++ b/clia-async-openai/src/types/assistant_impls.rs @@ -0,0 +1,65 @@ +use super::{ + AssistantToolCodeInterpreterResources, AssistantToolFileSearchResources, + AssistantToolResources, AssistantTools, AssistantToolsFileSearch, AssistantToolsFunction, + CreateAssistantToolFileSearchResources, CreateAssistantToolResources, FunctionObject, +}; + +impl From for AssistantTools { + fn from(value: AssistantToolsFileSearch) -> Self { + Self::FileSearch(value) + } +} + +impl From for AssistantTools { + fn from(value: AssistantToolsFunction) -> Self { + Self::Function(value) + } +} + +impl From for AssistantToolsFunction { + fn from(value: FunctionObject) -> Self { + Self { function: value } + } +} + +impl From for AssistantTools { + fn from(value: FunctionObject) -> Self { + Self::Function(value.into()) + } +} + +impl From for CreateAssistantToolResources { + fn from(value: CreateAssistantToolFileSearchResources) -> Self { + Self { + code_interpreter: None, + file_search: Some(value), + } + } +} + +impl From for CreateAssistantToolResources { + fn from(value: AssistantToolCodeInterpreterResources) -> Self { + Self { + code_interpreter: Some(value), + file_search: None, + } + } +} + +impl From for AssistantToolResources { + fn from(value: AssistantToolCodeInterpreterResources) -> Self { + Self { + code_interpreter: Some(value), + file_search: None, + } + } +} + +impl From for AssistantToolResources { + fn from(value: AssistantToolFileSearchResources) -> Self { + Self { + code_interpreter: None, + file_search: Some(value), + } + } +} diff --git a/clia-async-openai/src/types/assistant_stream.rs b/clia-async-openai/src/types/assistant_stream.rs new file mode 100644 index 00000000..755a322d --- /dev/null +++ b/clia-async-openai/src/types/assistant_stream.rs @@ -0,0 +1,215 @@ +use std::pin::Pin; + +use futures::Stream; +use serde::Deserialize; + +use crate::error::{map_deserialization_error, ApiError, OpenAIError}; + +use super::{ + MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject, +}; + +/// Represents an event emitted when streaming a Run. +/// +/// Each event in a server-sent events stream has an `event` and `data` property: +/// +/// ```text +/// event: thread.created +/// data: {"id": "thread_123", "object": "thread", ...} +/// ``` +/// +/// We emit events whenever a new object is created, transitions to a new state, or is being +/// streamed in parts (deltas). For example, we emit `thread.run.created` when a new run +/// is created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses +/// to create a message during a run, we emit a `thread.message.created event`, a +/// `thread.message.in_progress` event, many `thread.message.delta` events, and finally a +/// `thread.message.completed` event. +/// +/// We may add additional events over time, so we recommend handling unknown events gracefully +/// in your code. See the [Assistants API quickstart](https://platform.openai.com/docs/assistants/overview) to learn how to +/// integrate the Assistants API with streaming. + +#[derive(Debug, Deserialize, Clone)] +#[serde(tag = "event", content = "data")] +#[non_exhaustive] +pub enum AssistantStreamEvent { + /// Occurs when a new [thread](https://platform.openai.com/docs/api-reference/threads/object) is created. + #[serde(rename = "thread.created")] + TreadCreated(ThreadObject), + /// Occurs when a new [run](https://platform.openai.com/docs/api-reference/runs/object) is created. + #[serde(rename = "thread.run.created")] + ThreadRunCreated(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to a `queued` status. + #[serde(rename = "thread.run.queued")] + ThreadRunQueued(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to an `in_progress` status. + #[serde(rename = "thread.run.in_progress")] + ThreadRunInProgress(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to a `requires_action` status. + #[serde(rename = "thread.run.requires_action")] + ThreadRunRequiresAction(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) is completed. + #[serde(rename = "thread.run.completed")] + ThreadRunCompleted(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) ends with status `incomplete`. + #[serde(rename = "thread.run.incomplete")] + ThreadRunIncomplete(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) fails. + #[serde(rename = "thread.run.failed")] + ThreadRunFailed(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to a `cancelling` status. + #[serde(rename = "thread.run.cancelling")] + ThreadRunCancelling(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) is cancelled. + #[serde(rename = "thread.run.cancelled")] + ThreadRunCancelled(RunObject), + /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) expires. + #[serde(rename = "thread.run.expired")] + ThreadRunExpired(RunObject), + /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) is created. + #[serde(rename = "thread.run.step.created")] + ThreadRunStepCreated(RunStepObject), + /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) moves to an `in_progress` state. + #[serde(rename = "thread.run.step.in_progress")] + ThreadRunStepInProgress(RunStepObject), + /// Occurs when parts of a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) are being streamed. + #[serde(rename = "thread.run.step.delta")] + ThreadRunStepDelta(RunStepDeltaObject), + /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) is completed. + #[serde(rename = "thread.run.step.completed")] + ThreadRunStepCompleted(RunStepObject), + /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) fails. + #[serde(rename = "thread.run.step.failed")] + ThreadRunStepFailed(RunStepObject), + /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) is cancelled. + #[serde(rename = "thread.run.step.cancelled")] + ThreadRunStepCancelled(RunStepObject), + /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) expires. + #[serde(rename = "thread.run.step.expired")] + ThreadRunStepExpired(RunStepObject), + /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) is created. + #[serde(rename = "thread.message.created")] + ThreadMessageCreated(MessageObject), + /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) moves to an `in_progress` state. + #[serde(rename = "thread.message.in_progress")] + ThreadMessageInProgress(MessageObject), + /// Occurs when parts of a [Message](https://platform.openai.com/docs/api-reference/messages/object) are being streamed. + #[serde(rename = "thread.message.delta")] + ThreadMessageDelta(MessageDeltaObject), + /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) is completed. + #[serde(rename = "thread.message.completed")] + ThreadMessageCompleted(MessageObject), + /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) ends before it is completed. + #[serde(rename = "thread.message.incomplete")] + ThreadMessageIncomplete(MessageObject), + /// Occurs when an [error](https://platform.openai.com/docs/guides/error-codes/api-errors) occurs. This can happen due to an internal server error or a timeout. + #[serde(rename = "error")] + ErrorEvent(ApiError), + /// Occurs when a stream ends. + #[serde(rename = "done")] + Done(String), +} + +pub type AssistantEventStream = + Pin> + Send>>; + +impl TryFrom for AssistantStreamEvent { + type Error = OpenAIError; + fn try_from(value: eventsource_stream::Event) -> Result { + match value.event.as_str() { + "thread.created" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::TreadCreated), + "thread.run.created" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunCreated), + "thread.run.queued" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunQueued), + "thread.run.in_progress" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunInProgress), + "thread.run.requires_action" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunRequiresAction), + "thread.run.completed" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunCompleted), + "thread.run.incomplete" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunIncomplete), + "thread.run.failed" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunFailed), + "thread.run.cancelling" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunCancelling), + "thread.run.cancelled" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunCancelled), + "thread.run.expired" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunExpired), + "thread.run.step.created" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunStepCreated), + "thread.run.step.in_progress" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunStepInProgress) + } + "thread.run.step.delta" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunStepDelta) + } + "thread.run.step.completed" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunStepCompleted) + } + "thread.run.step.failed" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunStepFailed), + "thread.run.step.cancelled" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunStepCancelled) + } + "thread.run.step.expired" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadRunStepExpired), + "thread.message.created" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadMessageCreated), + "thread.message.in_progress" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadMessageInProgress) + } + "thread.message.delta" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadMessageDelta) + } + "thread.message.completed" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadMessageCompleted) + } + "thread.message.incomplete" => { + serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ThreadMessageIncomplete) + } + "error" => serde_json::from_str::(value.data.as_str()) + .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map(AssistantStreamEvent::ErrorEvent), + "done" => Ok(AssistantStreamEvent::Done(value.data)), + + _ => Err(OpenAIError::StreamError( + "Unrecognized event: {value:?#}".into(), + )), + } + } +} diff --git a/clia-async-openai/src/types/audio.rs b/clia-async-openai/src/types/audio.rs new file mode 100644 index 00000000..e84f21db --- /dev/null +++ b/clia-async-openai/src/types/audio.rs @@ -0,0 +1,249 @@ +use bytes::Bytes; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use super::InputSource; +use crate::error::OpenAIError; + +#[derive(Debug, Default, Clone, PartialEq)] +pub struct AudioInput { + pub source: InputSource, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum AudioResponseFormat { + #[default] + Json, + Text, + Srt, + VerboseJson, + Vtt, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum SpeechResponseFormat { + #[default] + Mp3, + Opus, + Aac, + Flac, + Pcm, + Wav, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +pub enum Voice { + #[default] + Alloy, + Ash, + Coral, + Echo, + Fable, + Onyx, + Nova, + Sage, + Shimmer, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub enum SpeechModel { + #[default] + #[serde(rename = "tts-1")] + Tts1, + #[serde(rename = "tts-1-hd")] + Tts1Hd, + #[serde(untagged)] + Other(String), +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum TimestampGranularity { + Word, + #[default] + Segment, +} + +#[derive(Clone, Default, Debug, Builder, PartialEq)] +#[builder(name = "CreateTranscriptionRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateTranscriptionRequest { + /// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm. + pub file: AudioInput, + + /// ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + pub model: String, + + /// An optional text to guide the model's style or continue a previous audio segment. The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) should match the audio language. + pub prompt: Option, + + /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. + pub response_format: Option, + + /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + pub temperature: Option, // default: 0 + + /// The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency. + pub language: Option, + + /// The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency. + pub timestamp_granularities: Option>, +} + +/// Represents a transcription response returned by model, based on the provided +/// input. +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CreateTranscriptionResponseJson { + /// The transcribed text. + pub text: String, +} + +/// Represents a verbose json transcription response returned by model, based on +/// the provided input. +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CreateTranscriptionResponseVerboseJson { + /// The language of the input audio. + pub language: String, + + /// The duration of the input audio. + pub duration: f32, + + /// The transcribed text. + pub text: String, + + /// Extracted words and their corresponding timestamps. + #[serde(skip_serializing_if = "Option::is_none")] + pub words: Option>, + + /// Segments of the transcribed text and their corresponding details. + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct TranscriptionWord { + /// The text content of the word. + pub word: String, + + /// Start time of the word in seconds. + pub start: f32, + + /// End time of the word in seconds. + pub end: f32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct TranscriptionSegment { + /// Unique identifier of the segment. + pub id: i32, + + // Seek offset of the segment. + pub seek: i32, + + /// Start time of the segment in seconds. + pub start: f32, + + /// End time of the segment in seconds. + pub end: f32, + + /// Text content of the segment. + pub text: String, + + /// Array of token IDs for the text content. + pub tokens: Vec, + + /// Temperature parameter used for generating the segment. + pub temperature: f32, + + /// Average logprob of the segment. If the value is lower than -1, consider + /// the logprobs failed. + pub avg_logprob: f32, + + /// Compression ratio of the segment. If the value is greater than 2.4, + /// consider the compression failed. + pub compression_ratio: f32, + + /// Probability of no speech in the segment. If the value is higher than 1.0 + /// and the `avg_logprob` is below -1, consider this segment silent. + pub no_speech_prob: f32, +} + +#[derive(Clone, Default, Debug, Builder, PartialEq, Serialize, Deserialize)] +#[builder(name = "CreateSpeechRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateSpeechRequest { + /// The text to generate audio for. The maximum length is 4096 characters. + pub input: String, + + /// One of the available [TTS models](https://platform.openai.com/docs/models/tts): `tts-1` or `tts-1-hd` + pub model: SpeechModel, + + /// The voice to use when generating the audio. Supported voices are `alloy`, `ash`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage` and `shimmer`. + /// Previews of the voices are available in the [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options). + pub voice: Voice, + + /// The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + /// The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default. + #[serde(skip_serializing_if = "Option::is_none")] + pub speed: Option, // default: 1.0 +} + +#[derive(Clone, Default, Debug, Builder, PartialEq)] +#[builder(name = "CreateTranslationRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateTranslationRequest { + /// The audio file object (not file name) translate, in one of these + ///formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + pub file: AudioInput, + + /// ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + pub model: String, + + /// An optional text to guide the model's style or continue a previous audio segment. The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) should be in English. + pub prompt: Option, + + /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. + pub response_format: Option, + + /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + pub temperature: Option, // default: 0 +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateTranslationResponseJson { + pub text: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CreateTranslationResponseVerboseJson { + /// The language of the output translation (always `english`). + pub language: String, + /// The duration of the input audio. + pub duration: String, + /// The translated text. + pub text: String, + /// Segments of the translated text and their corresponding details. + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, +} + +#[derive(Debug, Clone)] +pub struct CreateSpeechResponse { + pub bytes: Bytes, +} diff --git a/clia-async-openai/src/types/audit_log.rs b/clia-async-openai/src/types/audit_log.rs new file mode 100644 index 00000000..d652cf40 --- /dev/null +++ b/clia-async-openai/src/types/audit_log.rs @@ -0,0 +1,434 @@ +use serde::{Deserialize, Serialize}; + +/// The event type. +#[derive(Debug, Serialize, Deserialize)] +pub enum AuditLogEventType { + #[serde(rename = "api_key.created")] + ApiKeyCreated, + #[serde(rename = "api_key.updated")] + ApiKeyUpdated, + #[serde(rename = "api_key.deleted")] + ApiKeyDeleted, + #[serde(rename = "invite.sent")] + InviteSent, + #[serde(rename = "invite.accepted")] + InviteAccepted, + #[serde(rename = "invite.deleted")] + InviteDeleted, + #[serde(rename = "login.succeeded")] + LoginSucceeded, + #[serde(rename = "login.failed")] + LoginFailed, + #[serde(rename = "logout.succeeded")] + LogoutSucceeded, + #[serde(rename = "logout.failed")] + LogoutFailed, + #[serde(rename = "organization.updated")] + OrganizationUpdated, + #[serde(rename = "project.created")] + ProjectCreated, + #[serde(rename = "project.updated")] + ProjectUpdated, + #[serde(rename = "project.archived")] + ProjectArchived, + #[serde(rename = "service_account.created")] + ServiceAccountCreated, + #[serde(rename = "service_account.updated")] + ServiceAccountUpdated, + #[serde(rename = "service_account.deleted")] + ServiceAccountDeleted, + #[serde(rename = "user.added")] + UserAdded, + #[serde(rename = "user.updated")] + UserUpdated, + #[serde(rename = "user.deleted")] + UserDeleted, +} + +/// Represents a list of audit logs. +#[derive(Debug, Serialize, Deserialize)] +pub struct ListAuditLogsResponse { + /// The object type, which is always `list`. + pub object: String, + /// A list of `AuditLog` objects. + pub data: Vec, + /// The first `audit_log_id` in the retrieved `list`. + pub first_id: String, + /// The last `audit_log_id` in the retrieved `list`. + pub last_id: String, + /// The `has_more` property is used for pagination to indicate there are additional results. + pub has_more: bool, +} + +/// The project that the action was scoped to. Absent for actions not scoped to projects. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogProject { + /// The project ID. + pub id: String, + /// The project title. + pub name: String, +} + +/// The actor who performed the audit logged action. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogActor { + /// The type of actor. Is either `session` or `api_key`. + pub r#type: String, + /// The session in which the audit logged action was performed. + pub session: Option, + /// The API Key used to perform the audit logged action. + pub api_key: Option, +} + +/// The session in which the audit logged action was performed. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogActorSession { + /// The user who performed the audit logged action. + pub user: AuditLogActorUser, + /// The IP address from which the action was performed. + pub ip_address: String, +} + +/// The API Key used to perform the audit logged action. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogActorApiKey { + /// The tracking id of the API key. + pub id: String, + /// The type of API key. Can be either `user` or `service_account`. + pub r#type: AuditLogActorApiKeyType, + /// The user who performed the audit logged action, if applicable. + pub user: Option, + /// The service account that performed the audit logged action, if applicable. + pub service_account: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AuditLogActorApiKeyType { + User, + ServiceAccount, +} + +/// The user who performed the audit logged action. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogActorUser { + /// The user id. + pub id: String, + /// The user email. + pub email: String, +} + +/// The service account that performed the audit logged action. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogActorServiceAccount { + /// The service account id. + pub id: String, +} + +/// A log of a user action or configuration change within this organization. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLog { + /// The ID of this log. + pub id: String, + /// The event type. + pub r#type: AuditLogEventType, + /// The Unix timestamp (in seconds) of the event. + pub effective_at: u32, + /// The project that the action was scoped to. Absent for actions not scoped to projects. + pub project: Option, + /// The actor who performed the audit logged action. + pub actor: AuditLogActor, + /// The details for events with the type `api_key.created`. + #[serde(rename = "api_key.created")] + pub api_key_created: Option, + /// The details for events with the type `api_key.updated`. + #[serde(rename = "api_key.updated")] + pub api_key_updated: Option, + /// The details for events with the type `api_key.deleted`. + #[serde(rename = "api_key.deleted")] + pub api_key_deleted: Option, + /// The details for events with the type `invite.sent`. + #[serde(rename = "invite.sent")] + pub invite_sent: Option, + /// The details for events with the type `invite.accepted`. + #[serde(rename = "invite.accepted")] + pub invite_accepted: Option, + /// The details for events with the type `invite.deleted`. + #[serde(rename = "invite.deleted")] + pub invite_deleted: Option, + /// The details for events with the type `login.failed`. + #[serde(rename = "login.failed")] + pub login_failed: Option, + /// The details for events with the type `logout.failed`. + #[serde(rename = "logout.failed")] + pub logout_failed: Option, + /// The details for events with the type `organization.updated`. + #[serde(rename = "organization.updated")] + pub organization_updated: Option, + /// The details for events with the type `project.created`. + #[serde(rename = "project.created")] + pub project_created: Option, + /// The details for events with the type `project.updated`. + #[serde(rename = "project.updated")] + pub project_updated: Option, + /// The details for events with the type `project.archived`. + #[serde(rename = "project.archived")] + pub project_archived: Option, + /// The details for events with the type `service_account.created`. + #[serde(rename = "service_account.created")] + pub service_account_created: Option, + /// The details for events with the type `service_account.updated`. + #[serde(rename = "service_account.updated")] + pub service_account_updated: Option, + /// The details for events with the type `service_account.deleted`. + #[serde(rename = "service_account.deleted")] + pub service_account_deleted: Option, + /// The details for events with the type `user.added`. + #[serde(rename = "user.added")] + pub user_added: Option, + /// The details for events with the type `user.updated`. + #[serde(rename = "user.updated")] + pub user_updated: Option, + /// The details for events with the type `user.deleted`. + #[serde(rename = "user.deleted")] + pub user_deleted: Option, +} + +/// The details for events with the type `api_key.created`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogApiKeyCreated { + /// The tracking ID of the API key. + pub id: String, + /// The payload used to create the API key. + pub data: Option, +} + +/// The payload used to create the API key. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogApiKeyCreatedData { + /// A list of scopes allowed for the API key, e.g. `["api.model.request"]`. + pub scopes: Option>, +} + +/// The details for events with the type `api_key.updated`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogApiKeyUpdated { + /// The tracking ID of the API key. + pub id: String, + /// The payload used to update the API key. + pub changes_requested: Option, +} + +/// The payload used to update the API key. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogApiKeyUpdatedChangesRequested { + /// A list of scopes allowed for the API key, e.g. `["api.model.request"]`. + pub scopes: Option>, +} + +/// The details for events with the type `api_key.deleted`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogApiKeyDeleted { + /// The tracking ID of the API key. + pub id: String, +} + +/// The details for events with the type `invite.sent`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogInviteSent { + /// The ID of the invite. + pub id: String, + /// The payload used to create the invite. + pub data: Option, +} + +/// The payload used to create the invite. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogInviteSentData { + /// The email invited to the organization. + pub email: String, + /// The role the email was invited to be. Is either `owner` or `member`. + pub role: String, +} + +/// The details for events with the type `invite.accepted`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogInviteAccepted { + /// The ID of the invite. + pub id: String, +} + +/// The details for events with the type `invite.deleted`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogInviteDeleted { + /// The ID of the invite. + pub id: String, +} + +/// The details for events with the type `login.failed`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogLoginFailed { + /// The error code of the failure. + pub error_code: String, + /// The error message of the failure. + pub error_message: String, +} + +/// The details for events with the type `logout.failed`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogLogoutFailed { + /// The error code of the failure. + pub error_code: String, + /// The error message of the failure. + pub error_message: String, +} + +/// The details for events with the type `organization.updated`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogOrganizationUpdated { + /// The organization ID. + pub id: String, + /// The payload used to update the organization settings. + pub changes_requested: Option, +} + +/// The payload used to update the organization settings. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogOrganizationUpdatedChangesRequested { + /// The organization title. + pub title: Option, + /// The organization description. + pub description: Option, + /// The organization name. + pub name: Option, + /// The organization settings. + pub settings: Option, +} + +/// The organization settings. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogOrganizationUpdatedChangesRequestedSettings { + /// Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`. + pub threads_ui_visibility: Option, + /// Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`. + pub usage_dashboard_visibility: Option, +} + +/// The details for events with the type `project.created`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogProjectCreated { + /// The project ID. + pub id: String, + /// The payload used to create the project. + pub data: Option, +} + +/// The payload used to create the project. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogProjectCreatedData { + /// The project name. + pub name: String, + /// The title of the project as seen on the dashboard. + pub title: Option, +} + +/// The details for events with the type `project.updated`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogProjectUpdated { + /// The project ID. + pub id: String, + /// The payload used to update the project. + pub changes_requested: Option, +} + +/// The payload used to update the project. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogProjectUpdatedChangesRequested { + /// The title of the project as seen on the dashboard. + pub title: Option, +} + +/// The details for events with the type `project.archived`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogProjectArchived { + /// The project ID. + pub id: String, +} + +/// The details for events with the type `service_account.created`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogServiceAccountCreated { + /// The service account ID. + pub id: String, + /// The payload used to create the service account. + pub data: Option, +} + +/// The payload used to create the service account. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogServiceAccountCreatedData { + /// The role of the service account. Is either `owner` or `member`. + pub role: String, +} + +/// The details for events with the type `service_account.updated`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogServiceAccountUpdated { + /// The service account ID. + pub id: String, + /// The payload used to updated the service account. + pub changes_requested: Option, +} + +/// The payload used to updated the service account. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogServiceAccountUpdatedChangesRequested { + /// The role of the service account. Is either `owner` or `member`. + pub role: String, +} + +/// The details for events with the type `service_account.deleted`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogServiceAccountDeleted { + /// The service account ID. + pub id: String, +} + +/// The details for events with the type `user.added`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogUserAdded { + /// The user ID. + pub id: String, + /// The payload used to add the user to the project. + pub data: Option, +} + +/// The payload used to add the user to the project. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogUserAddedData { + /// The role of the user. Is either `owner` or `member`. + pub role: String, +} + +/// The details for events with the type `user.updated`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogUserUpdated { + /// The project ID. + pub id: String, + /// The payload used to update the user. + pub changes_requested: Option, +} + +/// The payload used to update the user. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogUserUpdatedChangesRequested { + /// The role of the user. Is either `owner` or `member`. + pub role: String, +} + +/// The details for events with the type `user.deleted`. +#[derive(Debug, Serialize, Deserialize)] +pub struct AuditLogUserDeleted { + /// The user ID. + pub id: String, +} diff --git a/clia-async-openai/src/types/batch.rs b/clia-async-openai/src/types/batch.rs new file mode 100644 index 00000000..5546285e --- /dev/null +++ b/clia-async-openai/src/types/batch.rs @@ -0,0 +1,188 @@ +use std::collections::HashMap; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)] +#[builder(name = "BatchRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct BatchRequest { + /// The ID of an uploaded file that contains requests for the new batch. + /// + /// See [upload file](https://platform.openai.com/docs/api-reference/files/create) for how to upload a file. + /// + /// Your input file must be formatted as a [JSONL file](https://platform.openai.com/docs/api-reference/batch/request-input), and must be uploaded with the purpose `batch`. The file can contain up to 50,000 requests, and can be up to 100 MB in size. + pub input_file_id: String, + + /// The endpoint to be used for all requests in the batch. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. Note that `/v1/embeddings` batches are also restricted to a maximum of 50,000 embedding inputs across all requests in the batch. + pub endpoint: BatchEndpoint, + + /// The time frame within which the batch should be processed. Currently only `24h` is supported. + pub completion_window: BatchCompletionWindow, + + /// Optional custom metadata for the batch. + pub metadata: Option>, +} + +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Default)] +pub enum BatchEndpoint { + #[default] + #[serde(rename = "/v1/chat/completions")] + V1ChatCompletions, + #[serde(rename = "/v1/embeddings")] + V1Embeddings, + #[serde(rename = "/v1/completions")] + V1Completions, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Default, Deserialize)] +pub enum BatchCompletionWindow { + #[default] + #[serde(rename = "24h")] + W24H, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct Batch { + pub id: String, + /// The object type, which is always `batch`. + pub object: String, + /// The OpenAI API endpoint used by the batch. + pub endpoint: String, + pub errors: Option, + /// The ID of the input file for the batch. + pub input_file_id: String, + /// The time frame within which the batch should be processed. + pub completion_window: String, + /// The current status of the batch. + pub status: BatchStatus, + /// The ID of the file containing the outputs of successfully executed requests. + pub output_file_id: Option, + /// The ID of the file containing the outputs of requests with errors. + pub error_file_id: Option, + /// The Unix timestamp (in seconds) for when the batch was created. + pub created_at: u32, + /// The Unix timestamp (in seconds) for when the batch started processing. + pub in_progress_at: Option, + /// The Unix timestamp (in seconds) for when the batch will expire. + pub expires_at: Option, + /// The Unix timestamp (in seconds) for when the batch started finalizing. + pub finalizing_at: Option, + /// The Unix timestamp (in seconds) for when the batch was completed. + pub completed_at: Option, + /// The Unix timestamp (in seconds) for when the batch failed. + pub failed_at: Option, + /// he Unix timestamp (in seconds) for when the batch expired. + pub expired_at: Option, + /// The Unix timestamp (in seconds) for when the batch started cancelling. + pub cancelling_at: Option, + /// The Unix timestamp (in seconds) for when the batch was cancelled. + pub cancelled_at: Option, + /// The request counts for different statuses within the batch. + pub request_counts: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct BatchErrors { + /// The object type, which is always `list`. + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct BatchError { + /// An error code identifying the error type. + pub code: String, + /// A human-readable message providing more details about the error. + pub message: String, + /// The name of the parameter that caused the error, if applicable. + pub param: Option, + /// The line number of the input file where the error occurred, if applicable. + pub line: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum BatchStatus { + Validating, + Failed, + InProgress, + Finalizing, + Completed, + Expired, + Cancelling, + Cancelled, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct BatchRequestCounts { + /// Total number of requests in the batch. + pub total: u32, + /// Number of requests that have been completed successfully. + pub completed: u32, + /// Number of requests that have failed. + pub failed: u32, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct ListBatchesResponse { + pub data: Vec, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, + pub object: String, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +#[serde(rename_all = "UPPERCASE")] +pub enum BatchRequestInputMethod { + POST, +} + +/// The per-line object of the batch input file +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct BatchRequestInput { + /// A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch. + pub custom_id: String, + /// The HTTP method to be used for the request. Currently only `POST` is supported. + pub method: BatchRequestInputMethod, + /// The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. + pub url: BatchEndpoint, + pub body: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct BatchRequestOutputResponse { + /// The HTTP status code of the response + pub status_code: u16, + /// An unique identifier for the OpenAI API request. Please include this request ID when contacting support. + pub request_id: String, + /// The JSON body of the response + pub body: serde_json::Value, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct BatchRequestOutputError { + /// A machine-readable error code. + pub code: String, + /// A human-readable error message. + pub message: String, +} + +/// The per-line object of the batch output and error files +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct BatchRequestOutput { + pub id: String, + /// A developer-provided per-request id that will be used to match outputs to inputs. + pub custom_id: String, + pub response: Option, + /// For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure. + pub error: Option, +} diff --git a/clia-async-openai/src/types/chat.rs b/clia-async-openai/src/types/chat.rs new file mode 100644 index 00000000..b60011d0 --- /dev/null +++ b/clia-async-openai/src/types/chat.rs @@ -0,0 +1,988 @@ +use std::{collections::HashMap, pin::Pin}; + +use derive_builder::Builder; +use futures::Stream; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum Prompt { + String(String), + StringArray(Vec), + // Minimum value is 0, maximum value is 50256 (inclusive). + IntegerArray(Vec), + ArrayOfIntegerArray(Vec>), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum Stop { + String(String), // nullable: true + StringArray(Vec), // minItems: 1; maxItems: 4 +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Logprobs { + pub tokens: Vec, + pub token_logprobs: Vec>, // Option is to account for null value in the list + pub top_logprobs: Vec, + pub text_offset: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum CompletionFinishReason { + Stop, + Length, + ContentFilter, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Choice { + pub text: String, + pub index: u32, + pub logprobs: Option, + pub finish_reason: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub enum ChatCompletionFunctionCall { + /// The model does not call a function, and responds to the end-user. + #[serde(rename = "none")] + None, + /// The model can pick between an end-user or calling a function. + #[serde(rename = "auto")] + Auto, + + // In spec this is ChatCompletionFunctionCallOption + // based on feedback from @m1guelpf in https://github.com/64bit/async-openai/pull/118 + // it is diverged from the spec + /// Forces the model to call the specified function. + #[serde(untagged)] + Function { name: String }, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, Default, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + #[default] + User, + Assistant, + Tool, + Function, +} + +/// The name and arguments of a function that should be called, as generated by the model. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FunctionCall { + /// The name of the function to call. + pub name: String, + /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + pub arguments: String, +} + +/// Usage statistics for the completion request. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct CompletionUsage { + /// Number of tokens in the prompt. + pub prompt_tokens: u32, + /// Number of tokens in the generated completion. + pub completion_tokens: u32, + /// Total number of tokens used in the request (prompt + completion). + pub total_tokens: u32, + /// Breakdown of tokens used in the prompt. + pub prompt_tokens_details: Option, + /// Breakdown of tokens used in a completion. + pub completion_tokens_details: Option, +} + +/// Breakdown of tokens used in a completion. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct PromptTokensDetails { + /// Audio input tokens present in the prompt. + pub audio_tokens: Option, + /// Cached tokens present in the prompt. + pub cached_tokens: Option, +} + +/// Breakdown of tokens used in a completion. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct CompletionTokensDetails { + pub accepted_prediction_tokens: Option, + /// Audio input tokens generated by the model. + pub audio_tokens: Option, + /// Tokens generated by the model for reasoning. + pub reasoning_tokens: Option, + /// When using Predicted Outputs, the number of tokens in the + /// prediction that did not appear in the completion. However, like + /// reasoning tokens, these tokens are still counted in the total + /// completion tokens for purposes of billing, output, and context + /// window limits. + pub rejected_prediction_tokens: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestDeveloperMessageArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestDeveloperMessage { + /// The contents of the developer message. + pub content: ChatCompletionRequestDeveloperMessageContent, + + /// An optional name for the participant. Provides the model information to differentiate between participants of the same role. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum ChatCompletionRequestDeveloperMessageContent { + Text(String), + Array(Vec), +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestSystemMessageArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestSystemMessage { + /// The contents of the system message. + pub content: ChatCompletionRequestSystemMessageContent, + /// An optional name for the participant. Provides the model information to differentiate between participants of the same role. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestMessageContentPartTextArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestMessageContentPartText { + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +pub struct ChatCompletionRequestMessageContentPartRefusal { + /// The refusal message generated by the model. + pub refusal: String, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ImageDetail { + #[default] + Auto, + Low, + High, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ImageUrlArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ImageUrl { + /// Either a URL of the image or the base64 encoded image data. + pub url: String, + /// Specifies the detail level of the image. Learn more in the [Vision guide](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding). + pub detail: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestMessageContentPartImageArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestMessageContentPartImage { + pub image_url: ImageUrl, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum InputAudioFormat { + Wav, + #[default] + Mp3, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +pub struct InputAudio { + /// Base64 encoded audio data. + pub data: String, + /// The format of the encoded audio data. Currently supports "wav" and "mp3". + pub format: InputAudioFormat, +} + +/// Learn about [audio inputs](https://platform.openai.com/docs/guides/audio). +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestMessageContentPartAudioArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestMessageContentPartAudio { + pub input_audio: InputAudio, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ChatCompletionRequestUserMessageContentPart { + Text(ChatCompletionRequestMessageContentPartText), + ImageUrl(ChatCompletionRequestMessageContentPartImage), + InputAudio(ChatCompletionRequestMessageContentPartAudio), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ChatCompletionRequestSystemMessageContentPart { + Text(ChatCompletionRequestMessageContentPartText), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ChatCompletionRequestAssistantMessageContentPart { + Text(ChatCompletionRequestMessageContentPartText), + Refusal(ChatCompletionRequestMessageContentPartRefusal), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ChatCompletionRequestToolMessageContentPart { + Text(ChatCompletionRequestMessageContentPartText), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum ChatCompletionRequestSystemMessageContent { + /// The text contents of the system message. + Text(String), + /// An array of content parts with a defined type. For system messages, only type `text` is supported. + Array(Vec), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum ChatCompletionRequestUserMessageContent { + /// The text contents of the message. + Text(String), + /// An array of content parts with a defined type. Supported options differ based on the [model](https://platform.openai.com/docs/models) being used to generate the response. Can contain text, image, or audio inputs. + Array(Vec), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum ChatCompletionRequestAssistantMessageContent { + /// The text contents of the message. + Text(String), + /// An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`. + Array(Vec), +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum ChatCompletionRequestToolMessageContent { + /// The text contents of the tool message. + Text(String), + /// An array of content parts with a defined type. For tool messages, only type `text` is supported. + Array(Vec), +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestUserMessageArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestUserMessage { + /// The contents of the user message. + pub content: ChatCompletionRequestUserMessageContent, + /// An optional name for the participant. Provides the model information to differentiate between participants of the same role. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +pub struct ChatCompletionRequestAssistantMessageAudio { + /// Unique identifier for a previous audio response from the model. + pub id: String, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestAssistantMessageArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestAssistantMessage { + /// The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// The refusal message by the assistant. + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + /// An optional name for the participant. Provides the model information to differentiate between participants of the same role. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Data about a previous audio response from the model. + /// [Learn more](https://platform.openai.com/docs/guides/audio). + #[serde(skip_serializing_if = "Option::is_none")] + pub audio: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + /// Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model. + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, +} + +/// Tool message +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestToolMessageArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestToolMessage { + /// The contents of the tool message. + pub content: ChatCompletionRequestToolMessageContent, + pub tool_call_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "ChatCompletionRequestFunctionMessageArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionRequestFunctionMessage { + /// The return value from the function call, to return to the model. + pub content: Option, + /// The name of the function to call. + pub name: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "role")] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionRequestMessage { + Developer(ChatCompletionRequestDeveloperMessage), + System(ChatCompletionRequestSystemMessage), + User(ChatCompletionRequestUserMessage), + Assistant(ChatCompletionRequestAssistantMessage), + Tool(ChatCompletionRequestToolMessage), + Function(ChatCompletionRequestFunctionMessage), +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatCompletionMessageToolCall { + /// The ID of the tool call. + pub id: String, + /// The type of the tool. Currently, only `function` is supported. + pub r#type: ChatCompletionToolType, + /// The function that the model called. + pub function: FunctionCall, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +pub struct ChatCompletionResponseMessageAudio { + /// Unique identifier for this audio response. + pub id: String, + /// The Unix timestamp (in seconds) for when this audio response will no longer be accessible on the server for use in multi-turn conversations. + pub expires_at: u32, + /// Base64 encoded audio bytes generated by the model, in the format specified in the request. + pub data: String, + /// Transcript of the audio generated by the model. + pub transcript: String, +} + +/// A chat completion message generated by the model. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatCompletionResponseMessage { + /// The contents of the message. + pub content: Option, + /// The refusal message generated by the model. + pub refusal: Option, + /// The tool calls generated by the model, such as function calls. + pub tool_calls: Option>, + + /// The role of the author of this message. + pub role: Role, + + /// Deprecated and replaced by `tool_calls`. + /// The name and arguments of a function that should be called, as generated by the model. + #[deprecated] + pub function_call: Option, + + /// If the audio output modality is requested, this object contains data about the audio response from the model. [Learn more](https://platform.openai.com/docs/guides/audio). + pub audio: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "ChatCompletionFunctionsArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +#[deprecated] +pub struct ChatCompletionFunctions { + /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + pub name: String, + /// A description of what the function does, used by the model to choose when and how to call the function. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The parameters the functions accepts, described as a JSON Schema object. See the [guide](https://platform.openai.com/docs/guides/text-generation/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. + /// + /// Omitting `parameters` defines a function with an empty parameter list. + pub parameters: serde_json::Value, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "FunctionObjectArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct FunctionObject { + /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + pub name: String, + /// A description of what the function does, used by the model to choose when and how to call the function. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The parameters the functions accepts, described as a JSON Schema object. See the [guide](https://platform.openai.com/docs/guides/text-generation/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. + /// + /// Omitting `parameters` defines a function with an empty parameter list. + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, + + /// Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](https://platform.openai.com/docs/guides/function-calling). + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseFormat { + /// The type of response format being defined: `text` + Text, + /// The type of response format being defined: `json_object` + JsonObject, + /// The type of response format being defined: `json_schema` + JsonSchema { + json_schema: ResponseFormatJsonSchema, + }, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ResponseFormatJsonSchema { + /// A description of what the response format is for, used by the model to determine how to respond in the format. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + pub name: String, + /// The schema for the response format, described as a JSON Schema object. + #[serde(skip_serializing_if = "Option::is_none")] + pub schema: Option, + /// Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionToolType { + #[default] + Function, +} + +#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize, PartialEq)] +#[builder(name = "ChatCompletionToolArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionTool { + #[builder(default = "ChatCompletionToolType::Function")] + pub r#type: ChatCompletionToolType, + pub function: FunctionObject, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct FunctionName { + /// The name of the function to call. + pub name: String, +} + +/// Specifies a tool the model should use. Use to force the model to call a specific function. +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct ChatCompletionNamedToolChoice { + /// The type of the tool. Currently, only `function` is supported. + pub r#type: ChatCompletionToolType, + + pub function: FunctionName, +} + +/// Controls which (if any) tool is called by the model. +/// `none` means the model will not call any tool and instead generates a message. +/// `auto` means the model can pick between generating a message or calling one or more tools. +/// `required` means the model must call one or more tools. +/// Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. +/// +/// `none` is the default when no tools are present. `auto` is the default if tools are present.present. +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionToolChoiceOption { + #[default] + None, + Auto, + Required, + #[serde(untagged)] + Named(ChatCompletionNamedToolChoice), +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ServiceTier { + Auto, + Default, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ServiceTierResponse { + Scale, + Default, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningEffort { + Low, + Medium, + High, +} + +/// Output types that you would like the model to generate for this request. +/// +/// Most models are capable of generating text, which is the default: `["text"]` +/// +/// The `gpt-4o-audio-preview` model can also be used to [generate +/// audio](https://platform.openai.com/docs/guides/audio). To request that this model generate both text and audio responses, you can use: `["text", "audio"]` +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionModalities { + Text, + Audio, +} + +/// The content that should be matched when generating a model response. If generated tokens would match this content, the entire model response can be returned much more quickly. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum PredictionContentContent { + /// The content used for a Predicted Output. This is often the text of a file you are regenerating with minor changes. + Text(String), + /// An array of content parts with a defined type. Supported options differ based on the [model](https://platform.openai.com/docs/models) being used to generate the response. Can contain text inputs. + Array(Vec), +} + +/// Static predicted output content, such as the content of a text file that is being regenerated. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase", content = "content")] +pub enum PredictionContent { + /// The type of the predicted content you want to provide. This type is + /// currently always `content`. + Content(PredictionContentContent), +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionAudioVoice { + Alloy, + Ash, + Ballad, + Coral, + Echo, + Sage, + Shimmer, + Verse, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionAudioFormat { + Wav, + Mp3, + Flac, + Opus, + Pcm16, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatCompletionAudio { + /// The voice the model uses to respond. Supported voices are `ash`, `ballad`, `coral`, `sage`, and `verse` (also supported but not recommended are `alloy`, `echo`, and `shimmer`; these voices are less expressive). + pub voice: ChatCompletionAudioVoice, + /// Specifies the output audio format. Must be one of `wav`, `mp3`, `flac`, `opus`, or `pcm16`. + pub format: ChatCompletionAudioFormat, +} + +#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize, PartialEq)] +#[builder(name = "CreateChatCompletionRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateChatCompletionRequest { + /// A list of messages comprising the conversation so far. Depending on the [model](https://platform.openai.com/docs/models) you use, different message types (modalities) are supported, like [text](https://platform.openai.com/docs/guides/text-generation), [images](https://platform.openai.com/docs/guides/vision), and [audio](https://platform.openai.com/docs/guides/audio). + pub messages: Vec, // min: 1 + + /// ID of the model to use. + /// See the [model endpoint compatibility](https://platform.openai.com/docs/models#model-endpoint-compatibility) table for details on which models work with the Chat API. + pub model: String, + + /// Whether or not to store the output of this chat completion request + /// + /// for use in our [model distillation](https://platform.openai.com/docs/guides/distillation) or [evals](https://platform.openai.com/docs/guides/evals) products. + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, // nullable: true, default: false + + /// **o1 models only** + /// + /// Constrains effort on reasoning for + /// [reasoning models](https://platform.openai.com/docs/guides/reasoning). + /// + /// Currently supported values are `low`, `medium`, and `high`. Reducing + /// + /// reasoning effort can result in faster responses and fewer tokens + /// used on reasoning in a response. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + + /// Developer-defined tags and values used for filtering completions in the [dashboard](https://platform.openai.com/chat-completions). + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, // nullable: true + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, // min: -2.0, max: 2.0, default: 0 + + /// Modify the likelihood of specified tokens appearing in the completion. + /// + /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. + /// Mathematically, the bias is added to the logits generated by the model prior to sampling. + /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; + /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, // default: null + + /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`. + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + + /// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + + /// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the chat completion. + /// + /// This value can be used to control [costs](https://openai.com/api/pricing/) for text generated via API. + /// This value is now deprecated in favor of `max_completion_tokens`, and is + /// not compatible with [o1 series models](https://platform.openai.com/docs/guides/reasoning). + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and [reasoning tokens](https://platform.openai.com/docs/guides/reasoning). + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + /// How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // min:1, max: 128, default: 1 + + #[serde(skip_serializing_if = "Option::is_none")] + pub modalities: Option>, + + /// Configuration for a [Predicted Output](https://platform.openai.com/docs/guides/predicted-outputs),which can greatly improve response times when large parts of the model response are known ahead of time. This is most common when you are regenerating a file with only minor changes to most of the content. + #[serde(skip_serializing_if = "Option::is_none")] + pub prediction: Option, + + /// Parameters for audio output. Required when audio output is requested with `modalities: ["audio"]`. [Learn more](https://platform.openai.com/docs/guides/audio). + #[serde(skip_serializing_if = "Option::is_none")] + pub audio: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, // min: -2.0, max: 2.0, default 0 + + /// An object specifying the format that the model must output. Compatible with [GPT-4o](https://platform.openai.com/docs/models/gpt-4o), [GPT-4o mini](https://platform.openai.com/docs/models/gpt-4o-mini), [GPT-4 Turbo](https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`. + /// + /// Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + /// + /// Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. + /// + /// **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + /// This feature is in Beta. + /// If specified, our system will make a best effort to sample deterministically, such that repeated requests + /// with the same `seed` and parameters should return the same result. + /// Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + /// Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service: + /// - If set to 'auto', the system will utilize scale tier credits until they are exhausted. + /// - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee. + /// - When not set, the default behavior is 'auto'. + /// + /// When this parameter is set, the response body will include the `service_tier` utilized. + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + + /// Up to 4 sequences where the API will stop generating further tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// If set, partial message deltas will be sent, like in ChatGPT. + /// Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + /// as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, + /// while lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, // min: 0, max: 2, default: 1, + + /// An alternative to sampling with temperature, called nucleus sampling, + /// where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or `temperature` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, // min: 0, max: 1, default: 1 + + /// A list of tools the model may call. Currently, only functions are supported as a tool. + /// Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Whether to enable [parallel function calling](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling) during tool use. + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids). + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// Deprecated in favor of `tool_choice`. + /// + /// Controls which (if any) function is called by the model. + /// `none` means the model will not call a function and instead generates a message. + /// `auto` means the model can pick between generating a message or calling a function. + /// Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + /// + /// `none` is the default when no functions are present. `auto` is the default if functions are present. + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + + /// Deprecated in favor of `tools`. + /// + /// A list of functions the model may generate JSON inputs for. + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub functions: Option>, +} + +/// Options for streaming response. Only set this when you set `stream: true`. +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +pub struct ChatCompletionStreamOptions { + /// If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value. + pub include_usage: bool, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + ToolCalls, + ContentFilter, + FunctionCall, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct TopLogprobs { + /// The token. + pub token: String, + /// The log probability of this token. + pub logprob: f32, + /// A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. + pub bytes: Option>, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatCompletionTokenLogprob { + /// The token. + pub token: String, + /// The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. + pub logprob: f32, + /// A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. + pub bytes: Option>, + /// List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned. + pub top_logprobs: Vec, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatChoiceLogprobs { + /// A list of message content tokens with log probability information. + pub content: Option>, + pub refusal: Option>, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatChoice { + /// The index of the choice in the list of choices. + pub index: u32, + pub message: ChatCompletionResponseMessage, + /// The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, + /// `length` if the maximum number of tokens specified in the request was reached, + /// `content_filter` if content was omitted due to a flag from our content filters, + /// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + pub finish_reason: Option, + /// Log probability information for the choice. + pub logprobs: Option, +} + +/// Represents a chat completion response returned by model, based on the provided input. +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateChatCompletionResponse { + /// A unique identifier for the chat completion. + pub id: String, + /// A list of chat completion choices. Can be more than one if `n` is greater than 1. + pub choices: Vec, + /// The Unix timestamp (in seconds) of when the chat completion was created. + pub created: u32, + /// The model used for the chat completion. + pub model: String, + /// The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + pub service_tier: Option, + /// This fingerprint represents the backend configuration that the model runs with. + /// + /// Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + pub system_fingerprint: Option, + + /// The object type, which is always `chat.completion`. + pub object: String, + pub usage: Option, +} + +/// Parsed server side events stream until an \[DONE\] is received from server. +pub type ChatCompletionResponseStream = + Pin> + Send>>; + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FunctionCallStream { + /// The name of the function to call. + pub name: Option, + /// The arguments to call the function with, as generated by the model in JSON format. + /// Note that the model does not always generate valid JSON, and may hallucinate + /// parameters not defined by your function schema. Validate the arguments in your + /// code before calling your function. + pub arguments: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatCompletionMessageToolCallChunk { + pub index: u32, + /// The ID of the tool call. + pub id: Option, + /// The type of the tool. Currently, only `function` is supported. + pub r#type: Option, + pub function: Option, +} + +/// A chat completion delta generated by streamed model responses. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatCompletionStreamResponseDelta { + /// The contents of the chunk message. + pub content: Option, + /// Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model. + #[deprecated] + pub function_call: Option, + + pub tool_calls: Option>, + /// The role of the author of this message. + pub role: Option, + /// The refusal message generated by the model. + pub refusal: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ChatChoiceStream { + /// The index of the choice in the list of choices. + pub index: u32, + pub delta: ChatCompletionStreamResponseDelta, + /// The reason the model stopped generating tokens. This will be + /// `stop` if the model hit a natural stop point or a provided + /// stop sequence, + /// + /// `length` if the maximum number of tokens specified in the + /// request was reached, + /// `content_filter` if content was omitted due to a flag from our + /// content filters, + /// `tool_calls` if the model called a tool, or `function_call` + /// (deprecated) if the model called a function. + pub finish_reason: Option, + /// Log probability information for the choice. + pub logprobs: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +/// Represents a streamed chunk of a chat completion response returned by model, based on the provided input. +pub struct CreateChatCompletionStreamResponse { + /// A unique identifier for the chat completion. Each chunk has the same ID. + pub id: String, + /// A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the last chunk if you set `stream_options: {"include_usage": true}`. + pub choices: Vec, + + /// The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. + pub created: u32, + /// The model to generate the completion. + pub model: String, + /// The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + pub service_tier: Option, + /// This fingerprint represents the backend configuration that the model runs with. + /// Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + pub system_fingerprint: Option, + /// The object type, which is always `chat.completion.chunk`. + pub object: String, + + /// An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request. + /// When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request. + pub usage: Option, +} diff --git a/clia-async-openai/src/types/common.rs b/clia-async-openai/src/types/common.rs new file mode 100644 index 00000000..1fc9017d --- /dev/null +++ b/clia-async-openai/src/types/common.rs @@ -0,0 +1,18 @@ +use std::path::PathBuf; + +use bytes::Bytes; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq)] +pub enum InputSource { + Path { path: PathBuf }, + Bytes { filename: String, bytes: Bytes }, + VecU8 { filename: String, vec: Vec }, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum OrganizationRole { + Owner, + Reader, +} diff --git a/clia-async-openai/src/types/completion.rs b/clia-async-openai/src/types/completion.rs new file mode 100644 index 00000000..3c15dd6b --- /dev/null +++ b/clia-async-openai/src/types/completion.rs @@ -0,0 +1,141 @@ +use std::{collections::HashMap, pin::Pin}; + +use derive_builder::Builder; +use futures::Stream; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +use super::{ChatCompletionStreamOptions, Choice, CompletionUsage, Prompt, Stop}; + +#[derive(Clone, Serialize, Deserialize, Default, Debug, Builder, PartialEq)] +#[builder(name = "CreateCompletionRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateCompletionRequest { + /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them. + pub model: String, + + /// The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. + /// + /// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document. + pub prompt: Prompt, + + /// The suffix that comes after a completion of inserted text. + /// + /// This parameter is only supported for `gpt-3.5-turbo-instruct`. + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, // default: null + + /// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the completion. + /// + /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, // min: 0, max: 2, default: 1, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or `temperature` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, // min: 0, max: 1, default: 1 + + /// How many completions to generate for each prompt. + + /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + /// + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // min:1 max: 128, default: 1 + + /// Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + /// as they become available, with the stream terminated by a `data: [DONE]` message. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, // nullable: true + + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + /// Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. + /// + /// The maximum value for `logprobs` is 5. + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, // min:0 , max: 5, default: null, nullable: true + + /// Echo back the prompt in addition to the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub echo: Option, + + /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details) + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, // min: -2.0, max: 2.0, default 0 + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/text-generation/parameter-details) + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, // min: -2.0, max: 2.0, default: 0 + + /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed. + /// + /// When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`. + /// + /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, //min: 0, max: 20, default: 1 + + /// Modify the likelihood of specified tokens appearing in the completion. + /// + /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + /// + /// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated. + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, // default: null + + /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. + /// + /// Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateCompletionResponse { + /// A unique identifier for the completion. + pub id: String, + pub choices: Vec, + /// The Unix timestamp (in seconds) of when the completion was created. + pub created: u32, + + /// The model used for completion. + pub model: String, + /// This fingerprint represents the backend configuration that the model runs with. + /// + /// Can be used in conjunction with the `seed` request parameter to understand when backend changes have been + /// made that might impact determinism. + pub system_fingerprint: Option, + + /// The object type, which is always "text_completion" + pub object: String, + pub usage: Option, +} + +/// Parsed server side events stream until an \[DONE\] is received from server. +pub type CompletionResponseStream = + Pin> + Send>>; diff --git a/clia-async-openai/src/types/embedding.rs b/clia-async-openai/src/types/embedding.rs new file mode 100644 index 00000000..ea05ac3b --- /dev/null +++ b/clia-async-openai/src/types/embedding.rs @@ -0,0 +1,122 @@ +use base64::engine::{general_purpose, Engine}; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)] +#[serde(untagged)] +pub enum EmbeddingInput { + String(String), + StringArray(Vec), + // Minimum value is 0, maximum value is 100257 (inclusive). + IntegerArray(Vec), + ArrayOfIntegerArray(Vec>), +} + +#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum EncodingFormat { + #[default] + Float, + Base64, +} + +#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)] +#[builder(name = "CreateEmbeddingRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateEmbeddingRequest { + /// ID of the model to use. You can use the + /// [List models](https://platform.openai.com/docs/api-reference/models/list) + /// API to see all of your available models, or see our + /// [Model overview](https://platform.openai.com/docs/models/overview) + /// for descriptions of them. + pub model: String, + + /// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + pub input: EmbeddingInput, + + /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). Defaults to float + #[serde(skip_serializing_if = "Option::is_none")] + pub encoding_format: Option, + + /// A unique identifier representing your end-user, which will help OpenAI + /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. + #[serde(skip_serializing_if = "Option::is_none")] + pub dimensions: Option, +} + +/// Represents an embedding vector returned by embedding endpoint. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Embedding { + /// The index of the embedding in the list of embeddings. + pub index: u32, + /// The object type, which is always "embedding". + pub object: String, + /// The embedding vector, which is a list of floats. The length of vector + /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings). + pub embedding: Vec, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Base64EmbeddingVector(pub String); + +impl From for Vec { + fn from(value: Base64EmbeddingVector) -> Self { + let bytes = general_purpose::STANDARD + .decode(value.0) + .expect("openai base64 encoding to be valid"); + let chunks = bytes.chunks_exact(4); + chunks + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect() + } +} + +/// Represents an base64-encoded embedding vector returned by embedding endpoint. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Base64Embedding { + /// The index of the embedding in the list of embeddings. + pub index: u32, + /// The object type, which is always "embedding". + pub object: String, + /// The embedding vector, encoded in base64. + pub embedding: Base64EmbeddingVector, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct EmbeddingUsage { + /// The number of tokens used by the prompt. + pub prompt_tokens: u32, + /// The total number of tokens used by the request. + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateEmbeddingResponse { + pub object: String, + /// The name of the model used to generate the embedding. + pub model: String, + /// The list of embeddings generated by the model. + pub data: Vec, + /// The usage information for the request. + pub usage: EmbeddingUsage, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateBase64EmbeddingResponse { + pub object: String, + /// The name of the model used to generate the embedding. + pub model: String, + /// The list of embeddings generated by the model. + pub data: Vec, + /// The usage information for the request. + pub usage: EmbeddingUsage, +} diff --git a/clia-async-openai/src/types/file.rs b/clia-async-openai/src/types/file.rs new file mode 100644 index 00000000..9a2e5090 --- /dev/null +++ b/clia-async-openai/src/types/file.rs @@ -0,0 +1,90 @@ +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +use super::InputSource; + +#[derive(Debug, Default, Clone, PartialEq)] +pub struct FileInput { + pub source: InputSource, +} + +#[derive(Debug, Default, Clone, PartialEq)] +pub enum FilePurpose { + Assistants, + Batch, + #[default] + FineTune, + Vision, +} + +#[derive(Debug, Default, Clone, Builder, PartialEq)] +#[builder(name = "CreateFileRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateFileRequest { + /// The File object (not file name) to be uploaded. + pub file: FileInput, + + /// The intended purpose of the uploaded file. + /// + /// Use "assistants" for [Assistants](https://platform.openai.com/docs/api-reference/assistants) and [Message](https://platform.openai.com/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](https://platform.openai.com/docs/guides/batch), and "fine-tune" for [Fine-tuning](https://platform.openai.com/docs/api-reference/fine-tuning). + pub purpose: FilePurpose, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct ListFilesResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct DeleteFileResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub enum OpenAIFilePurpose { + #[serde(rename = "assistants")] + Assistants, + #[serde(rename = "assistants_output")] + AssistantsOutput, + #[serde(rename = "batch")] + Batch, + #[serde(rename = "batch_output")] + BatchOutput, + #[serde(rename = "fine-tune")] + FineTune, + #[serde(rename = "fine-tune-results")] + FineTuneResults, + #[serde(rename = "vision")] + Vision, +} + +/// The `File` object represents a document that has been uploaded to OpenAI. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct OpenAIFile { + /// The file identifier, which can be referenced in the API endpoints. + pub id: String, + /// The object type, which is always "file". + pub object: String, + /// The size of the file in bytes. + pub bytes: u32, + /// The Unix timestamp (in seconds) for when the file was created. + pub created_at: u32, + /// The name of the file. + pub filename: String, + /// The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`. + pub purpose: OpenAIFilePurpose, + /// Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`. + #[deprecated] + pub status: Option, + /// Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`. + #[deprecated] + pub status_details: Option, // nullable: true +} diff --git a/clia-async-openai/src/types/fine_tuning.rs b/clia-async-openai/src/types/fine_tuning.rs new file mode 100644 index 00000000..a5c6d321 --- /dev/null +++ b/clia-async-openai/src/types/fine_tuning.rs @@ -0,0 +1,348 @@ +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] +#[serde(untagged)] +pub enum NEpochs { + NEpochs(u8), + #[default] + #[serde(rename = "auto")] + Auto, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] +#[serde(untagged)] +pub enum BatchSize { + BatchSize(u16), + #[default] + #[serde(rename = "auto")] + Auto, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] +#[serde(untagged)] +pub enum LearningRateMultiplier { + LearningRateMultiplier(f32), + #[default] + #[serde(rename = "auto")] + Auto, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] +pub struct Hyperparameters { + /// Number of examples in each batch. A larger batch size means that model parameters + /// are updated less frequently, but with lower variance. + pub batch_size: BatchSize, + /// Scaling factor for the learning rate. A smaller learning rate may be useful to avoid + /// overfitting. + pub learning_rate_multiplier: LearningRateMultiplier, + /// The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. + pub n_epochs: NEpochs, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] +#[serde(untagged)] +pub enum Beta { + Beta(f32), + #[default] + #[serde(rename = "auto")] + Auto, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] +pub struct DPOHyperparameters { + /// The beta value for the DPO method. A higher beta value will increase the weight of the penalty between the policy and reference model. + pub beta: Beta, + /// Number of examples in each batch. A larger batch size means that model parameters + /// are updated less frequently, but with lower variance. + pub batch_size: BatchSize, + /// Scaling factor for the learning rate. A smaller learning rate may be useful to avoid + /// overfitting. + pub learning_rate_multiplier: LearningRateMultiplier, + /// The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. + pub n_epochs: NEpochs, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default, Builder, PartialEq)] +#[builder(name = "CreateFineTuningJobRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateFineTuningJobRequest { + /// The name of the model to fine-tune. You can select one of the + /// [supported models](https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned). + pub model: String, + + /// The ID of an uploaded file that contains training data. + /// + /// See [upload file](https://platform.openai.com/docs/api-reference/files/create) for how to upload a file. + /// + /// Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`. + /// + /// The contents of the file should differ depending on if the model uses the [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input), [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) format, or if the fine-tuning method uses the [preference](https://platform.openai.com/docs/api-reference/fine-tuning/preference-input) format. + /// + /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) for more details. + pub training_file: String, + + /// The hyperparameters used for the fine-tuning job. + /// This value is now deprecated in favor of `method`, and should be passed in under the `method` parameter. + #[deprecated] + pub hyperparameters: Option, + + /// A string of up to 64 characters that will be added to your fine-tuned model name. + /// + /// For example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`. + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, // default: null, minLength:1, maxLength:40 + + /// The ID of an uploaded file that contains validation data. + /// + /// If you provide this file, the data is used to generate validation + /// metrics periodically during fine-tuning. These metrics can be viewed in + /// the fine-tuning results file. + /// The same data should not be present in both train and validation files. + /// + /// Your dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`. + /// + /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) for more details. + #[serde(skip_serializing_if = "Option::is_none")] + pub validation_file: Option, + + /// A list of integrations to enable for your fine-tuning job. + #[serde(skip_serializing_if = "Option::is_none")] + pub integrations: Option>, + + /// The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. + /// If a seed is not specified, one will be generated for you. + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, // min:0, max: 2147483647 + + #[serde(skip_serializing_if = "Option::is_none")] + pub method: Option, +} + +/// The method used for fine-tuning. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum FineTuneMethod { + Supervised { + supervised: FineTuneSupervisedMethod, + }, + DPO { + dpo: FineTuneDPOMethod, + }, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FineTuneSupervisedMethod { + pub hyperparameters: Hyperparameters, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FineTuneDPOMethod { + pub hyperparameters: DPOHyperparameters, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum FineTuningJobIntegrationType { + #[default] + Wandb, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FineTuningIntegration { + /// The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported. + pub r#type: FineTuningJobIntegrationType, + + /// The settings for your integration with Weights and Biases. This payload specifies the project that + /// metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags + /// to your run, and set a default entity (team, username, etc) to be associated with your run. + pub wandb: WandB, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct WandB { + /// The name of the project that the new run will be created under. + pub project: String, + /// A display name to set for the run. If not set, we will use the Job ID as the name. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// The entity to use for the run. This allows you to set the team or username of the WandB user that you would + /// like associated with the run. If not set, the default entity for the registered WandB API key is used. + #[serde(skip_serializing_if = "Option::is_none")] + pub entity: Option, + /// A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some + /// default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + #[serde(skip_serializing_if = "Option::is_none")] + pub tags: Option>, +} + +/// For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FineTuneJobError { + /// A machine-readable error code. + pub code: String, + /// A human-readable error message. + pub message: String, + /// The parameter that was invalid, usually `training_file` or `validation_file`. + /// This field will be null if the failure was not parameter-specific. + pub param: Option, // nullable true +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum FineTuningJobStatus { + ValidatingFiles, + Queued, + Running, + Succeeded, + Failed, + Cancelled, +} + +/// The `fine_tuning.job` object represents a fine-tuning job that has been created through the API. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FineTuningJob { + /// The object identifier, which can be referenced in the API endpoints. + pub id: String, + /// The Unix timestamp (in seconds) for when the fine-tuning job was created. + pub created_at: u32, + /// For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure. + pub error: Option, + /// The name of the fine-tuned model that is being created. + /// The value will be null if the fine-tuning job is still running. + pub fine_tuned_model: Option, // nullable: true + /// The Unix timestamp (in seconds) for when the fine-tuning job was finished. + /// The value will be null if the fine-tuning job is still running. + pub finished_at: Option, // nullable true + + /// The hyperparameters used for the fine-tuning job. + /// See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + pub hyperparameters: Hyperparameters, + + /// The base model that is being fine-tuned. + pub model: String, + + /// The object type, which is always "fine_tuning.job". + pub object: String, + /// The organization that owns the fine-tuning job. + pub organization_id: String, + + /// The compiled results file ID(s) for the fine-tuning job. + /// You can retrieve the results with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + pub result_files: Vec, + + /// The current status of the fine-tuning job, which can be either + /// `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. + pub status: FineTuningJobStatus, + + /// The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running. + pub trained_tokens: Option, + + /// The file ID used for training. You can retrieve the training data with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + pub training_file: String, + + /// The file ID used for validation. You can retrieve the validation results with the [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + pub validation_file: Option, + + /// A list of integrations to enable for this fine-tuning job. + pub integrations: Option>, // maxItems: 5 + + /// The seed used for the fine-tuning job. + pub seed: u32, + + /// The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running. + pub estimated_finish: Option, + + pub method: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ListPaginatedFineTuningJobsResponse { + pub data: Vec, + pub has_more: bool, + pub object: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ListFineTuningJobEventsResponse { + pub data: Vec, + pub object: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ListFineTuningJobCheckpointsResponse { + pub data: Vec, + pub object: String, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Level { + Info, + Warn, + Error, +} + +///Fine-tuning job event object +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct FineTuningJobEvent { + /// The object identifier. + pub id: String, + /// The Unix timestamp (in seconds) for when the fine-tuning job event was created. + pub created_at: u32, + /// The log level of the event. + pub level: Level, + /// The message of the event. + pub message: String, + /// The object type, which is always "fine_tuning.job.event". + pub object: String, + /// The type of event. + pub r#type: Option, + /// The data associated with the event. + pub data: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum FineTuningJobEventType { + Message, + Metrics, +} + +/// The `fine_tuning.job.checkpoint` object represents a model checkpoint for a fine-tuning job that is ready to use. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct FineTuningJobCheckpoint { + /// The checkpoint identifier, which can be referenced in the API endpoints. + pub id: String, + /// The Unix timestamp (in seconds) for when the checkpoint was created. + pub created_at: u32, + /// The name of the fine-tuned checkpoint model that is created. + pub fine_tuned_model_checkpoint: String, + /// The step number that the checkpoint was created at. + pub step_number: u32, + /// Metrics at the step number during the fine-tuning job. + pub metrics: FineTuningJobCheckpointMetrics, + /// The name of the fine-tuning job that this checkpoint was created from. + pub fine_tuning_job_id: String, + /// The object type, which is always "fine_tuning.job.checkpoint". + pub object: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct FineTuningJobCheckpointMetrics { + pub step: u32, + pub train_loss: f32, + pub train_mean_token_accuracy: f32, + pub valid_loss: f32, + pub valid_mean_token_accuracy: f32, + pub full_valid_loss: f32, + pub full_valid_mean_token_accuracy: f32, +} diff --git a/clia-async-openai/src/types/image.rs b/clia-async-openai/src/types/image.rs new file mode 100644 index 00000000..86169c46 --- /dev/null +++ b/clia-async-openai/src/types/image.rs @@ -0,0 +1,197 @@ +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +use super::InputSource; + +#[derive(Default, Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +pub enum ImageSize { + #[serde(rename = "256x256")] + S256x256, + #[serde(rename = "512x512")] + S512x512, + #[default] + #[serde(rename = "1024x1024")] + S1024x1024, + #[serde(rename = "1792x1024")] + S1792x1024, + #[serde(rename = "1024x1792")] + S1024x1792, +} + +#[derive(Default, Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +pub enum DallE2ImageSize { + #[serde(rename = "256x256")] + S256x256, + #[serde(rename = "512x512")] + S512x512, + #[default] + #[serde(rename = "1024x1024")] + S1024x1024, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ImageResponseFormat { + #[default] + Url, + #[serde(rename = "b64_json")] + B64Json, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +pub enum ImageModel { + #[default] + #[serde(rename = "dall-e-2")] + DallE2, + #[serde(rename = "dall-e-3")] + DallE3, + #[serde(untagged)] + Other(String), +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ImageQuality { + #[default] + Standard, + HD, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ImageStyle { + #[default] + Vivid, + Natural, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder, PartialEq)] +#[builder(name = "CreateImageRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateImageRequest { + /// A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` + /// and 4000 characters for `dall-e-3`. + pub prompt: String, + + /// The model to use for image generation. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // min:1 max:10 default:1 + + /// The quality of the image that will be generated. `hd` creates images with finer details and greater + /// consistency across the image. This param is only supported for `dall-e-3`. + #[serde(skip_serializing_if = "Option::is_none")] + pub quality: Option, + + /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. + /// Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models. + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, + + /// The style of the generated images. Must be one of `vivid` or `natural`. + /// Vivid causes the model to lean towards generating hyper-real and dramatic images. + /// Natural causes the model to produce more natural, less hyper-real looking images. + /// This param is only supported for `dall-e-3`. + #[serde(skip_serializing_if = "Option::is_none")] + pub style: Option, + + /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(untagged)] +pub enum Image { + /// The URL of the generated image, if `response_format` is `url` (default). + Url { + url: String, + revised_prompt: Option, + }, + /// The base64-encoded JSON of the generated image, if `response_format` is `b64_json`. + B64Json { + b64_json: std::sync::Arc, + revised_prompt: Option, + }, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ImagesResponse { + pub created: u32, + pub data: Vec>, +} + +#[derive(Debug, Default, Clone, PartialEq)] +pub struct ImageInput { + pub source: InputSource, +} + +#[derive(Debug, Clone, Default, Builder, PartialEq)] +#[builder(name = "CreateImageEditRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateImageEditRequest { + /// The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask. + pub image: ImageInput, + + /// A text description of the desired image(s). The maximum length is 1000 characters. + pub prompt: String, + + /// An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`. + pub mask: Option, + + /// The model to use for image generation. Only `dall-e-2` is supported at this time. + pub model: Option, + + /// The number of images to generate. Must be between 1 and 10. + pub n: Option, // min:1 max:10 default:1 + + /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + pub size: Option, + + /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. + pub response_format: Option, + + /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). + pub user: Option, +} + +#[derive(Debug, Default, Clone, Builder, PartialEq)] +#[builder(name = "CreateImageVariationRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateImageVariationRequest { + /// The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. + pub image: ImageInput, + + /// The model to use for image generation. Only `dall-e-2` is supported at this time. + pub model: Option, + + /// The number of images to generate. Must be between 1 and 10. + pub n: Option, // min:1 max:10 default:1 + + /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + pub size: Option, + + /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. + pub response_format: Option, + + /// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). + pub user: Option, +} diff --git a/clia-async-openai/src/types/impls.rs b/clia-async-openai/src/types/impls.rs new file mode 100644 index 00000000..cc11c0bf --- /dev/null +++ b/clia-async-openai/src/types/impls.rs @@ -0,0 +1,989 @@ +use std::{ + fmt::Display, + path::{Path, PathBuf}, +}; + +use crate::{ + download::{download_url, save_b64}, + error::OpenAIError, + traits::AsyncTryFrom, + types::InputSource, + util::{create_all_dir, create_file_part}, +}; + +use bytes::Bytes; + +use super::{ + AddUploadPartRequest, AudioInput, AudioResponseFormat, ChatCompletionFunctionCall, + ChatCompletionFunctions, ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessage, + ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestDeveloperMessage, + ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestFunctionMessage, + ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartAudio, + ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, + ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, + ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, + ChatCompletionRequestUserMessageContentPart, ChatCompletionToolChoiceOption, CreateFileRequest, + CreateImageEditRequest, CreateImageVariationRequest, CreateMessageRequestContent, + CreateSpeechResponse, CreateTranscriptionRequest, CreateTranslationRequest, DallE2ImageSize, + EmbeddingInput, FileInput, FilePurpose, FunctionName, Image, ImageInput, ImageModel, + ImageResponseFormat, ImageSize, ImageUrl, ImagesResponse, ModerationInput, Prompt, Role, Stop, + TimestampGranularity, +}; + +/// for `impl_from!(T, Enum)`, implements +/// - `From` +/// - `From>` +/// - `From<&Vec>` +/// - `From<[T; N]>` +/// - `From<&[T; N]>` +/// +/// for `T: Into` and `Enum` having variants `String(String)` and `StringArray(Vec)` +macro_rules! impl_from { + ($from_typ:ty, $to_typ:ty) => { + // From -> String variant + impl From<$from_typ> for $to_typ { + fn from(value: $from_typ) -> Self { + <$to_typ>::String(value.into()) + } + } + + // From> -> StringArray variant + impl From> for $to_typ { + fn from(value: Vec<$from_typ>) -> Self { + <$to_typ>::StringArray(value.iter().map(|v| v.to_string()).collect()) + } + } + + // From<&Vec> -> StringArray variant + impl From<&Vec<$from_typ>> for $to_typ { + fn from(value: &Vec<$from_typ>) -> Self { + <$to_typ>::StringArray(value.iter().map(|v| v.to_string()).collect()) + } + } + + // From<[T; N]> -> StringArray variant + impl From<[$from_typ; N]> for $to_typ { + fn from(value: [$from_typ; N]) -> Self { + <$to_typ>::StringArray(value.into_iter().map(|v| v.to_string()).collect()) + } + } + + // From<&[T; N]> -> StringArray variatn + impl From<&[$from_typ; N]> for $to_typ { + fn from(value: &[$from_typ; N]) -> Self { + <$to_typ>::StringArray(value.into_iter().map(|v| v.to_string()).collect()) + } + } + }; +} + +// From String "family" to Prompt +impl_from!(&str, Prompt); +impl_from!(String, Prompt); +impl_from!(&String, Prompt); + +// From String "family" to Stop +impl_from!(&str, Stop); +impl_from!(String, Stop); +impl_from!(&String, Stop); + +// From String "family" to ModerationInput +impl_from!(&str, ModerationInput); +impl_from!(String, ModerationInput); +impl_from!(&String, ModerationInput); + +// From String "family" to EmbeddingInput +impl_from!(&str, EmbeddingInput); +impl_from!(String, EmbeddingInput); +impl_from!(&String, EmbeddingInput); + +/// for `impl_default!(Enum)`, implements `Default` for `Enum` as `Enum::String("")` where `Enum` has `String` variant +macro_rules! impl_default { + ($for_typ:ty) => { + impl Default for $for_typ { + fn default() -> Self { + Self::String("".into()) + } + } + }; +} + +impl_default!(Prompt); +impl_default!(ModerationInput); +impl_default!(EmbeddingInput); + +impl Default for InputSource { + fn default() -> Self { + InputSource::Path { + path: PathBuf::new(), + } + } +} + +/// for `impl_input!(Struct)` where +/// ```text +/// Struct { +/// source: InputSource +/// } +/// ``` +/// implements methods `from_bytes` and `from_vec_u8`, +/// and `From

` for `P: AsRef` +macro_rules! impl_input { + ($for_typ:ty) => { + impl $for_typ { + pub fn from_bytes(filename: String, bytes: Bytes) -> Self { + Self { + source: InputSource::Bytes { filename, bytes }, + } + } + + pub fn from_vec_u8(filename: String, vec: Vec) -> Self { + Self { + source: InputSource::VecU8 { filename, vec }, + } + } + } + + impl> From

for $for_typ { + fn from(path: P) -> Self { + let path_buf = path.as_ref().to_path_buf(); + Self { + source: InputSource::Path { path: path_buf }, + } + } + } + }; +} + +impl_input!(AudioInput); +impl_input!(FileInput); +impl_input!(ImageInput); + +impl Display for ImageSize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::S256x256 => "256x256", + Self::S512x512 => "512x512", + Self::S1024x1024 => "1024x1024", + Self::S1792x1024 => "1792x1024", + Self::S1024x1792 => "1024x1792", + } + ) + } +} + +impl Display for DallE2ImageSize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::S256x256 => "256x256", + Self::S512x512 => "512x512", + Self::S1024x1024 => "1024x1024", + } + ) + } +} + +impl Display for ImageModel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::DallE2 => "dall-e-2", + Self::DallE3 => "dall-e-3", + Self::Other(other) => other, + } + ) + } +} + +impl Display for ImageResponseFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Url => "url", + Self::B64Json => "b64_json", + } + ) + } +} + +impl Display for AudioResponseFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + AudioResponseFormat::Json => "json", + AudioResponseFormat::Srt => "srt", + AudioResponseFormat::Text => "text", + AudioResponseFormat::VerboseJson => "verbose_json", + AudioResponseFormat::Vtt => "vtt", + } + ) + } +} + +impl Display for TimestampGranularity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + TimestampGranularity::Word => "word", + TimestampGranularity::Segment => "segment", + } + ) + } +} + +impl Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Role::User => "user", + Role::System => "system", + Role::Assistant => "assistant", + Role::Function => "function", + Role::Tool => "tool", + } + ) + } +} + +impl Display for FilePurpose { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Assistants => "assistants", + Self::Batch => "batch", + Self::FineTune => "fine-tune", + Self::Vision => "vision", + } + ) + } +} + +impl ImagesResponse { + /// Save each image in a dedicated Tokio task and return paths to saved files. + /// For [ResponseFormat::Url] each file is downloaded in dedicated Tokio task. + pub async fn save>(&self, dir: P) -> Result, OpenAIError> { + create_all_dir(dir.as_ref())?; + + let mut handles = vec![]; + for id in self.data.clone() { + let dir_buf = PathBuf::from(dir.as_ref()); + handles.push(tokio::spawn(async move { id.save(dir_buf).await })); + } + + let results = futures::future::join_all(handles).await; + let mut errors = vec![]; + let mut paths = vec![]; + + for result in results { + match result { + Ok(inner) => match inner { + Ok(path) => paths.push(path), + Err(e) => errors.push(e), + }, + Err(e) => errors.push(OpenAIError::FileSaveError(e.to_string())), + } + } + + if errors.is_empty() { + Ok(paths) + } else { + Err(OpenAIError::FileSaveError( + errors + .into_iter() + .map(|e| e.to_string()) + .collect::>() + .join("; "), + )) + } + } +} + +impl CreateSpeechResponse { + pub async fn save>(&self, file_path: P) -> Result<(), OpenAIError> { + let dir = file_path.as_ref().parent(); + + if let Some(dir) = dir { + create_all_dir(dir)?; + } + + tokio::fs::write(file_path, &self.bytes) + .await + .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; + + Ok(()) + } +} + +impl Image { + async fn save>(&self, dir: P) -> Result { + match self { + Image::Url { url, .. } => download_url(url, dir).await, + Image::B64Json { b64_json, .. } => save_b64(b64_json, dir).await, + } + } +} + +macro_rules! impl_from_for_integer_array { + ($from_typ:ty, $to_typ:ty) => { + impl From<[$from_typ; N]> for $to_typ { + fn from(value: [$from_typ; N]) -> Self { + Self::IntegerArray(value.to_vec()) + } + } + + impl From<&[$from_typ; N]> for $to_typ { + fn from(value: &[$from_typ; N]) -> Self { + Self::IntegerArray(value.to_vec()) + } + } + + impl From> for $to_typ { + fn from(value: Vec<$from_typ>) -> Self { + Self::IntegerArray(value) + } + } + + impl From<&Vec<$from_typ>> for $to_typ { + fn from(value: &Vec<$from_typ>) -> Self { + Self::IntegerArray(value.clone()) + } + } + }; +} + +impl_from_for_integer_array!(u32, EmbeddingInput); +impl_from_for_integer_array!(u16, Prompt); + +macro_rules! impl_from_for_array_of_integer_array { + ($from_typ:ty, $to_typ:ty) => { + impl From>> for $to_typ { + fn from(value: Vec>) -> Self { + Self::ArrayOfIntegerArray(value) + } + } + + impl From<&Vec>> for $to_typ { + fn from(value: &Vec>) -> Self { + Self::ArrayOfIntegerArray(value.clone()) + } + } + + impl From<[[$from_typ; N]; M]> for $to_typ { + fn from(value: [[$from_typ; N]; M]) -> Self { + Self::ArrayOfIntegerArray(value.iter().map(|inner| inner.to_vec()).collect()) + } + } + + impl From<[&[$from_typ; N]; M]> for $to_typ { + fn from(value: [&[$from_typ; N]; M]) -> Self { + Self::ArrayOfIntegerArray(value.iter().map(|inner| inner.to_vec()).collect()) + } + } + + impl From<&[[$from_typ; N]; M]> for $to_typ { + fn from(value: &[[$from_typ; N]; M]) -> Self { + Self::ArrayOfIntegerArray(value.iter().map(|inner| inner.to_vec()).collect()) + } + } + + impl From<&[&[$from_typ; N]; M]> for $to_typ { + fn from(value: &[&[$from_typ; N]; M]) -> Self { + Self::ArrayOfIntegerArray(value.iter().map(|inner| inner.to_vec()).collect()) + } + } + + impl From<[Vec<$from_typ>; N]> for $to_typ { + fn from(value: [Vec<$from_typ>; N]) -> Self { + Self::ArrayOfIntegerArray(value.to_vec()) + } + } + + impl From<&[Vec<$from_typ>; N]> for $to_typ { + fn from(value: &[Vec<$from_typ>; N]) -> Self { + Self::ArrayOfIntegerArray(value.to_vec()) + } + } + + impl From<[&Vec<$from_typ>; N]> for $to_typ { + fn from(value: [&Vec<$from_typ>; N]) -> Self { + Self::ArrayOfIntegerArray(value.into_iter().map(|inner| inner.clone()).collect()) + } + } + + impl From<&[&Vec<$from_typ>; N]> for $to_typ { + fn from(value: &[&Vec<$from_typ>; N]) -> Self { + Self::ArrayOfIntegerArray( + value + .to_vec() + .into_iter() + .map(|inner| inner.clone()) + .collect(), + ) + } + } + + impl From> for $to_typ { + fn from(value: Vec<[$from_typ; N]>) -> Self { + Self::ArrayOfIntegerArray(value.into_iter().map(|inner| inner.to_vec()).collect()) + } + } + + impl From<&Vec<[$from_typ; N]>> for $to_typ { + fn from(value: &Vec<[$from_typ; N]>) -> Self { + Self::ArrayOfIntegerArray(value.into_iter().map(|inner| inner.to_vec()).collect()) + } + } + + impl From> for $to_typ { + fn from(value: Vec<&[$from_typ; N]>) -> Self { + Self::ArrayOfIntegerArray(value.into_iter().map(|inner| inner.to_vec()).collect()) + } + } + + impl From<&Vec<&[$from_typ; N]>> for $to_typ { + fn from(value: &Vec<&[$from_typ; N]>) -> Self { + Self::ArrayOfIntegerArray(value.into_iter().map(|inner| inner.to_vec()).collect()) + } + } + }; +} + +impl_from_for_array_of_integer_array!(u32, EmbeddingInput); +impl_from_for_array_of_integer_array!(u16, Prompt); + +impl From<&str> for ChatCompletionFunctionCall { + fn from(value: &str) -> Self { + match value { + "auto" => Self::Auto, + "none" => Self::None, + _ => Self::Function { name: value.into() }, + } + } +} + +impl From<&str> for FunctionName { + fn from(value: &str) -> Self { + Self { name: value.into() } + } +} + +impl From for FunctionName { + fn from(value: String) -> Self { + Self { name: value } + } +} + +impl From<&str> for ChatCompletionNamedToolChoice { + fn from(value: &str) -> Self { + Self { + r#type: super::ChatCompletionToolType::Function, + function: value.into(), + } + } +} + +impl From for ChatCompletionNamedToolChoice { + fn from(value: String) -> Self { + Self { + r#type: super::ChatCompletionToolType::Function, + function: value.into(), + } + } +} + +impl From<&str> for ChatCompletionToolChoiceOption { + fn from(value: &str) -> Self { + match value { + "auto" => Self::Auto, + "none" => Self::None, + _ => Self::Named(value.into()), + } + } +} + +impl From for ChatCompletionToolChoiceOption { + fn from(value: String) -> Self { + match value.as_str() { + "auto" => Self::Auto, + "none" => Self::None, + _ => Self::Named(value.into()), + } + } +} + +impl From<(String, serde_json::Value)> for ChatCompletionFunctions { + fn from(value: (String, serde_json::Value)) -> Self { + Self { + name: value.0, + description: None, + parameters: value.1, + } + } +} + +// todo: write macro for bunch of same looking From trait implementations below + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestUserMessage) -> Self { + Self::User(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestSystemMessage) -> Self { + Self::System(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestDeveloperMessage) -> Self { + Self::Developer(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestAssistantMessage) -> Self { + Self::Assistant(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestFunctionMessage) -> Self { + Self::Function(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestToolMessage) -> Self { + Self::Tool(value) + } +} + +impl From for ChatCompletionRequestUserMessage { + fn from(value: ChatCompletionRequestUserMessageContent) -> Self { + Self { + content: value, + name: None, + } + } +} + +impl From for ChatCompletionRequestSystemMessage { + fn from(value: ChatCompletionRequestSystemMessageContent) -> Self { + Self { + content: value, + name: None, + } + } +} + +impl From for ChatCompletionRequestDeveloperMessage { + fn from(value: ChatCompletionRequestDeveloperMessageContent) -> Self { + Self { + content: value, + name: None, + } + } +} + +impl From for ChatCompletionRequestAssistantMessage { + fn from(value: ChatCompletionRequestAssistantMessageContent) -> Self { + Self { + content: Some(value), + ..Default::default() + } + } +} + +impl From<&str> for ChatCompletionRequestUserMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestUserMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestUserMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestUserMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestSystemMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestSystemMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestSystemMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestSystemMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestDeveloperMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestDeveloperMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestDeveloperMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestDeveloperMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestAssistantMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestAssistantMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestAssistantMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestAssistantMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestToolMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestToolMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestToolMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestToolMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestUserMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestUserMessageContent::Text(value.into()).into() + } +} + +impl From for ChatCompletionRequestUserMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From<&str> for ChatCompletionRequestSystemMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestSystemMessageContent::Text(value.into()).into() + } +} + +impl From<&str> for ChatCompletionRequestDeveloperMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestDeveloperMessageContent::Text(value.into()).into() + } +} + +impl From for ChatCompletionRequestSystemMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From for ChatCompletionRequestDeveloperMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From<&str> for ChatCompletionRequestAssistantMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestAssistantMessageContent::Text(value.into()).into() + } +} + +impl From for ChatCompletionRequestAssistantMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From> + for ChatCompletionRequestUserMessageContent +{ + fn from(value: Vec) -> Self { + ChatCompletionRequestUserMessageContent::Array(value) + } +} + +impl From + for ChatCompletionRequestUserMessageContentPart +{ + fn from(value: ChatCompletionRequestMessageContentPartText) -> Self { + ChatCompletionRequestUserMessageContentPart::Text(value) + } +} + +impl From + for ChatCompletionRequestUserMessageContentPart +{ + fn from(value: ChatCompletionRequestMessageContentPartImage) -> Self { + ChatCompletionRequestUserMessageContentPart::ImageUrl(value) + } +} + +impl From + for ChatCompletionRequestUserMessageContentPart +{ + fn from(value: ChatCompletionRequestMessageContentPartAudio) -> Self { + ChatCompletionRequestUserMessageContentPart::InputAudio(value) + } +} + +impl From<&str> for ChatCompletionRequestMessageContentPartText { + fn from(value: &str) -> Self { + ChatCompletionRequestMessageContentPartText { text: value.into() } + } +} + +impl From for ChatCompletionRequestMessageContentPartText { + fn from(value: String) -> Self { + ChatCompletionRequestMessageContentPartText { text: value } + } +} + +impl From<&str> for ImageUrl { + fn from(value: &str) -> Self { + Self { + url: value.into(), + detail: Default::default(), + } + } +} + +impl From for ImageUrl { + fn from(value: String) -> Self { + Self { + url: value, + detail: Default::default(), + } + } +} + +impl From for CreateMessageRequestContent { + fn from(value: String) -> Self { + Self::Content(value) + } +} + +impl From<&str> for CreateMessageRequestContent { + fn from(value: &str) -> Self { + Self::Content(value.to_string()) + } +} + +impl Default for ChatCompletionRequestUserMessageContent { + fn default() -> Self { + ChatCompletionRequestUserMessageContent::Text("".into()) + } +} + +impl Default for CreateMessageRequestContent { + fn default() -> Self { + Self::Content("".into()) + } +} + +impl Default for ChatCompletionRequestDeveloperMessageContent { + fn default() -> Self { + ChatCompletionRequestDeveloperMessageContent::Text("".into()) + } +} + +impl Default for ChatCompletionRequestSystemMessageContent { + fn default() -> Self { + ChatCompletionRequestSystemMessageContent::Text("".into()) + } +} + +impl Default for ChatCompletionRequestToolMessageContent { + fn default() -> Self { + ChatCompletionRequestToolMessageContent::Text("".into()) + } +} + +// start: types to multipart from + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateTranscriptionRequest) -> Result { + let audio_part = create_file_part(request.file.source).await?; + + let mut form = reqwest::multipart::Form::new() + .part("file", audio_part) + .text("model", request.model); + + if let Some(prompt) = request.prompt { + form = form.text("prompt", prompt); + } + + if let Some(response_format) = request.response_format { + form = form.text("response_format", response_format.to_string()) + } + + if let Some(temperature) = request.temperature { + form = form.text("temperature", temperature.to_string()) + } + + if let Some(language) = request.language { + form = form.text("language", language); + } + + if let Some(timestamp_granularities) = request.timestamp_granularities { + for tg in timestamp_granularities { + form = form.text("timestamp_granularities[]", tg.to_string()); + } + } + + Ok(form) + } +} + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateTranslationRequest) -> Result { + let audio_part = create_file_part(request.file.source).await?; + + let mut form = reqwest::multipart::Form::new() + .part("file", audio_part) + .text("model", request.model); + + if let Some(prompt) = request.prompt { + form = form.text("prompt", prompt); + } + + if let Some(response_format) = request.response_format { + form = form.text("response_format", response_format.to_string()) + } + + if let Some(temperature) = request.temperature { + form = form.text("temperature", temperature.to_string()) + } + Ok(form) + } +} + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateImageEditRequest) -> Result { + let image_part = create_file_part(request.image.source).await?; + + let mut form = reqwest::multipart::Form::new() + .part("image", image_part) + .text("prompt", request.prompt); + + if let Some(mask) = request.mask { + let mask_part = create_file_part(mask.source).await?; + form = form.part("mask", mask_part); + } + + if let Some(model) = request.model { + form = form.text("model", model.to_string()) + } + + if request.n.is_some() { + form = form.text("n", request.n.unwrap().to_string()) + } + + if request.size.is_some() { + form = form.text("size", request.size.unwrap().to_string()) + } + + if request.response_format.is_some() { + form = form.text( + "response_format", + request.response_format.unwrap().to_string(), + ) + } + + if request.user.is_some() { + form = form.text("user", request.user.unwrap()) + } + Ok(form) + } +} + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateImageVariationRequest) -> Result { + let image_part = create_file_part(request.image.source).await?; + + let mut form = reqwest::multipart::Form::new().part("image", image_part); + + if let Some(model) = request.model { + form = form.text("model", model.to_string()) + } + + if request.n.is_some() { + form = form.text("n", request.n.unwrap().to_string()) + } + + if request.size.is_some() { + form = form.text("size", request.size.unwrap().to_string()) + } + + if request.response_format.is_some() { + form = form.text( + "response_format", + request.response_format.unwrap().to_string(), + ) + } + + if request.user.is_some() { + form = form.text("user", request.user.unwrap()) + } + Ok(form) + } +} + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateFileRequest) -> Result { + let file_part = create_file_part(request.file.source).await?; + let form = reqwest::multipart::Form::new() + .part("file", file_part) + .text("purpose", request.purpose.to_string()); + Ok(form) + } +} + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: AddUploadPartRequest) -> Result { + let file_part = create_file_part(request.data).await?; + let form = reqwest::multipart::Form::new().part("data", file_part); + Ok(form) + } +} + +// end: types to multipart form diff --git a/clia-async-openai/src/types/invites.rs b/clia-async-openai/src/types/invites.rs new file mode 100644 index 00000000..282b3e02 --- /dev/null +++ b/clia-async-openai/src/types/invites.rs @@ -0,0 +1,62 @@ +use crate::types::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use super::OrganizationRole; + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum InviteStatus { + Accepted, + Expired, + Pending, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Builder)] +#[builder(name = "InviteRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option))] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct InviteRequest { + pub email: String, + pub role: OrganizationRole, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct InviteListResponse { + pub object: String, + pub data: Vec, + pub first_id: Option, + pub last_id: Option, + pub has_more: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct InviteDeleteResponse { + /// The object type, which is always `organization.invite.deleted` + pub object: String, + pub id: String, + pub deleted: bool, +} + +/// Represents an individual `invite` to the organization. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct Invite { + /// The object type, which is always `organization.invite` + pub object: String, + /// The identifier, which can be referenced in API endpoints + pub id: String, + /// The email address of the individual to whom the invite was sent + pub email: String, + /// `owner` or `reader` + pub role: OrganizationRole, + /// `accepted`, `expired`, or `pending` + pub status: InviteStatus, + /// The Unix timestamp (in seconds) of when the invite was sent. + pub invited_at: u32, + /// The Unix timestamp (in seconds) of when the invite expires. + pub expires_at: u32, + /// The Unix timestamp (in seconds) of when the invite was accepted. + pub accepted_at: Option, +} diff --git a/clia-async-openai/src/types/message.rs b/clia-async-openai/src/types/message.rs new file mode 100644 index 00000000..af79ccc1 --- /dev/null +++ b/clia-async-openai/src/types/message.rs @@ -0,0 +1,355 @@ +use std::collections::HashMap; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +use super::{ImageDetail, ImageUrl}; + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +#[serde(rename_all = "lowercase")] +pub enum MessageRole { + #[default] + User, + Assistant, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum MessageStatus { + InProgress, + Incomplete, + Completed, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum MessageIncompleteDetailsType { + ContentFilter, + MaxTokens, + RunCancelled, + RunExpired, + RunFailed, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageIncompleteDetails { + /// The reason the message is incomplete. + pub reason: MessageIncompleteDetailsType, +} + +/// Represents a message within a [thread](https://platform.openai.com/docs/api-reference/threads). +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `thread.message`. + pub object: String, + /// The Unix timestamp (in seconds) for when the message was created. + pub created_at: i32, + /// The [thread](https://platform.openai.com/docs/api-reference/threads) ID that this message belongs to. + pub thread_id: String, + + /// The status of the message, which can be either `in_progress`, `incomplete`, or `completed`. + pub status: Option, + + /// On an incomplete message, details about why the message is incomplete. + pub incomplete_details: Option, + + /// The Unix timestamp (in seconds) for when the message was completed. + pub completed_at: Option, + + /// The Unix timestamp (in seconds) for when the message was marked as incomplete. + pub incomplete_at: Option, + + /// The entity that produced the message. One of `user` or `assistant`. + pub role: MessageRole, + + /// The content of the message in array of text and/or images. + pub content: Vec, + + /// If applicable, the ID of the [assistant](https://platform.openai.com/docs/api-reference/assistants) that authored this message. + pub assistant_id: Option, + + /// The ID of the [run](https://platform.openai.com/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints. + pub run_id: Option, + + /// A list of files attached to the message, and the tools they were added to. + pub attachments: Option>, + + pub metadata: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageAttachment { + /// The ID of the file to attach to the message. + pub file_id: String, + /// The tools to add this file to. + pub tools: Vec, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum MessageAttachmentTool { + CodeInterpreter, + FileSearch, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum MessageContent { + Text(MessageContentTextObject), + ImageFile(MessageContentImageFileObject), + ImageUrl(MessageContentImageUrlObject), + Refusal(MessageContentRefusalObject), +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageContentRefusalObject { + pub refusal: String, +} + +/// The text content that is part of a message. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageContentTextObject { + pub text: TextData, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct TextData { + /// The data that makes up the text. + pub value: String, + pub annotations: Vec, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum MessageContentTextAnnotations { + /// A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "retrieval" tool to search files. + FileCitation(MessageContentTextAnnotationsFileCitationObject), + /// A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. + FilePath(MessageContentTextAnnotationsFilePathObject), +} + +/// A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageContentTextAnnotationsFileCitationObject { + /// The text in the message content that needs to be replaced. + pub text: String, + pub file_citation: FileCitation, + pub start_index: u32, + pub end_index: u32, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct FileCitation { + /// The ID of the specific File the citation is from. + pub file_id: String, + /// The specific quote in the file. + pub quote: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageContentTextAnnotationsFilePathObject { + /// The text in the message content that needs to be replaced. + pub text: String, + pub file_path: FilePath, + pub start_index: u32, + pub end_index: u32, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct FilePath { + /// The ID of the file that was generated. + pub file_id: String, +} + +/// References an image [File](https://platform.openai.com/docs/api-reference/files) in the content of a message. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageContentImageFileObject { + pub image_file: ImageFile, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ImageFile { + /// The [File](https://platform.openai.com/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. + pub file_id: String, + /// Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. + pub detail: Option, +} + +/// References an image URL in the content of a message. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageContentImageUrlObject { + pub image_url: ImageUrl, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageRequestContentTextObject { + /// Text content to be sent to the model + pub text: String, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum CreateMessageRequestContent { + /// The text contents of the message. + Content(String), + /// An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](https://platform.openai.com/docs/models/overview). + ContentArray(Vec), +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum MessageContentInput { + Text(MessageRequestContentTextObject), + ImageFile(MessageContentImageFileObject), + ImageUrl(MessageContentImageUrlObject), +} +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "CreateMessageRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateMessageRequest { + /// The role of the entity that is creating the message. Allowed values include: + /// - `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages. + /// - `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation. + pub role: MessageRole, + /// The content of the message. + pub content: CreateMessageRequestContent, + + /// A list of files attached to the message, and the tools they should be added to. + pub attachments: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct ModifyMessageRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct DeleteMessageResponse { + pub id: String, + pub deleted: bool, + pub object: String, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct ListMessagesResponse { + pub object: String, + pub data: Vec, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, +} + +/// Represents a message delta i.e. any changed fields on a message during streaming. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaObject { + /// The identifier of the message, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `thread.message.delta`. + pub object: String, + /// The delta containing the fields that have changed on the Message. + pub delta: MessageDelta, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDelta { + /// The entity that produced the message. One of `user` or `assistant`. + pub role: Option, + /// The content of the message in array of text and/or images. + pub content: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum MessageDeltaContent { + ImageFile(MessageDeltaContentImageFileObject), + ImageUrl(MessageDeltaContentImageUrlObject), + Text(MessageDeltaContentTextObject), + Refusal(MessageDeltaContentRefusalObject), +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaContentRefusalObject { + /// The index of the refusal part in the message. + pub index: i32, + pub refusal: Option, +} + +/// The text content that is part of a message. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaContentTextObject { + /// The index of the content part in the message. + pub index: u32, + pub text: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaContentText { + /// The data that makes up the text. + pub value: Option, + pub annotations: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum MessageDeltaContentTextAnnotations { + FileCitation(MessageDeltaContentTextAnnotationsFileCitationObject), + FilePath(MessageDeltaContentTextAnnotationsFilePathObject), +} + +/// A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaContentTextAnnotationsFileCitationObject { + /// The index of the annotation in the text content part. + pub index: u32, + /// The text in the message content that needs to be replaced. + pub text: Option, + pub file_citation: Option, + pub start_index: Option, + pub end_index: Option, +} + +/// A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaContentTextAnnotationsFilePathObject { + /// The index of the annotation in the text content part. + pub index: u32, + /// The text in the message content that needs to be replaced. + pub text: Option, + pub file_path: Option, + pub start_index: Option, + pub end_index: Option, +} + +/// References an image [File](https://platform.openai.com/docs/api-reference/files) in the content of a message. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaContentImageFileObject { + /// The index of the content part in the message. + pub index: u32, + + pub image_file: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageDeltaContentImageUrlObject { + /// The index of the content part in the message. + pub index: u32, + + pub image_url: Option, +} diff --git a/clia-async-openai/src/types/mod.rs b/clia-async-openai/src/types/mod.rs new file mode 100644 index 00000000..f71b538a --- /dev/null +++ b/clia-async-openai/src/types/mod.rs @@ -0,0 +1,70 @@ +//! Types used in OpenAI API requests and responses. +//! These types are created from component schemas in the [OpenAPI spec](https://github.com/openai/openai-openapi) +mod assistant; +mod assistant_impls; +mod assistant_stream; +mod audio; +mod audit_log; +mod batch; +mod chat; +mod common; +mod completion; +mod embedding; +mod file; +mod fine_tuning; +mod image; +mod invites; +mod message; +mod model; +mod moderation; +mod project_api_key; +mod project_service_account; +mod project_users; +mod projects; +#[cfg_attr(docsrs, doc(cfg(feature = "realtime")))] +#[cfg(feature = "realtime")] +pub mod realtime; +mod run; +mod step; +mod thread; +mod upload; +mod users; +mod vector_store; + +pub use assistant::*; +pub use assistant_stream::*; +pub use audio::*; +pub use audit_log::*; +pub use batch::*; +pub use chat::*; +pub use common::*; +pub use completion::*; +pub use embedding::*; +pub use file::*; +pub use fine_tuning::*; +pub use image::*; +pub use invites::*; +pub use message::*; +pub use model::*; +pub use moderation::*; +pub use project_api_key::*; +pub use project_service_account::*; +pub use project_users::*; +pub use projects::*; +pub use run::*; +pub use step::*; +pub use thread::*; +pub use upload::*; +pub use users::*; +pub use vector_store::*; + +mod impls; +use derive_builder::UninitializedFieldError; + +use crate::error::OpenAIError; + +impl From for OpenAIError { + fn from(value: UninitializedFieldError) -> Self { + OpenAIError::InvalidArgument(value.to_string()) + } +} diff --git a/clia-async-openai/src/types/model.rs b/clia-async-openai/src/types/model.rs new file mode 100644 index 00000000..034213a6 --- /dev/null +++ b/clia-async-openai/src/types/model.rs @@ -0,0 +1,27 @@ +use serde::{Deserialize, Serialize}; + +/// Describes an OpenAI model offering that can be used with the API. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Model { + /// The model identifier, which can be referenced in the API endpoints. + pub id: String, + /// The object type, which is always "model". + pub object: String, + /// The Unix timestamp (in seconds) when the model was created. + pub created: u32, + /// The organization that owns the model. + pub owned_by: String, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct ListModelResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct DeleteModelResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} diff --git a/clia-async-openai/src/types/moderation.rs b/clia-async-openai/src/types/moderation.rs new file mode 100644 index 00000000..f8c1c0ff --- /dev/null +++ b/clia-async-openai/src/types/moderation.rs @@ -0,0 +1,227 @@ +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)] +#[serde(untagged)] +pub enum ModerationInput { + /// A single string of text to classify for moderation + String(String), + + /// An array of strings to classify for moderation + StringArray(Vec), + + /// An array of multi-modal inputs to the moderation model + MultiModal(Vec), +} + +/// Content part for multi-modal moderation input +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +pub enum ModerationContentPart { + /// An object describing text to classify + #[serde(rename = "text")] + Text { + /// A string of text to classify + text: String, + }, + + /// An object describing an image to classify + #[serde(rename = "image_url")] + ImageUrl { + /// Contains either an image URL or a data URL for a base64 encoded image + image_url: ModerationImageUrl, + }, +} + +/// Image URL configuration for image moderation +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ModerationImageUrl { + /// Either a URL of the image or the base64 encoded image data + pub url: String, +} + +#[derive(Debug, Default, Clone, Serialize, Builder, PartialEq, Deserialize)] +#[builder(name = "CreateModerationRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateModerationRequest { + /// Input (or inputs) to classify. Can be a single string, an array of strings, or + /// an array of multi-modal input objects similar to other models. + pub input: ModerationInput, + + /// The content moderation model you would like to use. Learn more in the + /// [moderation guide](https://platform.openai.com/docs/guides/moderation), and learn about + /// available models [here](https://platform.openai.com/docs/models/moderation). + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct Category { + /// Content that expresses, incites, or promotes hate based on race, gender, + /// ethnicity, religion, nationality, sexual orientation, disability status, or + /// caste. Hateful content aimed at non-protected groups (e.g., chess players) + /// is harrassment. + pub hate: bool, + #[serde(rename = "hate/threatening")] + /// Hateful content that also includes violence or serious harm towards the + /// targeted group based on race, gender, ethnicity, religion, nationality, + /// sexual orientation, disability status, or caste. + pub hate_threatening: bool, + /// Content that expresses, incites, or promotes harassing language towards any target. + pub harassment: bool, + /// Harassment content that also includes violence or serious harm towards any target. + #[serde(rename = "harassment/threatening")] + pub harassment_threatening: bool, + /// Content that includes instructions or advice that facilitate the planning or execution of wrongdoing, or that gives advice or instruction on how to commit illicit acts. For example, "how to shoplift" would fit this category. + pub illicit: bool, + /// Content that includes instructions or advice that facilitate the planning or execution of wrongdoing that also includes violence, or that gives advice or instruction on the procurement of any weapon. + #[serde(rename = "illicit/violent")] + pub illicit_violent: bool, + /// Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders. + #[serde(rename = "self-harm")] + pub self_harm: bool, + /// Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders. + #[serde(rename = "self-harm/intent")] + pub self_harm_intent: bool, + /// Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts. + #[serde(rename = "self-harm/instructions")] + pub self_harm_instructions: bool, + /// Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness). + pub sexual: bool, + /// Sexual content that includes an individual who is under 18 years old. + #[serde(rename = "sexual/minors")] + pub sexual_minors: bool, + /// Content that depicts death, violence, or physical injury. + pub violence: bool, + /// Content that depicts death, violence, or physical injury in graphic detail. + #[serde(rename = "violence/graphic")] + pub violence_graphic: bool, +} + +/// A list of the categories along with their scores as predicted by model. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct CategoryScore { + /// The score for the category 'hate'. + pub hate: f32, + /// The score for the category 'hate/threatening'. + #[serde(rename = "hate/threatening")] + pub hate_threatening: f32, + /// The score for the category 'harassment'. + pub harassment: f32, + /// The score for the category 'harassment/threatening'. + #[serde(rename = "harassment/threatening")] + pub harassment_threatening: f32, + /// The score for the category 'illicit'. + pub illicit: f32, + /// The score for the category 'illicit/violent'. + #[serde(rename = "illicit/violent")] + pub illicit_violent: f32, + /// The score for the category 'self-harm'. + #[serde(rename = "self-harm")] + pub self_harm: f32, + /// The score for the category 'self-harm/intent'. + #[serde(rename = "self-harm/intent")] + pub self_harm_intent: f32, + /// The score for the category 'self-harm/instructions'. + #[serde(rename = "self-harm/instructions")] + pub self_harm_instructions: f32, + /// The score for the category 'sexual'. + pub sexual: f32, + /// The score for the category 'sexual/minors'. + #[serde(rename = "sexual/minors")] + pub sexual_minors: f32, + /// The score for the category 'violence'. + pub violence: f32, + /// The score for the category 'violence/graphic'. + #[serde(rename = "violence/graphic")] + pub violence_graphic: f32, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct ContentModerationResult { + /// Whether any of the below categories are flagged. + pub flagged: bool, + /// A list of the categories, and whether they are flagged or not. + pub categories: Category, + /// A list of the categories along with their scores as predicted by model. + pub category_scores: CategoryScore, + /// A list of the categories along with the input type(s) that the score applies to. + pub category_applied_input_types: CategoryAppliedInputTypes, +} + +/// Represents if a given text input is potentially harmful. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct CreateModerationResponse { + /// The unique identifier for the moderation request. + pub id: String, + /// The model used to generate the moderation results. + pub model: String, + /// A list of moderation objects. + pub results: Vec, +} + +/// A list of the categories along with the input type(s) that the score applies to. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct CategoryAppliedInputTypes { + /// The applied input type(s) for the category 'hate'. + pub hate: Vec, + + /// The applied input type(s) for the category 'hate/threatening'. + #[serde(rename = "hate/threatening")] + pub hate_threatening: Vec, + + /// The applied input type(s) for the category 'harassment'. + pub harassment: Vec, + + /// The applied input type(s) for the category 'harassment/threatening'. + #[serde(rename = "harassment/threatening")] + pub harassment_threatening: Vec, + + /// The applied input type(s) for the category 'illicit'. + pub illicit: Vec, + + /// The applied input type(s) for the category 'illicit/violent'. + #[serde(rename = "illicit/violent")] + pub illicit_violent: Vec, + + /// The applied input type(s) for the category 'self-harm'. + #[serde(rename = "self-harm")] + pub self_harm: Vec, + + /// The applied input type(s) for the category 'self-harm/intent'. + #[serde(rename = "self-harm/intent")] + pub self_harm_intent: Vec, + + /// The applied input type(s) for the category 'self-harm/instructions'. + #[serde(rename = "self-harm/instructions")] + pub self_harm_instructions: Vec, + + /// The applied input type(s) for the category 'sexual'. + pub sexual: Vec, + + /// The applied input type(s) for the category 'sexual/minors'. + #[serde(rename = "sexual/minors")] + pub sexual_minors: Vec, + + /// The applied input type(s) for the category 'violence'. + pub violence: Vec, + + /// The applied input type(s) for the category 'violence/graphic'. + #[serde(rename = "violence/graphic")] + pub violence_graphic: Vec, +} + +/// The type of input that was moderated +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ModInputType { + /// Text content that was moderated + Text, + /// Image content that was moderated + Image, +} diff --git a/clia-async-openai/src/types/project_api_key.rs b/clia-async-openai/src/types/project_api_key.rs new file mode 100644 index 00000000..96886581 --- /dev/null +++ b/clia-async-openai/src/types/project_api_key.rs @@ -0,0 +1,64 @@ +use serde::{Deserialize, Serialize}; + +use super::{ProjectServiceAccount, ProjectUser}; + +/// Represents an individual API key in a project. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectApiKey { + /// The object type, which is always `organization.project.api_key`. + pub object: String, + /// The redacted value of the API key. + pub redacted_value: String, + /// The name of the API key. + pub name: String, + /// The Unix timestamp (in seconds) of when the API key was created. + pub created_at: u32, + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The owner of the API key. + pub owner: ProjectApiKeyOwner, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename = "snake_case")] +pub enum ProjectApiKeyOwnerType { + User, + ServiceAccount, +} + +/// Represents the owner of a project API key. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectApiKeyOwner { + /// The type of owner, which is either `user` or `service_account`. + pub r#type: ProjectApiKeyOwnerType, + /// The user owner of the API key, if applicable. + pub user: Option, + /// The service account owner of the API key, if applicable. + pub service_account: Option, +} + +/// Represents the response object for listing project API keys. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectApiKeyListResponse { + /// The object type, which is always `list`. + pub object: String, + /// The list of project API keys. + pub data: Vec, + /// The ID of the first project API key in the list. + pub first_id: String, + /// The ID of the last project API key in the list. + pub last_id: String, + /// Indicates if there are more project API keys available. + pub has_more: bool, +} + +/// Represents the response object for deleting a project API key. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectApiKeyDeleteResponse { + /// The object type, which is always `organization.project.api_key.deleted`. + pub object: String, + /// The ID of the deleted API key. + pub id: String, + /// Indicates if the API key was successfully deleted. + pub deleted: bool, +} diff --git a/clia-async-openai/src/types/project_service_account.rs b/clia-async-openai/src/types/project_service_account.rs new file mode 100644 index 00000000..4449ddf8 --- /dev/null +++ b/clia-async-openai/src/types/project_service_account.rs @@ -0,0 +1,83 @@ +use serde::{Deserialize, Serialize}; + +use super::ProjectUserRole; + +/// Represents an individual service account in a project. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectServiceAccount { + /// The object type, which is always `organization.project.service_account`. + pub object: String, + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The name of the service account. + pub name: String, + /// `owner` or `member`. + pub role: ProjectUserRole, + /// The Unix timestamp (in seconds) of when the service account was created. + pub created_at: u32, +} + +/// Represents the response object for listing project service accounts. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectServiceAccountListResponse { + /// The object type, which is always `list`. + pub object: String, + /// The list of project service accounts. + pub data: Vec, + /// The ID of the first project service account in the list. + pub first_id: String, + /// The ID of the last project service account in the list. + pub last_id: String, + /// Indicates if there are more project service accounts available. + pub has_more: bool, +} + +/// Represents the request object for creating a project service account. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectServiceAccountCreateRequest { + /// The name of the service account being created. + pub name: String, +} + +/// Represents the response object for creating a project service account. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectServiceAccountCreateResponse { + /// The object type, which is always `organization.project.service_account`. + pub object: String, + /// The ID of the created service account. + pub id: String, + /// The name of the created service account. + pub name: String, + /// Service accounts can only have one role of type `member`. + pub role: String, + /// The Unix timestamp (in seconds) of when the service account was created. + pub created_at: u32, + /// The API key associated with the created service account. + pub api_key: ProjectServiceAccountApiKey, +} + +/// Represents the API key associated with a project service account. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectServiceAccountApiKey { + /// The object type, which is always `organization.project.service_account.api_key`. + pub object: String, + /// The value of the API key. + pub value: String, + /// The name of the API key. + pub name: String, + /// The Unix timestamp (in seconds) of when the API key was created. + pub created_at: u32, + /// The ID of the API key. + pub id: String, +} + +/// Represents the response object for deleting a project service account. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProjectServiceAccountDeleteResponse { + /// The object type, which is always `organization.project.service_account.deleted`. + pub object: String, + /// The ID of the deleted service account. + pub id: String, + /// Indicates if the service account was successfully deleted. + pub deleted: bool, +} diff --git a/clia-async-openai/src/types/project_users.rs b/clia-async-openai/src/types/project_users.rs new file mode 100644 index 00000000..5bedd26a --- /dev/null +++ b/clia-async-openai/src/types/project_users.rs @@ -0,0 +1,69 @@ +use crate::types::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +/// Represents an individual user in a project. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ProjectUser { + /// The object type, which is always `organization.project.user` + pub object: String, + /// The identifier, which can be referenced in API endpoints + pub id: String, + /// The name of the project + pub name: String, + /// The email address of the user + pub email: String, + /// `owner` or `member` + pub role: ProjectUserRole, + /// The Unix timestamp (in seconds) of when the project was added. + pub added_at: u32, +} + +/// `owner` or `member` +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ProjectUserRole { + Owner, + Member, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ProjectUserListResponse { + pub object: String, + pub data: Vec, + pub first_id: String, + pub last_id: String, + pub has_more: String, +} + +/// The project user create request payload. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Builder)] +#[builder(name = "ProjectUserCreateRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option))] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ProjectUserCreateRequest { + /// The ID of the user. + pub user_id: String, + /// `owner` or `member` + pub role: ProjectUserRole, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Builder)] +#[builder(name = "ProjectUserUpdateRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option))] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ProjectUserUpdateRequest { + /// `owner` or `member` + pub role: ProjectUserRole, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ProjectUserDeleteResponse { + pub object: String, + pub id: String, + pub deleted: bool, +} diff --git a/clia-async-openai/src/types/projects.rs b/clia-async-openai/src/types/projects.rs new file mode 100644 index 00000000..bb5ae3ff --- /dev/null +++ b/clia-async-openai/src/types/projects.rs @@ -0,0 +1,62 @@ +use crate::types::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +/// `active` or `archived` +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ProjectStatus { + Active, + Archived, +} + +/// Represents an individual project. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct Project { + /// The identifier, which can be referenced in API endpoints + pub id: String, + /// The object type, which is always `organization.project` + pub object: String, + /// The name of the project. This appears in reporting. + pub name: String, + /// The Unix timestamp (in seconds) of when the project was created. + pub created_at: u32, + /// The Unix timestamp (in seconds) of when the project was archived or `null`. + pub archived_at: Option, + /// `active` or `archived` + pub status: ProjectStatus, +} + +/// A list of Project objects. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ProjectListResponse { + pub object: String, + pub data: Vec, + pub first_id: String, + pub last_id: String, + pub has_more: String, +} + +/// The project create request payload. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Builder)] +#[builder(name = "ProjectCreateRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option))] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ProjectCreateRequest { + /// The friendly name of the project, this name appears in reports. + pub name: String, +} + +/// The project update request payload. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Builder)] +#[builder(name = "ProjectUpdateRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option))] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ProjectUpdateRequest { + /// The updated name of the project, this name appears in reports. + pub name: String, +} diff --git a/clia-async-openai/src/types/realtime/client_event.rs b/clia-async-openai/src/types/realtime/client_event.rs new file mode 100644 index 00000000..87ff7010 --- /dev/null +++ b/clia-async-openai/src/types/realtime/client_event.rs @@ -0,0 +1,220 @@ +use serde::{Deserialize, Serialize}; +use tokio_tungstenite::tungstenite::Message; + +use super::{item::Item, session_resource::SessionResource}; + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct SessionUpdateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + /// Session configuration to update. + pub session: SessionResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct InputAudioBufferAppendEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + /// Base64-encoded audio bytes. + pub audio: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct InputAudioBufferCommitEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct InputAudioBufferClearEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemCreateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// The ID of the preceding item after which the new item will be inserted. + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_item_id: Option, + + /// The item to add to the conversation. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ConversationItemTruncateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// The ID of the assistant message item to truncate. + pub item_id: String, + + /// The index of the content part to truncate. + pub content_index: u32, + + /// Inclusive duration up to which audio is truncated, in milliseconds. + pub audio_end_ms: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ConversationItemDeleteEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// The ID of the item to delete. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ResponseCreateEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, + + /// Configuration for the response. + pub response: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ResponseCancelEvent { + /// Optional client-generated ID used to identify this event. + #[serde(skip_serializing_if = "Option::is_none")] + pub event_id: Option, +} + +/// These are events that the OpenAI Realtime WebSocket server will accept from the client. +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ClientEvent { + /// Send this event to update the session’s default configuration. + #[serde(rename = "session.update")] + SessionUpdate(SessionUpdateEvent), + + /// Send this event to append audio bytes to the input audio buffer. + #[serde(rename = "input_audio_buffer.append")] + InputAudioBufferAppend(InputAudioBufferAppendEvent), + + /// Send this event to commit audio bytes to a user message. + #[serde(rename = "input_audio_buffer.commit")] + InputAudioBufferCommit(InputAudioBufferCommitEvent), + + /// Send this event to clear the audio bytes in the buffer. + #[serde(rename = "input_audio_buffer.clear")] + InputAudioBufferClear(InputAudioBufferClearEvent), + + /// Send this event when adding an item to the conversation. + #[serde(rename = "conversation.item.create")] + ConversationItemCreate(ConversationItemCreateEvent), + + /// Send this event when you want to truncate a previous assistant message’s audio. + #[serde(rename = "conversation.item.truncate")] + ConversationItemTruncate(ConversationItemTruncateEvent), + + /// Send this event when you want to remove any item from the conversation history. + #[serde(rename = "conversation.item.delete")] + ConversationItemDelete(ConversationItemDeleteEvent), + + /// Send this event to trigger a response generation. + #[serde(rename = "response.create")] + ResponseCreate(ResponseCreateEvent), + + /// Send this event to cancel an in-progress response. + #[serde(rename = "response.cancel")] + ResponseCancel(ResponseCancelEvent), +} + +impl From<&ClientEvent> for String { + fn from(value: &ClientEvent) -> Self { + serde_json::to_string(value).unwrap() + } +} + +impl From for Message { + fn from(value: ClientEvent) -> Self { + Message::Text(String::from(&value).into()) + } +} + +macro_rules! message_from_event { + ($from_typ:ty, $evt_typ:ty) => { + impl From<$from_typ> for Message { + fn from(value: $from_typ) -> Self { + Self::from(<$evt_typ>::from(value)) + } + } + }; +} + +macro_rules! event_from { + ($from_typ:ty, $evt_typ:ty, $variant:ident) => { + impl From<$from_typ> for $evt_typ { + fn from(value: $from_typ) -> Self { + <$evt_typ>::$variant(value) + } + } + }; +} + +event_from!(SessionUpdateEvent, ClientEvent, SessionUpdate); +event_from!( + InputAudioBufferAppendEvent, + ClientEvent, + InputAudioBufferAppend +); +event_from!( + InputAudioBufferCommitEvent, + ClientEvent, + InputAudioBufferCommit +); +event_from!( + InputAudioBufferClearEvent, + ClientEvent, + InputAudioBufferClear +); +event_from!( + ConversationItemCreateEvent, + ClientEvent, + ConversationItemCreate +); +event_from!( + ConversationItemTruncateEvent, + ClientEvent, + ConversationItemTruncate +); +event_from!( + ConversationItemDeleteEvent, + ClientEvent, + ConversationItemDelete +); +event_from!(ResponseCreateEvent, ClientEvent, ResponseCreate); +event_from!(ResponseCancelEvent, ClientEvent, ResponseCancel); + +message_from_event!(SessionUpdateEvent, ClientEvent); +message_from_event!(InputAudioBufferAppendEvent, ClientEvent); +message_from_event!(InputAudioBufferCommitEvent, ClientEvent); +message_from_event!(InputAudioBufferClearEvent, ClientEvent); +message_from_event!(ConversationItemCreateEvent, ClientEvent); +message_from_event!(ConversationItemTruncateEvent, ClientEvent); +message_from_event!(ConversationItemDeleteEvent, ClientEvent); +message_from_event!(ResponseCreateEvent, ClientEvent); +message_from_event!(ResponseCancelEvent, ClientEvent); + +impl From for ConversationItemCreateEvent { + fn from(value: Item) -> Self { + Self { + event_id: None, + previous_item_id: None, + item: value, + } + } +} diff --git a/clia-async-openai/src/types/realtime/content_part.rs b/clia-async-openai/src/types/realtime/content_part.rs new file mode 100644 index 00000000..eec93ab3 --- /dev/null +++ b/clia-async-openai/src/types/realtime/content_part.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { + /// The text content + text: String, + }, + #[serde(rename = "audio")] + Audio { + /// Base64-encoded audio data + audio: Option, + /// The transcript of the audio + transcript: String, + }, +} diff --git a/clia-async-openai/src/types/realtime/conversation.rs b/clia-async-openai/src/types/realtime/conversation.rs new file mode 100644 index 00000000..3ea43bd8 --- /dev/null +++ b/clia-async-openai/src/types/realtime/conversation.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Conversation { + /// The unique ID of the conversation. + pub id: String, + + /// The object type, must be "realtime.conversation". + pub object: String, +} diff --git a/clia-async-openai/src/types/realtime/error.rs b/clia-async-openai/src/types/realtime/error.rs new file mode 100644 index 00000000..6ce907c3 --- /dev/null +++ b/clia-async-openai/src/types/realtime/error.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RealtimeAPIError { + /// The type of error (e.g., "invalid_request_error", "server_error"). + pub r#type: String, + + /// Error code, if any. + pub code: Option, + + /// A human-readable error message. + pub message: String, + + /// Parameter related to the error, if any. + pub param: Option, + + /// The event_id of the client event that caused the error, if applicable. + pub event_id: Option, +} diff --git a/clia-async-openai/src/types/realtime/item.rs b/clia-async-openai/src/types/realtime/item.rs new file mode 100644 index 00000000..3af7d0d9 --- /dev/null +++ b/clia-async-openai/src/types/realtime/item.rs @@ -0,0 +1,99 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ItemType { + Message, + FunctionCall, + FunctionCallOutput, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ItemStatus { + Completed, + InProgress, + Incomplete, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ItemRole { + User, + Assistant, + System, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ItemContentType { + InputText, + InputAudio, + Text, + Audio, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ItemContent { + /// The content type ("input_text", "input_audio", "text", "audio"). + pub r#type: ItemContentType, + + /// The text content. + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + /// Base64-encoded audio bytes. + #[serde(skip_serializing_if = "Option::is_none")] + pub audio: Option, + + /// The transcript of the audio. + #[serde(skip_serializing_if = "Option::is_none")] + pub transcript: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Item { + /// The unique ID of the item. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + + /// The type of the item ("message", "function_call", "function_call_output"). + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, + + /// The status of the item ("completed", "in_progress", "incomplete"). + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, + + /// The role of the message sender ("user", "assistant", "system"). + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + + /// The content of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option>, + + /// The ID of the function call (for "function_call" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub call_id: Option, + + /// The name of the function being called (for "function_call" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// The arguments of the function call (for "function_call" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + + /// The output of the function call (for "function_call_output" items). + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, +} + +impl TryFrom for Item { + type Error = serde_json::Error; + + fn try_from(value: serde_json::Value) -> Result { + serde_json::from_value(value) + } +} diff --git a/clia-async-openai/src/types/realtime/mod.rs b/clia-async-openai/src/types/realtime/mod.rs new file mode 100644 index 00000000..b47605f8 --- /dev/null +++ b/clia-async-openai/src/types/realtime/mod.rs @@ -0,0 +1,19 @@ +mod client_event; +mod content_part; +mod conversation; +mod error; +mod item; +mod rate_limit; +mod response_resource; +mod server_event; +mod session_resource; + +pub use client_event::*; +pub use content_part::*; +pub use conversation::*; +pub use error::*; +pub use item::*; +pub use rate_limit::*; +pub use response_resource::*; +pub use server_event::*; +pub use session_resource::*; diff --git a/clia-async-openai/src/types/realtime/rate_limit.rs b/clia-async-openai/src/types/realtime/rate_limit.rs new file mode 100644 index 00000000..f3fc4aa6 --- /dev/null +++ b/clia-async-openai/src/types/realtime/rate_limit.rs @@ -0,0 +1,13 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RateLimit { + /// The name of the rate limit ("requests", "tokens", "input_tokens", "output_tokens"). + pub name: String, + /// The maximum allowed value for the rate limit. + pub limit: u32, + /// The remaining value before the limit is reached. + pub remaining: u32, + /// Seconds until the rate limit resets. + pub reset_seconds: f32, +} diff --git a/clia-async-openai/src/types/realtime/response_resource.rs b/clia-async-openai/src/types/realtime/response_resource.rs new file mode 100644 index 00000000..4a500890 --- /dev/null +++ b/clia-async-openai/src/types/realtime/response_resource.rs @@ -0,0 +1,59 @@ +use serde::{Deserialize, Serialize}; + +use super::item::Item; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Usage { + pub total_tokens: u32, + pub input_tokens: u32, + pub output_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ResponseStatus { + InProgress, + Completed, + Cancelled, + Failed, + Incomplete, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FailedError { + pub code: String, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum IncompleteReason { + Interruption, + MaxOutputTokens, + ContentFilter, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ResponseStatusDetail { + #[serde(rename = "incomplete")] + Incomplete { reason: IncompleteReason }, + #[serde(rename = "failed")] + Failed { error: Option }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseResource { + /// The unique ID of the response. + pub id: String, + /// The object type, must be "realtime.response". + pub object: String, + /// The status of the response + pub status: ResponseStatus, + /// Additional details about the status. + pub status_details: Option, + /// The list of output items generated by the response. + pub output: Vec, + /// Usage statistics for the response. + pub usage: Option, +} diff --git a/clia-async-openai/src/types/realtime/server_event.rs b/clia-async-openai/src/types/realtime/server_event.rs new file mode 100644 index 00000000..3ba5f552 --- /dev/null +++ b/clia-async-openai/src/types/realtime/server_event.rs @@ -0,0 +1,459 @@ +use serde::{Deserialize, Serialize}; + +use super::{ + content_part::ContentPart, conversation::Conversation, error::RealtimeAPIError, item::Item, + rate_limit::RateLimit, response_resource::ResponseResource, session_resource::SessionResource, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ErrorEvent { + /// The unique ID of the server event. + pub event_id: String, + /// Details of the error. + pub error: RealtimeAPIError, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SessionCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The session resource. + pub session: SessionResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SessionUpdatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The updated session resource. + pub session: SessionResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The conversation resource. + pub conversation: Conversation, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferCommitedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the preceding item after which the new item will be inserted. + pub previous_item_id: String, + /// The ID of the user message item that will be created. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferClearedEvent { + /// The unique ID of the server event. + pub event_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferSpeechStartedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// Milliseconds since the session started when speech was detected. + pub audio_start_ms: u32, + /// The ID of the user message item that will be created when speech stops. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct InputAudioBufferSpeechStoppedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// Milliseconds since the session started when speech stopped. + pub audio_end_ms: u32, + /// The ID of the user message item that will be created. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the preceding item. + pub previous_item_id: Option, + /// The item that was created. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemInputAudioTranscriptionCompletedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the user message item. + pub item_id: String, + /// The index of the content part containing the audio. + pub content_index: u32, + /// The transcribed text. + pub transcript: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemInputAudioTranscriptionFailedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the user message item. + pub item_id: String, + /// The index of the content part containing the audio. + pub content_index: u32, + /// Details of the transcription error. + pub error: RealtimeAPIError, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemTruncatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the assistant message item that was truncated. + pub item_id: String, + /// The index of the content part that was truncated. + pub content_index: u32, + /// The duration up to which the audio was truncated, in milliseconds. + pub audio_end_ms: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ConversationItemDeletedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the item that was deleted. + pub item_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseCreatedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The response resource. + pub response: ResponseResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The response resource. + pub response: ResponseResource, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseOutputItemAddedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response to which the item belongs. + pub response_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The item that was added. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseOutputItemDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response to which the item belongs. + pub response_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The completed item. + pub item: Item, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseContentPartAddedEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item to which the content part was added. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The content part that was added. + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseContentPartDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item to which the content part was added. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The content part that is done. + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseTextDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The text delta. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseTextDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The final text content. + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioTranscriptDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// The text delta. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioTranscriptDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + ///The final transcript of the audio. + pub transcript: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, + /// Base64-encoded audio data delta. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseAudioDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The index of the content part in the item's content array. + pub content_index: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseFunctionCallArgumentsDeltaEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the function call item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The ID of the function call. + pub call_id: String, + /// The arguments delta as a JSON string. + pub delta: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseFunctionCallArgumentsDoneEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The ID of the response. + pub response_id: String, + /// The ID of the function call item. + pub item_id: String, + /// The index of the output item in the response. + pub output_index: u32, + /// The ID of the function call. + pub call_id: String, + /// The final arguments as a JSON string. + pub arguments: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RateLimitsUpdatedEvent { + /// The unique ID of the server event. + pub event_id: String, + pub rate_limits: Vec, +} + +/// These are events emitted from the OpenAI Realtime WebSocket server to the client. +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ServerEvent { + /// Returned when an error occurs. + #[serde(rename = "error")] + Error(ErrorEvent), + + /// Returned when a session is created. Emitted automatically when a new connection is established. + #[serde(rename = "session.created")] + SessionCreated(SessionCreatedEvent), + + /// Returned when a session is updated. + #[serde(rename = "session.updated")] + SessionUpdated(SessionUpdatedEvent), + + /// Returned when a conversation is created. Emitted right after session creation. + #[serde(rename = "conversation.created")] + ConversationCreated(ConversationCreatedEvent), + + /// Returned when an input audio buffer is committed, either by the client or automatically in server VAD mode. + #[serde(rename = "input_audio_buffer.committed")] + InputAudioBufferCommited(InputAudioBufferCommitedEvent), + + /// Returned when the input audio buffer is cleared by the client. + #[serde(rename = "input_audio_buffer.cleared")] + InputAudioBufferCleared(InputAudioBufferClearedEvent), + + /// Returned in server turn detection mode when speech is detected. + #[serde(rename = "input_audio_buffer.speech_started")] + InputAudioBufferSpeechStarted(InputAudioBufferSpeechStartedEvent), + + /// Returned in server turn detection mode when speech stops. + #[serde(rename = "input_audio_buffer.speech_stopped")] + InputAudioBufferSpeechStopped(InputAudioBufferSpeechStoppedEvent), + + /// Returned when a conversation item is created. + #[serde(rename = "conversation.item.created")] + ConversationItemCreated(ConversationItemCreatedEvent), + + /// Returned when input audio transcription is enabled and a transcription succeeds. + #[serde(rename = "conversation.item.input_audio_transcription.completed")] + ConversationItemInputAudioTranscriptionCompleted( + ConversationItemInputAudioTranscriptionCompletedEvent, + ), + + /// Returned when input audio transcription is configured, and a transcription request for a user message failed. + #[serde(rename = "conversation.item.input_audio_transcription.failed")] + ConversationItemInputAudioTranscriptionFailed( + ConversationItemInputAudioTranscriptionFailedEvent, + ), + + /// Returned when an earlier assistant audio message item is truncated by the client. + #[serde(rename = "conversation.item.truncated")] + ConversationItemTruncated(ConversationItemTruncatedEvent), + + /// Returned when an item in the conversation is deleted. + #[serde(rename = "conversation.item.deleted")] + ConversationItemDeleted(ConversationItemDeletedEvent), + + /// Returned when a new Response is created. The first event of response creation, where the response is in an initial state of "in_progress". + #[serde(rename = "response.created")] + ResponseCreated(ResponseCreatedEvent), + + /// Returned when a Response is done streaming. Always emitted, no matter the final state. + #[serde(rename = "response.done")] + ResponseDone(ResponseDoneEvent), + + /// Returned when a new Item is created during response generation. + #[serde(rename = "response.output_item.added")] + ResponseOutputItemAdded(ResponseOutputItemAddedEvent), + + /// Returned when an Item is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.output_item.done")] + ResponseOutputItemDone(ResponseOutputItemDoneEvent), + + /// Returned when a new content part is added to an assistant message item during response generation. + #[serde(rename = "response.content_part.added")] + ResponseContentPartAdded(ResponseContentPartAddedEvent), + + /// Returned when a content part is done streaming in an assistant message item. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.content_part.done")] + ResponseContentPartDone(ResponseContentPartDoneEvent), + + /// Returned when the text value of a "text" content part is updated. + #[serde(rename = "response.text.delta")] + ResponseTextDelta(ResponseTextDeltaEvent), + + /// Returned when the text value of a "text" content part is done streaming. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.text.done")] + ResponseTextDone(ResponseTextDoneEvent), + + /// Returned when the model-generated transcription of audio output is updated. + #[serde(rename = "response.audio_transcript.delta")] + ResponseAudioTranscriptDelta(ResponseAudioTranscriptDeltaEvent), + + /// Returned when the model-generated transcription of audio output is done streaming. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.audio_transcript.done")] + ResponseAudioTranscriptDone(ResponseAudioTranscriptDoneEvent), + + /// Returned when the model-generated audio is updated. + #[serde(rename = "response.audio.delta")] + ResponseAudioDelta(ResponseAudioDeltaEvent), + + /// Returned when the model-generated audio is done. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.audio.done")] + ResponseAudioDone(ResponseAudioDoneEvent), + + /// Returned when the model-generated function call arguments are updated. + #[serde(rename = "response.function_call_arguments.delta")] + ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent), + + /// Returned when the model-generated function call arguments are done streaming. + /// Also emitted when a Response is interrupted, incomplete, or cancelled. + #[serde(rename = "response.function_call_arguments.done")] + ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent), + + /// Emitted after every "response.done" event to indicate the updated rate limits. + #[serde(rename = "rate_limits.updated")] + RateLimitsUpdated(RateLimitsUpdatedEvent), +} diff --git a/clia-async-openai/src/types/realtime/session_resource.rs b/clia-async-openai/src/types/realtime/session_resource.rs new file mode 100644 index 00000000..10472414 --- /dev/null +++ b/clia-async-openai/src/types/realtime/session_resource.rs @@ -0,0 +1,136 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum AudioFormat { + #[serde(rename = "pcm16")] + PCM16, + #[serde(rename = "g711-ulaw")] + G711ULAW, + #[serde(rename = "g711-alaw")] + G711ALAW, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct AudioTranscription { + /// Whether to enable input audio transcription. + pub enabled: bool, + /// The model to use for transcription (e.g., "whisper-1"). + pub model: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum TurnDetection { + /// Type of turn detection, only "server_vad" is currently supported. + #[serde(rename = "server_vad")] + ServerVAD { + /// Activation threshold for VAD (0.0 to 1.0). + threshold: f32, + /// Amount of audio to include before speech starts (in milliseconds). + prefix_padding_ms: u32, + /// Duration of silence to detect speech stop (in milliseconds). + silence_duration_ms: u32, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum MaxResponseOutputTokens { + #[serde(rename = "inf")] + Inf, + #[serde(untagged)] + Num(u16), +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum ToolDefinition { + #[serde(rename = "function")] + Function { + /// The name of the function. + name: String, + /// The description of the function. + description: String, + /// Parameters of the function in JSON Schema. + parameters: serde_json::Value, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum FunctionType { + Function, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ToolChoice { + Auto, + None, + Required, + #[serde(untagged)] + Function { + r#type: FunctionType, + name: String, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum RealtimeVoice { + Alloy, + Shimmer, + Echo, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct SessionResource { + /// The default model used for this session. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// The set of modalities the model can respond with. To disable audio, set this to ["text"]. + #[serde(skip_serializing_if = "Option::is_none")] + pub modalities: Option>, + + //// The default system instructions prepended to model calls. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// The voice the model uses to respond. Cannot be changed once the model has responded with audio at least once. + #[serde(skip_serializing_if = "Option::is_none")] + pub voice: Option, + + /// The format of input audio. Options are "pcm16", "g711_ulaw", or "g711_alaw". + #[serde(skip_serializing_if = "Option::is_none")] + pub input_audio_format: Option, + + /// The format of output audio. Options are "pcm16", "g711_ulaw", or "g711_alaw". + #[serde(skip_serializing_if = "Option::is_none")] + pub output_audio_format: Option, + + /// Configuration for input audio transcription. Can be set to null to turn off. + #[serde(skip_serializing_if = "Option::is_none")] + pub input_audio_transcription: Option, + + /// Configuration for turn detection. Can be set to null to turn off. + #[serde(skip_serializing_if = "Option::is_none")] + pub turn_detection: Option, + + /// Tools (functions) available to the model. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + /// How the model chooses tools. + pub tool_choice: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + /// Sampling temperature for the model. + pub temperature: Option, + + /// Maximum number of output tokens for a single assistant response, inclusive of tool calls. + /// Provide an integer between 1 and 4096 to limit output tokens, or "inf" for the maximum available tokens for a given model. + /// Defaults to "inf". + #[serde(skip_serializing_if = "Option::is_none")] + pub max_response_output_tokens: Option, +} diff --git a/clia-async-openai/src/types/run.rs b/clia-async-openai/src/types/run.rs new file mode 100644 index 00000000..8be4ad99 --- /dev/null +++ b/clia-async-openai/src/types/run.rs @@ -0,0 +1,285 @@ +use std::collections::HashMap; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::{error::OpenAIError, types::FunctionCall}; + +use super::{ + AssistantTools, AssistantsApiResponseFormatOption, AssistantsApiToolChoiceOption, + CreateMessageRequest, +}; + +/// Represents an execution run on a [thread](https://platform.openai.com/docs/api-reference/threads). +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `thread.run`. + pub object: String, + /// The Unix timestamp (in seconds) for when the run was created. + pub created_at: i32, + ///The ID of the [thread](https://platform.openai.com/docs/api-reference/threads) that was executed on as a part of this run. + pub thread_id: String, + + /// The ID of the [assistant](https://platform.openai.com/docs/api-reference/assistants) used for execution of this run. + pub assistant_id: Option, + + /// The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`. + pub status: RunStatus, + + /// Details on the action required to continue the run. Will be `null` if no action is required. + pub required_action: Option, + + /// The last error associated with this run. Will be `null` if there are no errors. + pub last_error: Option, + + /// The Unix timestamp (in seconds) for when the run will expire. + pub expires_at: Option, + /// The Unix timestamp (in seconds) for when the run was started. + pub started_at: Option, + /// The Unix timestamp (in seconds) for when the run was cancelled. + pub cancelled_at: Option, + /// The Unix timestamp (in seconds) for when the run failed. + pub failed_at: Option, + ///The Unix timestamp (in seconds) for when the run was completed. + pub completed_at: Option, + + /// Details on why the run is incomplete. Will be `null` if the run is not incomplete. + pub incomplete_details: Option, + + /// The model that the [assistant](https://platform.openai.com/docs/api-reference/assistants) used for this run. + pub model: String, + + /// The instructions that the [assistant](https://platform.openai.com/docs/api-reference/assistants) used for this run. + pub instructions: String, + + /// The list of tools that the [assistant](https://platform.openai.com/docs/api-reference/assistants) used for this run. + pub tools: Vec, + + pub metadata: Option>, + + /// Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). + pub usage: Option, + + /// The sampling temperature used for this run. If not set, defaults to 1. + pub temperature: Option, + + /// The nucleus sampling value used for this run. If not set, defaults to 1. + pub top_p: Option, + + /// The maximum number of prompt tokens specified to have been used over the course of the run. + pub max_prompt_tokens: Option, + + /// The maximum number of completion tokens specified to have been used over the course of the run. + pub max_completion_tokens: Option, + + /// Controls for how a thread will be truncated prior to the run. Use this to control the intial context window of the run. + pub truncation_strategy: Option, + + pub tool_choice: Option, + + /// Whether to enable [parallel function calling](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling) during tool use. + pub parallel_tool_calls: bool, + + pub response_format: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +#[serde(rename_all = "snake_case")] +pub enum TruncationObjectType { + #[default] + Auto, + LastMessages, +} + +/// Thread Truncation Controls +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct TruncationObject { + /// The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`. + pub r#type: TruncationObjectType, + /// The number of most recent messages from the thread when constructing the context for the run. + pub last_messages: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunObjectIncompleteDetails { + /// The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run. + pub reason: RunObjectIncompleteDetailsReason, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum RunObjectIncompleteDetailsReason { + MaxCompletionTokens, + MaxPromptTokens, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum RunStatus { + Queued, + InProgress, + RequiresAction, + Cancelling, + Cancelled, + Failed, + Completed, + Incomplete, + Expired, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RequiredAction { + /// For now, this is always `submit_tool_outputs`. + pub r#type: String, + + pub submit_tool_outputs: SubmitToolOutputs, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct SubmitToolOutputs { + pub tool_calls: Vec, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunToolCallObject { + /// The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs) endpoint. + pub id: String, + /// The type of tool call the output is required for. For now, this is always `function`. + pub r#type: String, + /// The function definition. + pub function: FunctionCall, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct LastError { + /// One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`. + pub code: LastErrorCode, + /// A human-readable description of the error. + pub message: String, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum LastErrorCode { + ServerError, + RateLimitExceeded, + InvalidPrompt, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunCompletionUsage { + /// Number of completion tokens used over the course of the run. + pub completion_tokens: u32, + /// Number of prompt tokens used over the course of the run. + pub prompt_tokens: u32, + /// Total number of tokens used (prompt + completion). + pub total_tokens: u32, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "CreateRunRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateRunRequest { + /// The ID of the [assistant](https://platform.openai.com/docs/api-reference/assistants) to use to execute this run. + pub assistant_id: String, + + /// The ID of the [Model](https://platform.openai.com/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// Overrides the [instructions](https://platform.openai.com/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_instructions: Option, + + /// Adds additional messages to the thread before creating the run. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_messages: Option>, + + /// Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// The sampling temperature used for this run. If not set, defaults to 1. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or temperature but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + + /// The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + + /// The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + /// Controls for how a thread will be truncated prior to the run. Use this to control the intial context window of the run. + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation_strategy: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Whether to enable [parallel function calling](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling) during tool use. + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct ModifyRunRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct ListRunsResponse { + pub object: String, + pub data: Vec, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct SubmitToolOutputsRunRequest { + /// A list of tools for which the outputs are being submitted. + pub tool_outputs: Vec, + /// If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + pub stream: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "ToolsOutputsArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ToolsOutputs { + /// The ID of the tool call in the `required_action` object within the run object the output is being submitted for. + pub tool_call_id: Option, + /// The output of the tool call to be submitted to continue the run. + pub output: Option, +} diff --git a/clia-async-openai/src/types/step.rs b/clia-async-openai/src/types/step.rs new file mode 100644 index 00000000..d95b3c18 --- /dev/null +++ b/clia-async-openai/src/types/step.rs @@ -0,0 +1,334 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use super::{FileSearchRankingOptions, ImageFile, LastError, RunStatus}; + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum RunStepType { + MessageCreation, + ToolCalls, +} + +/// Represents a step in execution of a run. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `thread.run.step`. + pub object: String, + /// The Unix timestamp (in seconds) for when the run step was created. + pub created_at: i32, + + /// The ID of the [assistant](https://platform.openai.com/docs/api-reference/assistants) associated with the run step. + pub assistant_id: Option, + + /// The ID of the [thread](https://platform.openai.com/docs/api-reference/threads) that was run. + pub thread_id: String, + + /// The ID of the [run](https://platform.openai.com/docs/api-reference/runs) that this run step is a part of. + pub run_id: String, + + /// The type of run step, which can be either `message_creation` or `tool_calls`. + pub r#type: RunStepType, + + /// The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`. + pub status: RunStatus, + + /// The details of the run step. + pub step_details: StepDetails, + + /// The last error associated with this run. Will be `null` if there are no errors. + pub last_error: Option, + + ///The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired. + pub expires_at: Option, + + /// The Unix timestamp (in seconds) for when the run step was cancelled. + pub cancelled_at: Option, + + /// The Unix timestamp (in seconds) for when the run step failed. + pub failed_at: Option, + + /// The Unix timestamp (in seconds) for when the run step completed. + pub completed_at: Option, + + pub metadata: Option>, + + /// Usage statistics related to the run step. This value will be `null` while the run step's status is `in_progress`. + pub usage: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepCompletionUsage { + /// Number of completion tokens used over the course of the run step. + pub completion_tokens: u32, + /// Number of prompt tokens used over the course of the run step. + pub prompt_tokens: u32, + /// Total number of tokens used (prompt + completion). + pub total_tokens: u32, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum StepDetails { + MessageCreation(RunStepDetailsMessageCreationObject), + ToolCalls(RunStepDetailsToolCallsObject), +} + +/// Details of the message creation by the run step. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsMessageCreationObject { + pub message_creation: MessageCreation, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct MessageCreation { + /// The ID of the message that was created by this run step. + pub message_id: String, +} + +/// Details of the tool call. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsObject { + /// An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + pub tool_calls: Vec, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum RunStepDetailsToolCalls { + /// Details of the Code Interpreter tool call the run step was involved in. + CodeInterpreter(RunStepDetailsToolCallsCodeObject), + FileSearch(RunStepDetailsToolCallsFileSearchObject), + Function(RunStepDetailsToolCallsFunctionObject), +} + +/// Code interpreter tool call +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsCodeObject { + /// The ID of the tool call. + pub id: String, + + /// The Code Interpreter tool call definition. + pub code_interpreter: CodeInterpreter, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct CodeInterpreter { + /// The input to the Code Interpreter tool call. + pub input: String, + /// The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. + pub outputs: Vec, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "lowercase")] +pub enum CodeInterpreterOutput { + /// Code interpreter log output + Logs(RunStepDetailsToolCallsCodeOutputLogsObject), + /// Code interpreter image output + Image(RunStepDetailsToolCallsCodeOutputImageObject), +} + +/// Text output from the Code Interpreter tool call as part of a run step. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsCodeOutputLogsObject { + /// The text output from the Code Interpreter tool call. + pub logs: String, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsCodeOutputImageObject { + /// The [file](https://platform.openai.com/docs/api-reference/files) ID of the image. + pub image: ImageFile, +} + +/// File search tool call +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsFileSearchObject { + /// The ID of the tool call object. + pub id: String, + /// For now, this is always going to be an empty object. + pub file_search: RunStepDetailsToolCallsFileSearchObjectFileSearch, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsFileSearchObjectFileSearch { + pub ranking_options: Option, + /// The results of the file search. + pub results: Option>, +} + +/// A result instance of the file search. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsFileSearchResultObject { + /// The ID of the file that result was found in. + pub file_id: String, + /// The name of the file that result was found in. + pub file_name: String, + /// The score of the result. All values must be a floating point number between 0 and 1. + pub score: f32, + /// The content of the result that was found. The content is only included if requested via the include query parameter. + pub content: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsFileSearchResultObjectContent { + // note: type is text hence omitted from struct + /// The text content of the file. + pub text: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDetailsToolCallsFunctionObject { + /// The ID of the tool call object. + pub id: String, + /// he definition of the function that was called. + pub function: RunStepFunctionObject, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepFunctionObject { + /// The name of the function. + pub name: String, + /// The arguments passed to the function. + pub arguments: String, + /// The output of the function. This will be `null` if the outputs have not been [submitted](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs) yet. + pub output: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepFunctionObjectDelta { + /// The name of the function. + pub name: Option, + /// The arguments passed to the function. + pub arguments: Option, + /// The output of the function. This will be `null` if the outputs have not been [submitted](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs) yet. + pub output: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ListRunStepsResponse { + pub object: String, + pub data: Vec, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, +} + +/// Represents a run step delta i.e. any changed fields on a run step during streaming. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaObject { + /// The identifier of the run step, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `thread.run.step.delta`. + pub object: String, + /// The delta containing the fields that have changed on the run step. + pub delta: RunStepDelta, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDelta { + pub step_details: DeltaStepDetails, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum DeltaStepDetails { + MessageCreation(RunStepDeltaStepDetailsMessageCreationObject), + ToolCalls(RunStepDeltaStepDetailsToolCallsObject), +} + +/// Details of the message creation by the run step. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaStepDetailsMessageCreationObject { + pub message_creation: Option, +} + +/// Details of the tool call. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaStepDetailsToolCallsObject { + /// An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + pub tool_calls: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum RunStepDeltaStepDetailsToolCalls { + CodeInterpreter(RunStepDeltaStepDetailsToolCallsCodeObject), + FileSearch(RunStepDeltaStepDetailsToolCallsFileSearchObject), + Function(RunStepDeltaStepDetailsToolCallsFunctionObject), +} + +/// Details of the Code Interpreter tool call the run step was involved in. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaStepDetailsToolCallsCodeObject { + /// The index of the tool call in the tool calls array. + pub index: u32, + /// The ID of the tool call. + pub id: Option, + /// The Code Interpreter tool call definition. + pub code_interpreter: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct DeltaCodeInterpreter { + /// The input to the Code Interpreter tool call. + pub input: Option, + /// The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. + pub outputs: Option>, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "lowercase")] +pub enum DeltaCodeInterpreterOutput { + Logs(RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject), + Image(RunStepDeltaStepDetailsToolCallsCodeOutputImageObject), +} + +/// Text output from the Code Interpreter tool call as part of a run step. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject { + /// The index of the output in the outputs array. + pub index: u32, + /// The text output from the Code Interpreter tool call. + pub logs: Option, +} + +/// Code interpreter image output +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaStepDetailsToolCallsCodeOutputImageObject { + /// The index of the output in the outputs array. + pub index: u32, + + pub image: Option, +} + +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaStepDetailsToolCallsFileSearchObject { + /// The index of the tool call in the tool calls array. + pub index: u32, + /// The ID of the tool call object. + pub id: Option, + /// For now, this is always going to be an empty object. + pub file_search: Option, +} + +/// Function tool call +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct RunStepDeltaStepDetailsToolCallsFunctionObject { + /// The index of the tool call in the tool calls array. + pub index: u32, + /// The ID of the tool call object. + pub id: Option, + /// The definition of the function that was called. + pub function: Option, +} diff --git a/clia-async-openai/src/types/thread.rs b/clia-async-openai/src/types/thread.rs new file mode 100644 index 00000000..199d8a46 --- /dev/null +++ b/clia-async-openai/src/types/thread.rs @@ -0,0 +1,134 @@ +use std::collections::HashMap; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +use super::{ + AssistantToolResources, AssistantTools, AssistantsApiResponseFormatOption, + AssistantsApiToolChoiceOption, CreateAssistantToolResources, CreateMessageRequest, + TruncationObject, +}; + +/// Represents a thread that contains [messages](https://platform.openai.com/docs/api-reference/messages). +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ThreadObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `thread`. + pub object: String, + /// The Unix timestamp (in seconds) for when the thread was created. + pub created_at: i32, + + /// A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + pub tool_resources: Option, + + pub metadata: Option>, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "CreateThreadRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateThreadRequest { + /// A list of [messages](https://platform.openai.com/docs/api-reference/messages) to start the thread with. + #[serde(skip_serializing_if = "Option::is_none")] + pub messages: Option>, + + /// A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_resources: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct ModifyThreadRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_resources: Option, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)] +pub struct DeleteThreadResponse { + pub id: String, + pub deleted: bool, + pub object: String, +} + +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "CreateThreadAndRunRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateThreadAndRunRequest { + /// The ID of the [assistant](https://platform.openai.com/docs/api-reference/assistants) to use to execute this run. + pub assistant_id: String, + + /// If no thread is provided, an empty thread will be created. + #[serde(skip_serializing_if = "Option::is_none")] + pub thread: Option, + + /// The ID of the [Model](https://platform.openai.com/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_resources: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or temperature but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + + /// The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + + /// The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + /// Controls for how a thread will be truncated prior to the run. Use this to control the intial context window of the run. + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation_strategy: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Whether to enable [parallel function calling](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling) during tool use. + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} diff --git a/clia-async-openai/src/types/upload.rs b/clia-async-openai/src/types/upload.rs new file mode 100644 index 00000000..eb91c0e1 --- /dev/null +++ b/clia-async-openai/src/types/upload.rs @@ -0,0 +1,126 @@ +use crate::error::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use super::{InputSource, OpenAIFile}; + +/// Request to create an upload object that can accept byte chunks in the form of Parts. +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "CreateUploadRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateUploadRequest { + /// The name of the file to upload. + pub filename: String, + + /// The intended purpose of the uploaded file. + /// + /// See the [documentation on File purposes](https://platform.openai.com/docs/api-reference/files/create#files-create-purpose). + pub purpose: UploadPurpose, + + /// The number of bytes in the file you are uploading. + pub bytes: u64, + + /// The MIME type of the file. + /// + /// This must fall within the supported MIME types for your file purpose. See the supported MIME + /// types for assistants and vision. + pub mime_type: String, +} + +/// The intended purpose of the uploaded file. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "snake_case")] +pub enum UploadPurpose { + /// For use with Assistants and Message files + Assistants, + /// For Assistants image file inputs + Vision, + /// For use with the Batch API + Batch, + /// For use with Fine-tuning + #[default] + FineTune, +} + +/// The Upload object can accept byte chunks in the form of Parts. +#[derive(Debug, Serialize, Deserialize)] +pub struct Upload { + /// The Upload unique identifier, which can be referenced in API endpoints + pub id: String, + + /// The Unix timestamp (in seconds) for when the Upload was created + pub created_at: u32, + + /// The name of the file to be uploaded + pub filename: String, + + /// The intended number of bytes to be uploaded + pub bytes: u64, + + /// The intended purpose of the file. [Pelase refer here]([Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values.) + pub purpose: UploadPurpose, + + /// The status of the Upload. + pub status: UploadStatus, + + /// The Unix timestamp (in seconds) for when the Upload was created + pub expires_at: u32, + + /// The object type, which is always "upload" + pub object: String, + + /// The ready File object after the Upload is completed + #[serde(skip_serializing_if = "Option::is_none")] + pub file: Option, +} + +/// The status of an upload +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum UploadStatus { + /// Upload is pending + Pending, + /// Upload has completed successfully + Completed, + /// Upload was cancelled + Cancelled, + /// Upload has expired + Expired, +} + +/// The upload Part represents a chunk of bytes we can add to an Upload object. +#[derive(Debug, Serialize, Deserialize)] +pub struct UploadPart { + /// The upload Part unique identifier, which can be referenced in API endpoints + pub id: String, + + /// The Unix timestamp (in seconds) for when the Part was created + pub created_at: u32, + + /// The ID of the Upload object that this Part was added to + pub upload_id: String, + + /// The object type, which is always `upload.part` + pub object: String, +} + +/// Request parameters for adding a part to an Upload +#[derive(Debug, Clone)] +pub struct AddUploadPartRequest { + /// The chunk of bytes for this Part + pub data: InputSource, +} + +/// Request parameters for completing an Upload +#[derive(Debug, Serialize)] +pub struct CompleteUploadRequest { + /// The ordered list of Part IDs + pub part_ids: Vec, + + /// The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect + #[serde(skip_serializing_if = "Option::is_none")] + pub md5: Option, +} diff --git a/clia-async-openai/src/types/users.rs b/clia-async-openai/src/types/users.rs new file mode 100644 index 00000000..5fd0760c --- /dev/null +++ b/clia-async-openai/src/types/users.rs @@ -0,0 +1,51 @@ +use crate::types::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use super::OrganizationRole; + +/// Represents an individual `user` within an organization. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct User { + /// The object type, which is always `organization.user` + pub object: String, + /// The identifier, which can be referenced in API endpoints + pub id: String, + /// The name of the user + pub name: String, + /// The email address of the user + pub email: String, + /// `owner` or `reader` + pub role: OrganizationRole, + /// The Unix timestamp (in seconds) of when the users was added. + pub added_at: u32, +} + +/// A list of `User` objects. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct UserListResponse { + pub object: String, + pub data: Vec, + pub first_id: String, + pub last_id: String, + pub has_more: bool, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Builder)] +#[builder(name = "UserRoleUpdateRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option))] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct UserRoleUpdateRequest { + /// `owner` or `reader` + pub role: OrganizationRole, +} + +/// Confirmation of the deleted user +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct UserDeleteResponse { + pub object: String, + pub id: String, + pub deleted: bool, +} diff --git a/clia-async-openai/src/types/vector_store.rs b/clia-async-openai/src/types/vector_store.rs new file mode 100644 index 00000000..c4c93481 --- /dev/null +++ b/clia-async-openai/src/types/vector_store.rs @@ -0,0 +1,271 @@ +use std::collections::HashMap; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +use super::StaticChunkingStrategy; + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "CreateVectorStoreRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateVectorStoreRequest { + /// A list of [File](https://platform.openai.com/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. + #[serde(skip_serializing_if = "Option::is_none")] + pub file_ids: Option>, + /// The name of the vector store. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// The expiration policy for a vector store. + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_after: Option, + + /// The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty. + #[serde(skip_serializing_if = "Option::is_none")] + pub chunking_strategy: Option, + + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +#[serde(tag = "type")] +pub enum VectorStoreChunkingStrategy { + /// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + #[default] + #[serde(rename = "auto")] + Auto, + #[serde(rename = "static")] + Static { + #[serde(rename = "static")] + config: StaticChunkingStrategy, + }, +} + +/// Vector store expiration policy +#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] +pub struct VectorStoreExpirationAfter { + /// Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`. + pub anchor: String, + /// The number of days after the anchor time that the vector store will expire. + pub days: u16, // min: 1, max: 365 +} + +/// A vector store is a collection of processed files can be used by the `file_search` tool. +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct VectorStoreObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `vector_store`. + pub object: String, + /// The Unix timestamp (in seconds) for when the vector store was created. + pub created_at: u32, + /// The name of the vector store. + pub name: Option, + /// The total number of bytes used by the files in the vector store. + pub usage_bytes: u64, + pub file_counts: VectorStoreFileCounts, + /// The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use. + pub status: VectorStoreStatus, + pub expires_after: Option, + /// The Unix timestamp (in seconds) for when the vector store will expire. + pub expires_at: Option, + /// The Unix timestamp (in seconds) for when the vector store was last active. + pub last_active_at: Option, + + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreStatus { + Expired, + InProgress, + Completed, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct VectorStoreFileCounts { + /// The number of files that are currently being processed. + pub in_progress: u32, + /// The number of files that have been successfully processed. + pub completed: u32, + /// The number of files that have failed to process. + pub failed: u32, + /// The number of files that were cancelled. + pub cancelled: u32, + /// The total number of files. + pub total: u32, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct ListVectorStoresResponse { + pub object: String, + pub data: Vec, + pub first_id: Option, + pub last_id: Option, + pub has_more: bool, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct DeleteVectorStoreResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "UpdateVectorStoreRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct UpdateVectorStoreRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_after: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct ListVectorStoreFilesResponse { + pub object: String, + pub data: Vec, + pub first_id: String, + pub last_id: String, + pub has_more: bool, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct VectorStoreFileObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `vector_store.file`. + pub object: String, + /// The total vector store usage in bytes. Note that this may be different from the original file size. + pub usage_bytes: u64, + /// The Unix timestamp (in seconds) for when the vector store file was created. + pub created_at: u32, + /// The ID of the [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) that the [File](https://platform.openai.com/docs/api-reference/files) is attached to. + pub vector_store_id: String, + /// The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use. + pub status: VectorStoreFileStatus, + /// The last error associated with this vector store file. Will be `null` if there are no errors. + pub last_error: Option, + /// The strategy used to chunk the file. + pub chunking_strategy: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreFileStatus { + InProgress, + Completed, + Cancelled, + Failed, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct VectorStoreFileError { + pub code: VectorStoreFileErrorCode, + /// A human-readable description of the error. + pub message: String, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreFileErrorCode { + ServerError, + UnsupportedFile, + InvalidFile, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "lowercase")] +pub enum VectorStoreFileObjectChunkingStrategy { + /// This is returned when the chunking strategy is unknown. Typically, this is because the file was indexed before the `chunking_strategy` concept was introduced in the API. + Other, + Static { + r#static: StaticChunkingStrategy, + }, +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] +#[builder(name = "CreateVectorStoreFileRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateVectorStoreFileRequest { + /// A [File](https://platform.openai.com/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files. + pub file_id: String, + pub chunking_strategy: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct DeleteVectorStoreFileResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} + +#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)] +#[builder(name = "CreateVectorStoreFileBatchRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateVectorStoreFileBatchRequest { + /// A list of [File](https://platform.openai.com/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. + pub file_ids: Vec, // minItems: 1, maxItems: 500 + pub chunking_strategy: Option, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreFileBatchStatus { + InProgress, + Completed, + Cancelled, + Failed, +} + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct VectorStoreFileBatchCounts { + /// The number of files that are currently being processed. + pub in_progress: u32, + /// The number of files that have been successfully processed. + pub completed: u32, + /// The number of files that have failed to process. + pub failed: u32, + /// The number of files that were cancelled. + pub cancelled: u32, + /// The total number of files. + pub total: u32, +} + +/// A batch of files attached to a vector store. +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct VectorStoreFileBatchObject { + /// The identifier, which can be referenced in API endpoints. + pub id: String, + /// The object type, which is always `vector_store.file_batch`. + pub object: String, + /// The Unix timestamp (in seconds) for when the vector store files batch was created. + pub created_at: u32, + /// The ID of the [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) that the [File](https://platform.openai.com/docs/api-reference/files) is attached to. + pub vector_store_id: String, + /// The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`. + pub status: VectorStoreFileBatchStatus, + pub file_counts: VectorStoreFileBatchCounts, +} diff --git a/clia-async-openai/src/uploads.rs b/clia-async-openai/src/uploads.rs new file mode 100644 index 00000000..ba3cced1 --- /dev/null +++ b/clia-async-openai/src/uploads.rs @@ -0,0 +1,90 @@ +use crate::{ + config::Config, + error::OpenAIError, + types::{AddUploadPartRequest, CompleteUploadRequest, CreateUploadRequest, Upload, UploadPart}, + Client, +}; + +/// Allows you to upload large files in multiple parts. +pub struct Uploads<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Uploads<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Creates an intermediate [Upload](https://platform.openai.com/docs/api-reference/uploads/object) object that + /// you can add [Parts](https://platform.openai.com/docs/api-reference/uploads/part-object) to. Currently, + /// an Upload can accept at most 8 GB in total and expires after an hour after you create it. + /// + /// Once you complete the Upload, we will create a [File](https://platform.openai.com/docs/api-reference/files/object) + /// object that contains all the parts you uploaded. This File is usable in the rest of our platform as a regular File object. + /// + /// For certain `purpose`s, the correct `mime_type` must be specified. Please refer to documentation for the + /// supported MIME types for your use case: + /// - [Assistants](https://platform.openai.com/docs/assistants/tools/file-search/supported-files) + /// + /// For guidance on the proper filename extensions for each purpose, please follow the documentation on + /// [creating a File](https://platform.openai.com/docs/api-reference/files/create). + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create(&self, request: CreateUploadRequest) -> Result { + self.client.post("/uploads", request).await + } + + /// Adds a [Part](https://platform.openai.com/docs/api-reference/uploads/part-object) to an + /// [Upload](https://platform.openai.com/docs/api-reference/uploads/object) object. + /// A Part represents a chunk of bytes from the file you are trying to upload. + /// + /// Each Part can be at most 64 MB, and you can add Parts until you hit the Upload maximum of 8 GB. + /// + /// It is possible to add multiple Parts in parallel. You can decide the intended order of the Parts + /// when you [complete the Upload](https://platform.openai.com/docs/api-reference/uploads/complete). + #[crate::byot( + T0 = std::fmt::Display, + T1 = Clone, + R = serde::de::DeserializeOwned, + where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom")] + pub async fn add_part( + &self, + upload_id: &str, + request: AddUploadPartRequest, + ) -> Result { + self.client + .post_form(&format!("/uploads/{upload_id}/parts"), request) + .await + } + + /// Completes the [Upload](https://platform.openai.com/docs/api-reference/uploads/object). + /// + /// Within the returned Upload object, there is a nested [File](https://platform.openai.com/docs/api-reference/files/object) + /// object that is ready to use in the rest of the platform. + /// + /// You can specify the order of the Parts by passing in an ordered list of the Part IDs. + /// + /// The number of bytes uploaded upon completion must match the number of bytes initially specified + /// when creating the Upload object. No Parts may be added after an Upload is completed. + + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn complete( + &self, + upload_id: &str, + request: CompleteUploadRequest, + ) -> Result { + self.client + .post(&format!("/uploads/{upload_id}/complete"), request) + .await + } + + /// Cancels the Upload. No Parts may be added after an Upload is cancelled. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn cancel(&self, upload_id: &str) -> Result { + self.client + .post( + &format!("/uploads/{upload_id}/cancel"), + serde_json::json!({}), + ) + .await + } +} diff --git a/clia-async-openai/src/users.rs b/clia-async-openai/src/users.rs new file mode 100644 index 00000000..727d3962 --- /dev/null +++ b/clia-async-openai/src/users.rs @@ -0,0 +1,58 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{User, UserDeleteResponse, UserListResponse, UserRoleUpdateRequest}, + Client, +}; + +/// Manage users and their role in an organization. Users will be automatically added to the Default project. +pub struct Users<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Users<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Lists all of the users in the organization. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query("/organization/users", &query) + .await + } + + /// Modifies a user's role in the organization. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn modify( + &self, + user_id: &str, + request: UserRoleUpdateRequest, + ) -> Result { + self.client + .post(format!("/organization/users/{user_id}").as_str(), request) + .await + } + + /// Retrieve a user by their identifier + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, user_id: &str) -> Result { + self.client + .get(format!("/organization/users/{user_id}").as_str()) + .await + } + + /// Deletes a user from the organization. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, user_id: &str) -> Result { + self.client + .delete(format!("/organizations/users/{user_id}").as_str()) + .await + } +} diff --git a/clia-async-openai/src/util.rs b/clia-async-openai/src/util.rs new file mode 100644 index 00000000..0668aec6 --- /dev/null +++ b/clia-async-openai/src/util.rs @@ -0,0 +1,75 @@ +use std::path::Path; + +use reqwest::Body; +use tokio::fs::File; +use tokio_util::codec::{BytesCodec, FramedRead}; + +use crate::error::OpenAIError; +use crate::types::InputSource; + +pub(crate) async fn file_stream_body(source: InputSource) -> Result { + let body = match source { + InputSource::Path { path } => { + let file = File::open(path) + .await + .map_err(|e| OpenAIError::FileReadError(e.to_string()))?; + let stream = FramedRead::new(file, BytesCodec::new()); + Body::wrap_stream(stream) + } + _ => { + return Err(OpenAIError::FileReadError( + "Cannot create stream from non-file source".to_string(), + )) + } + }; + Ok(body) +} + +/// Creates the part for the given file for multipart upload. +pub(crate) async fn create_file_part( + source: InputSource, +) -> Result { + let (stream, file_name) = match source { + InputSource::Path { path } => { + let file_name = path + .file_name() + .ok_or_else(|| { + OpenAIError::FileReadError(format!( + "cannot extract file name from {}", + path.display() + )) + })? + .to_str() + .unwrap() + .to_string(); + + ( + file_stream_body(InputSource::Path { path }).await?, + file_name, + ) + } + InputSource::Bytes { filename, bytes } => (Body::from(bytes), filename), + InputSource::VecU8 { filename, vec } => (Body::from(vec), filename), + }; + + let file_part = reqwest::multipart::Part::stream(stream) + .file_name(file_name) + .mime_str("application/octet-stream") + .unwrap(); + + Ok(file_part) +} + +pub(crate) fn create_all_dir>(dir: P) -> Result<(), OpenAIError> { + let exists = match Path::try_exists(dir.as_ref()) { + Ok(exists) => exists, + Err(e) => return Err(OpenAIError::FileSaveError(e.to_string())), + }; + + if !exists { + std::fs::create_dir_all(dir.as_ref()) + .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; + } + + Ok(()) +} diff --git a/clia-async-openai/src/vector_store_file_batches.rs b/clia-async-openai/src/vector_store_file_batches.rs new file mode 100644 index 00000000..8e1384a9 --- /dev/null +++ b/clia-async-openai/src/vector_store_file_batches.rs @@ -0,0 +1,90 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + CreateVectorStoreFileBatchRequest, ListVectorStoreFilesResponse, VectorStoreFileBatchObject, + }, + Client, +}; + +/// Vector store file batches represent operations to add multiple files to a vector store. +/// +/// Related guide: [File Search](https://platform.openai.com/docs/assistants/tools/file-search) +pub struct VectorStoreFileBatches<'c, C: Config> { + client: &'c Client, + pub vector_store_id: String, +} + +impl<'c, C: Config> VectorStoreFileBatches<'c, C> { + pub fn new(client: &'c Client, vector_store_id: &str) -> Self { + Self { + client, + vector_store_id: vector_store_id.into(), + } + } + + /// Create vector store file batch + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateVectorStoreFileBatchRequest, + ) -> Result { + self.client + .post( + &format!("/vector_stores/{}/file_batches", &self.vector_store_id), + request, + ) + .await + } + + /// Retrieves a vector store file batch. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve( + &self, + batch_id: &str, + ) -> Result { + self.client + .get(&format!( + "/vector_stores/{}/file_batches/{batch_id}", + &self.vector_store_id + )) + .await + } + + /// Cancel a vector store file batch. This attempts to cancel the processing of files in this batch as soon as possible. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn cancel(&self, batch_id: &str) -> Result { + self.client + .post( + &format!( + "/vector_stores/{}/file_batches/{batch_id}/cancel", + &self.vector_store_id + ), + serde_json::json!({}), + ) + .await + } + + /// Returns a list of vector store files in a batch. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list( + &self, + batch_id: &str, + query: &Q, + ) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + &format!( + "/vector_stores/{}/file_batches/{batch_id}/files", + &self.vector_store_id + ), + &query, + ) + .await + } +} diff --git a/clia-async-openai/src/vector_store_files.rs b/clia-async-openai/src/vector_store_files.rs new file mode 100644 index 00000000..b799eb0b --- /dev/null +++ b/clia-async-openai/src/vector_store_files.rs @@ -0,0 +1,134 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + CreateVectorStoreFileRequest, DeleteVectorStoreFileResponse, ListVectorStoreFilesResponse, + VectorStoreFileObject, + }, + Client, +}; + +/// Vector store files represent files inside a vector store. +/// +/// Related guide: [File Search](https://platform.openai.com/docs/assistants/tools/file-search) +pub struct VectorStoreFiles<'c, C: Config> { + client: &'c Client, + pub vector_store_id: String, +} + +impl<'c, C: Config> VectorStoreFiles<'c, C> { + pub fn new(client: &'c Client, vector_store_id: &str) -> Self { + Self { + client, + vector_store_id: vector_store_id.into(), + } + } + + /// Create a vector store file by attaching a [File](https://platform.openai.com/docs/api-reference/files) to a [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object). + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateVectorStoreFileRequest, + ) -> Result { + self.client + .post( + &format!("/vector_stores/{}/files", &self.vector_store_id), + request, + ) + .await + } + + /// Retrieves a vector store file. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, file_id: &str) -> Result { + self.client + .get(&format!( + "/vector_stores/{}/files/{file_id}", + &self.vector_store_id + )) + .await + } + + /// Delete a vector store file. This will remove the file from the vector store but the file itself will not be deleted. To delete the file, use the [delete file](https://platform.openai.com/docs/api-reference/files/delete) endpoint. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete( + &self, + file_id: &str, + ) -> Result { + self.client + .delete(&format!( + "/vector_stores/{}/files/{file_id}", + &self.vector_store_id + )) + .await + } + + /// Returns a list of vector store files. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query( + &format!("/vector_stores/{}/files", &self.vector_store_id), + &query, + ) + .await + } +} + +#[cfg(test)] +mod tests { + use crate::types::{CreateFileRequest, CreateVectorStoreRequest, FileInput, FilePurpose}; + use crate::Client; + + #[tokio::test] + async fn vector_store_file_creation_and_deletion( + ) -> Result<(), Box> { + let client = Client::new(); + + // Create a file + let file_handle = client + .files() + .create(CreateFileRequest { + file: FileInput::from_vec_u8( + String::from("meow.txt"), + String::from(":3").into_bytes(), + ), + purpose: FilePurpose::Assistants, + }) + .await?; + + // Create a vector store + let vector_store_handle = client + .vector_stores() + .create(CreateVectorStoreRequest { + file_ids: Some(vec![file_handle.id.clone()]), + name: None, + expires_after: None, + chunking_strategy: None, + metadata: None, + }) + .await?; + let vector_store_file = client + .vector_stores() + .files(&vector_store_handle.id) + .retrieve(&file_handle.id) + .await?; + + assert_eq!(vector_store_file.id, file_handle.id); + // Delete the vector store + client + .vector_stores() + .delete(&vector_store_handle.id) + .await?; + + // Delete the file + client.files().delete(&file_handle.id).await?; + + Ok(()) + } +} diff --git a/clia-async-openai/src/vector_stores.rs b/clia-async-openai/src/vector_stores.rs new file mode 100644 index 00000000..0fa4d1d8 --- /dev/null +++ b/clia-async-openai/src/vector_stores.rs @@ -0,0 +1,81 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::{ + CreateVectorStoreRequest, DeleteVectorStoreResponse, ListVectorStoresResponse, + UpdateVectorStoreRequest, VectorStoreObject, + }, + vector_store_file_batches::VectorStoreFileBatches, + Client, VectorStoreFiles, +}; + +pub struct VectorStores<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> VectorStores<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// [VectorStoreFiles] API group + pub fn files(&self, vector_store_id: &str) -> VectorStoreFiles { + VectorStoreFiles::new(self.client, vector_store_id) + } + + /// [VectorStoreFileBatches] API group + pub fn file_batches(&self, vector_store_id: &str) -> VectorStoreFileBatches { + VectorStoreFileBatches::new(self.client, vector_store_id) + } + + /// Create a vector store. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateVectorStoreRequest, + ) -> Result { + self.client.post("/vector_stores", request).await + } + + /// Retrieves a vector store. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, vector_store_id: &str) -> Result { + self.client + .get(&format!("/vector_stores/{vector_store_id}")) + .await + } + + /// Returns a list of vector stores. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client.get_with_query("/vector_stores", &query).await + } + + /// Delete a vector store. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete( + &self, + vector_store_id: &str, + ) -> Result { + self.client + .delete(&format!("/vector_stores/{vector_store_id}")) + .await + } + + /// Modifies a vector store. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn update( + &self, + vector_store_id: &str, + request: UpdateVectorStoreRequest, + ) -> Result { + self.client + .post(&format!("/vector_stores/{vector_store_id}"), request) + .await + } +} diff --git a/clia-async-openai/tests/boxed_future.rs b/clia-async-openai/tests/boxed_future.rs new file mode 100644 index 00000000..62893629 --- /dev/null +++ b/clia-async-openai/tests/boxed_future.rs @@ -0,0 +1,43 @@ +use futures::future::{BoxFuture, FutureExt}; +use futures::StreamExt; + +use clia_async_openai::types::{CompletionResponseStream, CreateCompletionRequestArgs}; +use clia_async_openai::Client; + +#[tokio::test] +async fn boxed_future_test() { + fn interpret_bool(token_stream: &mut CompletionResponseStream) -> BoxFuture<'_, bool> { + async move { + while let Some(response) = token_stream.next().await { + match response { + Ok(response) => { + let token_str = &response.choices[0].text.trim(); + if !token_str.is_empty() { + return token_str.contains("yes") || token_str.contains("Yes"); + } + } + Err(e) => eprintln!("Error: {e}"), + } + } + false + } + .boxed() + } + + let client = Client::new(); + + let request = CreateCompletionRequestArgs::default() + .model("gpt-3.5-turbo-instruct") + .n(1) + .prompt("does 2 and 2 add to four? (yes/no):\n") + .stream(true) + .logprobs(3) + .max_tokens(64_u32) + .build() + .unwrap(); + + let mut stream = client.completions().create_stream(request).await.unwrap(); + + let result = interpret_bool(&mut stream).await; + assert!(result); +} diff --git a/clia-async-openai/tests/bring-your-own-type.rs b/clia-async-openai/tests/bring-your-own-type.rs new file mode 100644 index 00000000..001a43c4 --- /dev/null +++ b/clia-async-openai/tests/bring-your-own-type.rs @@ -0,0 +1,444 @@ +#![allow(dead_code)] +//! The purpose of this test to make sure that all _byot methods compiles with custom types. +use std::pin::Pin; + +use async_openai::{error::OpenAIError, Client}; +use futures::Stream; +use serde_json::{json, Value}; + +impl async_openai::traits::AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + async fn try_from(_value: MyJson) -> Result { + Ok(reqwest::multipart::Form::new()) + } +} + +#[derive(Clone)] +pub struct MyJson(Value); + +type MyStreamingType = Pin> + Send>>; + +#[tokio::test] +async fn test_byot_files() { + let client = Client::new(); + + let _r: Result = client.files().create_byot(MyJson(json!({}))).await; + let _r: Result = client.files().list_byot([("limit", "2")]).await; + let _r: Result = client.files().retrieve_byot("file_id").await; + let _r: Result = client.files().delete_byot("file_id").await; +} + +#[tokio::test] +async fn test_byot_assistants() { + let client = Client::new(); + + let _r: Result = client.assistants().create_byot(json!({})).await; + let _r: Result = client.assistants().retrieve_byot("aid").await; + let _r: Result = client.assistants().update_byot("aid", json!({})).await; + let _r: Result = client.assistants().list_byot([("limit", 2)]).await; +} + +#[tokio::test] +async fn test_byot_models() { + let client = Client::new(); + + let _r: Result = client.models().list_byot().await; + let _r: Result = client.models().retrieve_byot("").await; + let _r: Result = client.models().delete_byot(String::new()).await; +} + +#[tokio::test] +async fn test_byot_moderations() { + let client = Client::new(); + + let _r: Result = client.moderations().create_byot(json!({})).await; +} + +#[tokio::test] +async fn test_byot_images() { + let client = Client::new(); + + let _r: Result = client.images().create_byot(json!({})).await; + let _r: Result = client.images().create_edit_byot(MyJson(json!({}))).await; + let _r: Result = client + .images() + .create_variation_byot(MyJson(json!({}))) + .await; +} + +#[tokio::test] +async fn test_byot_chat() { + let client = Client::new(); + + let _r: Result = client.chat().create_byot(json!({})).await; + let _r: Result = + client.chat().create_stream_byot(json!({})).await; +} + +#[tokio::test] +async fn test_byot_completions() { + let client = Client::new(); + + let _r: Result = client.completions().create_byot(json!({})).await; + let _r: Result = + client.completions().create_stream_byot(json!({})).await; +} + +#[tokio::test] +async fn test_byot_audio() { + let client = Client::new(); + + let _r: Result = client.audio().transcribe_byot(MyJson(json!({}))).await; + let _r: Result = client + .audio() + .transcribe_verbose_json_byot(MyJson(json!({}))) + .await; + let _r: Result = client.audio().translate_byot(MyJson(json!({}))).await; + let _r: Result = client + .audio() + .translate_verbose_json_byot(MyJson(json!({}))) + .await; +} + +#[tokio::test] +async fn test_byot_embeddings() { + let client = Client::new(); + + let _r: Result = client.embeddings().create_byot(json!({})).await; + let _r: Result = client.embeddings().create_base64_byot(json!({})).await; +} + +#[tokio::test] +async fn test_byot_fine_tunning() { + let client = Client::new(); + + let _r: Result = client.fine_tuning().create_byot(json!({})).await; + let _r: Result = client + .fine_tuning() + .list_paginated_byot([("limit", "2")]) + .await; + let _r: Result = client + .fine_tuning() + .retrieve_byot("fine_tunning_job_id") + .await; + let _r: Result = + client.fine_tuning().cancel_byot("fine_tuning_job_id").await; + let _r: Result = client + .fine_tuning() + .list_events_byot("fine_tuning_job_id", [("limit", "2")]) + .await; + let _r: Result = client + .fine_tuning() + .list_checkpoints_byot("fine_tuning_job_id", [("limit", "2")]) + .await; +} + +#[derive(Clone, serde::Deserialize)] +pub struct MyThreadJson(Value); + +impl TryFrom for MyThreadJson { + type Error = OpenAIError; + fn try_from(_value: eventsource_stream::Event) -> Result { + Ok(MyThreadJson(json!({}))) + } +} + +type MyThreadStreamingType = Pin> + Send>>; + +#[tokio::test] +async fn test_byot_threads() { + let client = Client::new(); + + let _r: Result = client.threads().create_and_run_byot(json!({})).await; + let _r: Result = + client.threads().create_and_run_stream_byot(json!({})).await; + let _r: Result = client.threads().create_byot(json!({})).await; + let _r: Result = client.threads().retrieve_byot("thread_id").await; + let _r: Result = client.threads().update_byot("thread_id", json!({})).await; + let _r: Result = client.threads().delete_byot("thread_id").await; +} + +#[tokio::test] +async fn test_byot_messages() { + let client = Client::new(); + + let _r: Result = client + .threads() + .messages("thread_id") + .create_byot(json!({})) + .await; + let _r: Result = client + .threads() + .messages("thread_id") + .retrieve_byot("message_id") + .await; + let _r: Result = client + .threads() + .messages("thread_id") + .update_byot("message_id", json!({})) + .await; + let _r: Result = client + .threads() + .messages("thread_id") + .list_byot([("limit", "2")]) + .await; + let _r: Result = client + .threads() + .messages("thread_id") + .delete_byot("message_id") + .await; +} + +#[tokio::test] +async fn test_byot_runs() { + let client = Client::new(); + + let _r: Result = client + .threads() + .runs("thread_id") + .create_byot(json!({})) + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .create_stream_byot(json!({})) + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .retrieve_byot("run_id") + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .update_byot("run_id", json!({})) + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .list_byot([("limit", "2")]) + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .submit_tool_outputs_byot("run_id", json!({})) + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .submit_tool_outputs_stream_byot("run_id", json!({})) + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .cancel_byot("run_id") + .await; +} + +#[tokio::test] +async fn test_byot_run_steps() { + let client = Client::new(); + + let _r: Result = client + .threads() + .runs("thread_id") + .steps("run_id") + .retrieve_byot("step_id") + .await; + let _r: Result = client + .threads() + .runs("thread_id") + .steps("run_id") + .list_byot([("limit", "2")]) + .await; +} + +#[tokio::test] +async fn test_byot_vector_store_files() { + let client = Client::new(); + let _r: Result = client + .vector_stores() + .files("vector_store_id") + .create_byot(json!({})) + .await; + let _r: Result = client + .vector_stores() + .files("vector_store_id") + .retrieve_byot("file_id") + .await; + let _r: Result = client + .vector_stores() + .files("vector_store_id") + .delete_byot("file_id") + .await; + let _r: Result = client + .vector_stores() + .files("vector_store_id") + .list_byot([("limit", "2")]) + .await; +} + +#[tokio::test] +async fn test_byot_vector_store_file_batches() { + let client = Client::new(); + let _r: Result = client + .vector_stores() + .file_batches("vector_store_id") + .create_byot(json!({})) + .await; + let _r: Result = client + .vector_stores() + .file_batches("vector_store_id") + .retrieve_byot("file_id") + .await; + let _r: Result = client + .vector_stores() + .file_batches("vector_store_id") + .cancel_byot("file_id") + .await; + let _r: Result = client + .vector_stores() + .file_batches("vector_store_id") + .list_byot("batch_id", [("limit", "2")]) + .await; +} + +#[tokio::test] +async fn test_byot_batches() { + let client = Client::new(); + let _r: Result = client.batches().create_byot(json!({})).await; + let _r: Result = client.batches().list_byot([("limit", "2")]).await; + let _r: Result = client.batches().retrieve_byot("batch_id").await; + let _r: Result = client.batches().cancel_byot("batch_id").await; +} + +#[tokio::test] +async fn test_byot_audit_logs() { + let client = Client::new(); + let _r: Result = client.audit_logs().get_byot([("limit", "2")]).await; +} + +#[tokio::test] +async fn test_byot_invites() { + let client = Client::new(); + let _r: Result = client.invites().create_byot(json!({})).await; + let _r: Result = client.invites().retrieve_byot("invite_id").await; + let _r: Result = client.invites().delete_byot("invite_id").await; + let _r: Result = client.invites().list_byot([("limit", "2")]).await; +} + +#[tokio::test] +async fn test_byot_projects() { + let client = Client::new(); + + let _r: Result = client.projects().list_byot([("limit", "2")]).await; + let _r: Result = client.projects().create_byot(json!({})).await; + let _r: Result = client.projects().retrieve_byot("project_id").await; + let _r: Result = + client.projects().modify_byot("project_id", json!({})).await; + let _r: Result = client.projects().archive_byot("project_id").await; +} + +#[tokio::test] +async fn test_byot_project_api_keys() { + let client = Client::new(); + + let _r: Result = client + .projects() + .api_keys("project_id") + .list_byot([("query", "2")]) + .await; + + let _r: Result = client + .projects() + .api_keys("project_id") + .retrieve_byot("api_key") + .await; + + let _r: Result = client + .projects() + .api_keys("project_id") + .delete_byot("api_key") + .await; +} + +#[tokio::test] +async fn test_byot_project_service_accounts() { + let client = Client::new(); + + let _r: Result = client + .projects() + .service_accounts("project_id") + .create_byot(json!({})) + .await; + + let _r: Result = client + .projects() + .service_accounts("project_id") + .delete_byot("service_account_id") + .await; + + let _r: Result = client + .projects() + .service_accounts("project_id") + .retrieve_byot("service_account_id") + .await; + + let _r: Result = client + .projects() + .service_accounts("project_id") + .list_byot([("limit", "2")]) + .await; +} + +#[tokio::test] +async fn test_byot_project_users() { + let client = Client::new(); + + let _r: Result = client + .projects() + .users("project_id") + .create_byot(json!({})) + .await; + let _r: Result = client + .projects() + .users("project_id") + .delete_byot("user_id") + .await; + + let _r: Result = client + .projects() + .users("project_id") + .list_byot([("limit", "2")]) + .await; + + let _r: Result = client + .projects() + .users("project_id") + .retrieve_byot("user_id") + .await; +} + +#[tokio::test] +async fn test_byot_uploads() { + let client = Client::new(); + + let _r: Result = client.uploads().create_byot(json!({})).await; + let _r: Result = client + .uploads() + .add_part_byot("upload_id", MyJson(json!({}))) + .await; + let _r: Result = + client.uploads().complete_byot("upload_id", json!({})).await; + let _r: Result = client.uploads().cancel_byot("upload_id").await; +} + +#[tokio::test] +async fn test_byot_users() { + let client = Client::new(); + + let _r: Result = client.users().list_byot([("limit", "2")]).await; + let _r: Result = client.users().modify_byot("user_id", json!({})).await; + let _r: Result = client.users().retrieve_byot("user_id").await; + let _r: Result = client.users().delete_byot("user_id").await; +} diff --git a/clia-async-openai/tests/completion.rs b/clia-async-openai/tests/completion.rs new file mode 100644 index 00000000..72f87284 --- /dev/null +++ b/clia-async-openai/tests/completion.rs @@ -0,0 +1,47 @@ +//! This test is primarily to make sure that macros_rules for From traits are correct. +use clia_async_openai::types::Prompt; + +fn prompt_input(input: T) -> Prompt +where + Prompt: From, +{ + input.into() +} + +#[test] +fn create_prompt_input() { + let prompt = "This is &str prompt"; + let _ = prompt_input(prompt); + + let prompt = "This is String".to_string(); + let _ = prompt_input(&prompt); + let _ = prompt_input(prompt); + + let prompt = vec!["This is first", "This is second"]; + let _ = prompt_input(&prompt); + let _ = prompt_input(prompt); + + let prompt = vec!["First string".to_string(), "Second string".to_string()]; + let _ = prompt_input(&prompt); + let _ = prompt_input(prompt); + + let first = "First".to_string(); + let second = "Second".to_string(); + let prompt = vec![&first, &second]; + let _ = prompt_input(&prompt); + let _ = prompt_input(prompt); + + let prompt = ["first", "second"]; + let _ = prompt_input(&prompt); + let _ = prompt_input(prompt); + + let prompt = ["first".to_string(), "second".to_string()]; + let _ = prompt_input(&prompt); + let _ = prompt_input(prompt); + + let first = "First".to_string(); + let second = "Second".to_string(); + let prompt = [&first, &second]; + let _ = prompt_input(&prompt); + let _ = prompt_input(prompt); +} diff --git a/clia-async-openai/tests/embeddings.rs b/clia-async-openai/tests/embeddings.rs new file mode 100644 index 00000000..305ae01b --- /dev/null +++ b/clia-async-openai/tests/embeddings.rs @@ -0,0 +1,46 @@ +//! This test is primarily to make sure that macros_rules for From traits are correct. +use clia_async_openai::types::EmbeddingInput; + +fn embedding_input(input: T) -> EmbeddingInput +where + EmbeddingInput: From, +{ + input.into() +} + +#[test] +fn create_embedding_input() { + let input = [1, 2, 3]; + let _ = embedding_input(&input); + let _ = embedding_input(input); + + let input = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + let _ = embedding_input(&input); + let _ = embedding_input(input); + + let (s1, s2, s3) = ([1, 2, 3], [4, 5, 6], [7, 8, 9]); + let input = [&s1, &s2, &s3]; + let _ = embedding_input(&input); + let _ = embedding_input(input); + + let input = vec![1, 2, 3]; + let _ = embedding_input(&input); + let _ = embedding_input(input); + + let input = vec![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + let _ = embedding_input(&input); + let _ = embedding_input(input); + + let input = vec![vec![1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11, 12]]; + let _ = embedding_input(&input); + let _ = embedding_input(input); + + let input = [vec![1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11, 12]]; + let _ = embedding_input(&input); + let _ = embedding_input(input); + + let (v1, v2, v3) = (vec![1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11, 12]); + let input = [&v1, &v2, &v3]; + let _ = embedding_input(&input); + let _ = embedding_input(input); +} diff --git a/clia-async-openai/tests/ser_de.rs b/clia-async-openai/tests/ser_de.rs new file mode 100644 index 00000000..02d1bd89 --- /dev/null +++ b/clia-async-openai/tests/ser_de.rs @@ -0,0 +1,28 @@ +use clia_async_openai::types::{ + ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs, + CreateChatCompletionRequest, CreateChatCompletionRequestArgs, +}; + +#[tokio::test] +async fn chat_types_serde() { + let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default() + .messages([ + ChatCompletionRequestSystemMessageArgs::default() + .content("your are a calculator") + .build() + .unwrap() + .into(), + ChatCompletionRequestUserMessageArgs::default() + .content("what is the result of 1+1") + .build() + .unwrap() + .into(), + ]) + .build() + .unwrap(); + // serialize the request + let serialized = serde_json::to_string(&request).unwrap(); + // deserialize the request + let deserialized: CreateChatCompletionRequest = serde_json::from_str(&serialized).unwrap(); + assert_eq!(request, deserialized); +} diff --git a/clia-async-openai/tests/whisper.rs b/clia-async-openai/tests/whisper.rs new file mode 100644 index 00000000..f74f61ee --- /dev/null +++ b/clia-async-openai/tests/whisper.rs @@ -0,0 +1,57 @@ +use clia_async_openai::types::CreateTranslationRequestArgs; +use clia_async_openai::{types::CreateTranscriptionRequestArgs, Client}; +use tokio_test::assert_err; + +#[tokio::test] +async fn transcribe_test() { + let client = Client::new(); + + let request = CreateTranscriptionRequestArgs::default().build().unwrap(); + + let response = client.audio().transcribe(request).await; + + assert_err!(response); // FileReadError("cannot extract file name from ") +} + +#[tokio::test] +async fn transcribe_sendable_test() { + let client = Client::new(); + + // https://github.com/64bit/async-openai/issues/140 + let transcribe = tokio::spawn(async move { + let request = CreateTranscriptionRequestArgs::default().build().unwrap(); + + client.audio().transcribe(request).await + }); + + let response = transcribe.await.unwrap(); + + assert_err!(response); // FileReadError("cannot extract file name from ") +} + +#[tokio::test] +async fn translate_test() { + let client = Client::new(); + + let request = CreateTranslationRequestArgs::default().build().unwrap(); + + let response = client.audio().translate(request).await; + + assert_err!(response); // FileReadError("cannot extract file name from ") +} + +#[tokio::test] +async fn translate_sendable_test() { + let client = Client::new(); + + // https://github.com/64bit/async-openai/issues/140 + let translate = tokio::spawn(async move { + let request = CreateTranslationRequestArgs::default().build().unwrap(); + + client.audio().translate(request).await + }); + + let response = translate.await.unwrap(); + + assert_err!(response); // FileReadError("cannot extract file name from ") +}