From a313a1797c1de14fa011e36b44b4b0f304835696 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Sat, 16 Mar 2024 09:31:18 -0400 Subject: [PATCH 1/2] Implement CreateTranscriptRequest::response_granularities This PR adds support for `AudioResponseFormat::VerboseJson` and `TimestampGranularity`, including updated example code. These were defined as types before, but not fully implemented. Implements 64bit/async-openai#201. --- async-openai/src/types/audio.rs | 77 +++++++++++++++++++++++++++ async-openai/src/types/impls.rs | 21 +++++++- examples/audio-transcribe/src/main.rs | 13 ++++- 3 files changed, 109 insertions(+), 2 deletions(-) diff --git a/async-openai/src/types/audio.rs b/async-openai/src/types/audio.rs index ca655c27..21320743 100644 --- a/async-openai/src/types/audio.rs +++ b/async-openai/src/types/audio.rs @@ -98,7 +98,84 @@ pub struct CreateTranscriptionRequest { #[derive(Debug, Deserialize, Clone, Serialize)] pub struct CreateTranscriptionResponse { + /// Transcribed text. pub text: String, + + /// If [`CreateTranscriptionRequestArgs::response_format`] is set to + /// [`AudioResponseFormat::VerboseJson`], this field will be populated with + /// the name of the language detected in the audio. + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + + /// If [`CreateTranscriptionRequestArgs::response_format`] is set to + /// [`AudioResponseFormat::VerboseJson`], this field will be populated with + /// the duration of the audio in seconds. + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + + /// If [`CreateTranscriptionRequestArgs::response_format`] is set to + /// [`AudioResponseFormat::VerboseJson`] and + /// [`CreateTranscriptionRequestArgs::timestamp_granularities`] contains + /// [`TimestampGranularity::Word`], this field will be populated with the + /// word-level information. + #[serde(skip_serializing_if = "Option::is_none")] + pub words: Option>, + + /// If [`CreateTranscriptionRequestArgs::response_format`] is set to + /// [`AudioResponseFormat::VerboseJson`] and + /// [`CreateTranscriptionRequestArgs::timestamp_granularities`] contains + /// [`TimestampGranularity::Segment`], this field will be populated with the + /// segment-level information. + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, + + #[serde(flatten)] + pub extra: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CreateTranscriptionResponseWord { + /// The word. + pub word: String, + + /// The start time of the word in seconds. + pub start: f32, + + /// The end time of the word in seconds. + pub end: f32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CreateTranscriptionResponseSegment { + /// Unique identifier of the segment. + pub id: i32, + + // Seek offset of the segment. + pub seek: i32, + + /// Start time of the segment in seconds. + pub start: f32, + + /// End time of the segment in seconds. + pub end: f32, + + /// Transcribed text of the segment. + pub text: String, + + /// Token IDs. + pub tokens: Vec, + + /// Temperature parameter used for generating the segment. + pub temperature: f32, + + /// Average log probability of the segment. + pub avg_logprob: f32, + + /// Compression ratio of the segment. + pub compression_ratio: f32, + + /// Probability of no speech in the segment. + pub no_speech_prob: f32, } #[derive(Clone, Default, Debug, Builder, PartialEq, Serialize)] diff --git a/async-openai/src/types/impls.rs b/async-openai/src/types/impls.rs index a77c1cba..991dfbc8 100644 --- a/async-openai/src/types/impls.rs +++ b/async-openai/src/types/impls.rs @@ -23,7 +23,7 @@ use super::{ CreateImageEditRequest, CreateImageVariationRequest, CreateSpeechResponse, CreateTranscriptionRequest, CreateTranslationRequest, DallE2ImageSize, EmbeddingInput, FileInput, FunctionName, Image, ImageInput, ImageModel, ImageSize, ImageUrl, ImagesResponse, - ModerationInput, Prompt, ResponseFormat, Role, Stop, + ModerationInput, Prompt, ResponseFormat, Role, Stop, TimestampGranularity, }; /// for `impl_from!(T, Enum)`, implements @@ -228,6 +228,19 @@ impl Display for AudioResponseFormat { } } +impl Display for TimestampGranularity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + TimestampGranularity::Word => "word", + TimestampGranularity::Segment => "segment", + } + ) + } +} + impl Display for Role { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -642,6 +655,12 @@ impl async_convert::TryFrom for reqwest::multipart:: form = form.text("language", language); } + if let Some(timestamp_granularities) = request.timestamp_granularities { + for tg in timestamp_granularities { + form = form.text("timestamp_granularities[]", tg.to_string()); + } + } + Ok(form) } } diff --git a/examples/audio-transcribe/src/main.rs b/examples/audio-transcribe/src/main.rs index 851ae2eb..808cb054 100644 --- a/examples/audio-transcribe/src/main.rs +++ b/examples/audio-transcribe/src/main.rs @@ -1,4 +1,7 @@ -use async_openai::{types::CreateTranscriptionRequestArgs, Client}; +use async_openai::{ + types::{AudioResponseFormat, CreateTranscriptionRequestArgs, TimestampGranularity}, + Client +}; use std::error::Error; #[tokio::main] @@ -10,11 +13,19 @@ async fn main() -> Result<(), Box> { "./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3", ) .model("whisper-1") + .response_format(AudioResponseFormat::VerboseJson) + .timestamp_granularities(vec![TimestampGranularity::Word, TimestampGranularity::Segment]) .build()?; let response = client.audio().transcribe(request).await?; println!("{}", response.text); + if let Some(words) = &response.words { + println!("- {} words", words.len()); + } + if let Some(segments) = &response.segments { + println!("- {} segments", segments.len()); + } Ok(()) } From 930778ba2892a92830544a7b4a5611b72cdbfa71 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Sun, 24 Mar 2024 14:24:36 -0400 Subject: [PATCH 2/2] Modify transcription API to be more like spec - Rename `CreateTranscriptionRespose` to `CreateTranscriptionResponseJson` (to match API spec) - Add `CreateTranscriptionResponseVerboseJson` and `transcribe_verbose_json` - Add `transcribe_raw` for SRT output - Add `post_form_raw` - Update example code --- async-openai/src/audio.rs | 26 +++++++++- async-openai/src/client.rs | 19 ++++++++ async-openai/src/types/audio.rs | 70 +++++++++++++-------------- examples/audio-transcribe/src/main.rs | 39 ++++++++++++++- 4 files changed, 114 insertions(+), 40 deletions(-) diff --git a/async-openai/src/audio.rs b/async-openai/src/audio.rs index 08ea7ac0..0c3a8add 100644 --- a/async-openai/src/audio.rs +++ b/async-openai/src/audio.rs @@ -1,9 +1,11 @@ +use bytes::Bytes; + use crate::{ config::Config, error::OpenAIError, types::{ CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest, - CreateTranscriptionResponse, CreateTranslationRequest, CreateTranslationResponse, + CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson, CreateTranslationRequest, CreateTranslationResponse, }, Client, }; @@ -23,12 +25,32 @@ impl<'c, C: Config> Audio<'c, C> { pub async fn transcribe( &self, request: CreateTranscriptionRequest, - ) -> Result { + ) -> Result { + self.client + .post_form("/audio/transcriptions", request) + .await + } + + /// Transcribes audio into the input language. + pub async fn transcribe_verbose_json( + &self, + request: CreateTranscriptionRequest, + ) -> Result { self.client .post_form("/audio/transcriptions", request) .await } + /// Transcribes audio into the input language. + pub async fn transcribe_raw( + &self, + request: CreateTranscriptionRequest, + ) -> Result { + self.client + .post_form_raw("/audio/transcriptions", request) + .await + } + /// Translates audio into into English. pub async fn translate( &self, diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index a4480daa..e4b567eb 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -222,6 +222,25 @@ impl Client { self.execute(request_maker).await } + /// POST a form at {path} and return the response body + pub(crate) async fn post_form_raw(&self, path: &str, form: F) -> Result + where + reqwest::multipart::Form: async_convert::TryFrom, + F: Clone, + { + let request_maker = || async { + Ok(self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .multipart(async_convert::TryFrom::try_from(form.clone()).await?) + .build()?) + }; + + self.execute_raw(request_maker).await + } + /// POST a form at {path} and deserialize the response body pub(crate) async fn post_form(&self, path: &str, form: F) -> Result where diff --git a/async-openai/src/types/audio.rs b/async-openai/src/types/audio.rs index 21320743..15519ace 100644 --- a/async-openai/src/types/audio.rs +++ b/async-openai/src/types/audio.rs @@ -96,57 +96,50 @@ pub struct CreateTranscriptionRequest { pub timestamp_granularities: Option>, } +/// Represents a transcription response returned by model, based on the provided +/// input. #[derive(Debug, Deserialize, Clone, Serialize)] -pub struct CreateTranscriptionResponse { - /// Transcribed text. +pub struct CreateTranscriptionResponseJson { + /// The transcribed text. pub text: String, +} - /// If [`CreateTranscriptionRequestArgs::response_format`] is set to - /// [`AudioResponseFormat::VerboseJson`], this field will be populated with - /// the name of the language detected in the audio. - #[serde(skip_serializing_if = "Option::is_none")] - pub language: Option, +/// Represents a verbose json transcription response returned by model, based on +/// the provided input. +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CreateTranscriptionResponseVerboseJson { + /// The language of the input audio. + pub language: String, - /// If [`CreateTranscriptionRequestArgs::response_format`] is set to - /// [`AudioResponseFormat::VerboseJson`], this field will be populated with - /// the duration of the audio in seconds. - #[serde(skip_serializing_if = "Option::is_none")] - pub duration: Option, + /// The duration of the input audio. + pub duration: f32, - /// If [`CreateTranscriptionRequestArgs::response_format`] is set to - /// [`AudioResponseFormat::VerboseJson`] and - /// [`CreateTranscriptionRequestArgs::timestamp_granularities`] contains - /// [`TimestampGranularity::Word`], this field will be populated with the - /// word-level information. - #[serde(skip_serializing_if = "Option::is_none")] - pub words: Option>, + /// The transcribed text. + pub text: String, - /// If [`CreateTranscriptionRequestArgs::response_format`] is set to - /// [`AudioResponseFormat::VerboseJson`] and - /// [`CreateTranscriptionRequestArgs::timestamp_granularities`] contains - /// [`TimestampGranularity::Segment`], this field will be populated with the - /// segment-level information. + /// Extracted words and their corresponding timestamps. #[serde(skip_serializing_if = "Option::is_none")] - pub segments: Option>, + pub words: Option>, - #[serde(flatten)] - pub extra: Option, + /// Segments of the transcribed text and their corresponding details. + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, } #[derive(Debug, Deserialize, Clone, Serialize)] -pub struct CreateTranscriptionResponseWord { - /// The word. +pub struct TranscriptionWord { + /// The text content of the word. pub word: String, - /// The start time of the word in seconds. + /// Start time of the word in seconds. pub start: f32, - /// The end time of the word in seconds. + /// End time of the word in seconds. pub end: f32, } #[derive(Debug, Deserialize, Clone, Serialize)] -pub struct CreateTranscriptionResponseSegment { +pub struct TranscriptionSegment { /// Unique identifier of the segment. pub id: i32, @@ -159,22 +152,25 @@ pub struct CreateTranscriptionResponseSegment { /// End time of the segment in seconds. pub end: f32, - /// Transcribed text of the segment. + /// Text content of the segment. pub text: String, - /// Token IDs. + /// Array of token IDs for the text content. pub tokens: Vec, /// Temperature parameter used for generating the segment. pub temperature: f32, - /// Average log probability of the segment. + /// Average logprob of the segment. If the value is lower than -1, consider + /// the logprobs failed. pub avg_logprob: f32, - /// Compression ratio of the segment. + /// Compression ratio of the segment. If the value is greater than 2.4, + /// consider the compression failed. pub compression_ratio: f32, - /// Probability of no speech in the segment. + /// Probability of no speech in the segment. If the value is higher than 1.0 + /// and the `avg_logprob` is below -1, consider this segment silent. pub no_speech_prob: f32, } diff --git a/examples/audio-transcribe/src/main.rs b/examples/audio-transcribe/src/main.rs index 808cb054..de414b69 100644 --- a/examples/audio-transcribe/src/main.rs +++ b/examples/audio-transcribe/src/main.rs @@ -6,8 +6,30 @@ use std::error::Error; #[tokio::main] async fn main() -> Result<(), Box> { + transcribe_json().await?; + transcribe_verbose_json().await?; + transcribe_srt().await?; + Ok(()) +} + +async fn transcribe_json() -> Result<(), Box> { let client = Client::new(); // Credits and Source for audio: https://www.youtube.com/watch?v=oQnDVqGIv4s + let request = CreateTranscriptionRequestArgs::default() + .file( + "./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3", + ) + .model("whisper-1") + .response_format(AudioResponseFormat::Json) + .build()?; + + let response = client.audio().transcribe(request).await?; + println!("{}", response.text); + Ok(()) +} + +async fn transcribe_verbose_json() -> Result<(), Box> { + let client = Client::new(); let request = CreateTranscriptionRequestArgs::default() .file( "./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3", @@ -17,7 +39,7 @@ async fn main() -> Result<(), Box> { .timestamp_granularities(vec![TimestampGranularity::Word, TimestampGranularity::Segment]) .build()?; - let response = client.audio().transcribe(request).await?; + let response = client.audio().transcribe_verbose_json(request).await?; println!("{}", response.text); if let Some(words) = &response.words { @@ -29,3 +51,18 @@ async fn main() -> Result<(), Box> { Ok(()) } + +async fn transcribe_srt() -> Result<(), Box> { + let client = Client::new(); + let request = CreateTranscriptionRequestArgs::default() + .file( + "./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3", + ) + .model("whisper-1") + .response_format(AudioResponseFormat::Srt) + .build()?; + + let response = client.audio().transcribe_raw(request).await?; + println!("{}", String::from_utf8_lossy(response.as_ref())); + Ok(()) +}