Skip to content

Commit db4c213

Browse files
authored
Update Audio APIs from updated spec (64bit#202)
* Implement CreateTranscriptRequest::response_granularities This PR adds support for `AudioResponseFormat::VerboseJson` and `TimestampGranularity`, including updated example code. These were defined as types before, but not fully implemented. Implements 64bit#201. * Modify transcription API to be more like spec - Rename `CreateTranscriptionRespose` to `CreateTranscriptionResponseJson` (to match API spec) - Add `CreateTranscriptionResponseVerboseJson` and `transcribe_verbose_json` - Add `transcribe_raw` for SRT output - Add `post_form_raw` - Update example code
1 parent e4a428f commit db4c213

File tree

5 files changed

+186
-5
lines changed

5 files changed

+186
-5
lines changed

async-openai/src/audio.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
use bytes::Bytes;
2+
13
use crate::{
24
config::Config,
35
error::OpenAIError,
46
types::{
57
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
6-
CreateTranscriptionResponse, CreateTranslationRequest, CreateTranslationResponse,
8+
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson, CreateTranslationRequest, CreateTranslationResponse,
79
},
810
Client,
911
};
@@ -23,12 +25,32 @@ impl<'c, C: Config> Audio<'c, C> {
2325
pub async fn transcribe(
2426
&self,
2527
request: CreateTranscriptionRequest,
26-
) -> Result<CreateTranscriptionResponse, OpenAIError> {
28+
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
29+
self.client
30+
.post_form("/audio/transcriptions", request)
31+
.await
32+
}
33+
34+
/// Transcribes audio into the input language.
35+
pub async fn transcribe_verbose_json(
36+
&self,
37+
request: CreateTranscriptionRequest,
38+
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
2739
self.client
2840
.post_form("/audio/transcriptions", request)
2941
.await
3042
}
3143

44+
/// Transcribes audio into the input language.
45+
pub async fn transcribe_raw(
46+
&self,
47+
request: CreateTranscriptionRequest,
48+
) -> Result<Bytes, OpenAIError> {
49+
self.client
50+
.post_form_raw("/audio/transcriptions", request)
51+
.await
52+
}
53+
3254
/// Translates audio into into English.
3355
pub async fn translate(
3456
&self,

async-openai/src/client.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,25 @@ impl<C: Config> Client<C> {
222222
self.execute(request_maker).await
223223
}
224224

225+
/// POST a form at {path} and return the response body
226+
pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
227+
where
228+
reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
229+
F: Clone,
230+
{
231+
let request_maker = || async {
232+
Ok(self
233+
.http_client
234+
.post(self.config.url(path))
235+
.query(&self.config.query())
236+
.headers(self.config.headers())
237+
.multipart(async_convert::TryFrom::try_from(form.clone()).await?)
238+
.build()?)
239+
};
240+
241+
self.execute_raw(request_maker).await
242+
}
243+
225244
/// POST a form at {path} and deserialize the response body
226245
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
227246
where

async-openai/src/types/audio.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,84 @@ pub struct CreateTranscriptionRequest {
9696
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
9797
}
9898

99+
/// Represents a transcription response returned by model, based on the provided
100+
/// input.
99101
#[derive(Debug, Deserialize, Clone, Serialize)]
100-
pub struct CreateTranscriptionResponse {
102+
pub struct CreateTranscriptionResponseJson {
103+
/// The transcribed text.
101104
pub text: String,
102105
}
103106

107+
/// Represents a verbose json transcription response returned by model, based on
108+
/// the provided input.
109+
#[derive(Debug, Deserialize, Clone, Serialize)]
110+
pub struct CreateTranscriptionResponseVerboseJson {
111+
/// The language of the input audio.
112+
pub language: String,
113+
114+
/// The duration of the input audio.
115+
pub duration: f32,
116+
117+
/// The transcribed text.
118+
pub text: String,
119+
120+
/// Extracted words and their corresponding timestamps.
121+
#[serde(skip_serializing_if = "Option::is_none")]
122+
pub words: Option<Vec<TranscriptionWord>>,
123+
124+
/// Segments of the transcribed text and their corresponding details.
125+
#[serde(skip_serializing_if = "Option::is_none")]
126+
pub segments: Option<Vec<TranscriptionSegment>>,
127+
}
128+
129+
#[derive(Debug, Deserialize, Clone, Serialize)]
130+
pub struct TranscriptionWord {
131+
/// The text content of the word.
132+
pub word: String,
133+
134+
/// Start time of the word in seconds.
135+
pub start: f32,
136+
137+
/// End time of the word in seconds.
138+
pub end: f32,
139+
}
140+
141+
#[derive(Debug, Deserialize, Clone, Serialize)]
142+
pub struct TranscriptionSegment {
143+
/// Unique identifier of the segment.
144+
pub id: i32,
145+
146+
// Seek offset of the segment.
147+
pub seek: i32,
148+
149+
/// Start time of the segment in seconds.
150+
pub start: f32,
151+
152+
/// End time of the segment in seconds.
153+
pub end: f32,
154+
155+
/// Text content of the segment.
156+
pub text: String,
157+
158+
/// Array of token IDs for the text content.
159+
pub tokens: Vec<i32>,
160+
161+
/// Temperature parameter used for generating the segment.
162+
pub temperature: f32,
163+
164+
/// Average logprob of the segment. If the value is lower than -1, consider
165+
/// the logprobs failed.
166+
pub avg_logprob: f32,
167+
168+
/// Compression ratio of the segment. If the value is greater than 2.4,
169+
/// consider the compression failed.
170+
pub compression_ratio: f32,
171+
172+
/// Probability of no speech in the segment. If the value is higher than 1.0
173+
/// and the `avg_logprob` is below -1, consider this segment silent.
174+
pub no_speech_prob: f32,
175+
}
176+
104177
#[derive(Clone, Default, Debug, Builder, PartialEq, Serialize)]
105178
#[builder(name = "CreateSpeechRequestArgs")]
106179
#[builder(pattern = "mutable")]

async-openai/src/types/impls.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use super::{
2323
CreateImageEditRequest, CreateImageVariationRequest, CreateSpeechResponse,
2424
CreateTranscriptionRequest, CreateTranslationRequest, DallE2ImageSize, EmbeddingInput,
2525
FileInput, FunctionName, Image, ImageInput, ImageModel, ImageSize, ImageUrl, ImagesResponse,
26-
ModerationInput, Prompt, ResponseFormat, Role, Stop,
26+
ModerationInput, Prompt, ResponseFormat, Role, Stop, TimestampGranularity,
2727
};
2828

2929
/// for `impl_from!(T, Enum)`, implements
@@ -228,6 +228,19 @@ impl Display for AudioResponseFormat {
228228
}
229229
}
230230

231+
impl Display for TimestampGranularity {
232+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233+
write!(
234+
f,
235+
"{}",
236+
match self {
237+
TimestampGranularity::Word => "word",
238+
TimestampGranularity::Segment => "segment",
239+
}
240+
)
241+
}
242+
}
243+
231244
impl Display for Role {
232245
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233246
write!(
@@ -642,6 +655,12 @@ impl async_convert::TryFrom<CreateTranscriptionRequest> for reqwest::multipart::
642655
form = form.text("language", language);
643656
}
644657

658+
if let Some(timestamp_granularities) = request.timestamp_granularities {
659+
for tg in timestamp_granularities {
660+
form = form.text("timestamp_granularities[]", tg.to_string());
661+
}
662+
}
663+
645664
Ok(form)
646665
}
647666
}
Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,68 @@
1-
use async_openai::{types::CreateTranscriptionRequestArgs, Client};
1+
use async_openai::{
2+
types::{AudioResponseFormat, CreateTranscriptionRequestArgs, TimestampGranularity},
3+
Client
4+
};
25
use std::error::Error;
36

47
#[tokio::main]
58
async fn main() -> Result<(), Box<dyn Error>> {
9+
transcribe_json().await?;
10+
transcribe_verbose_json().await?;
11+
transcribe_srt().await?;
12+
Ok(())
13+
}
14+
15+
async fn transcribe_json() -> Result<(), Box<dyn Error>> {
616
let client = Client::new();
717
// Credits and Source for audio: https://www.youtube.com/watch?v=oQnDVqGIv4s
818
let request = CreateTranscriptionRequestArgs::default()
919
.file(
1020
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
1121
)
1222
.model("whisper-1")
23+
.response_format(AudioResponseFormat::Json)
1324
.build()?;
1425

1526
let response = client.audio().transcribe(request).await?;
27+
println!("{}", response.text);
28+
Ok(())
29+
}
30+
31+
async fn transcribe_verbose_json() -> Result<(), Box<dyn Error>> {
32+
let client = Client::new();
33+
let request = CreateTranscriptionRequestArgs::default()
34+
.file(
35+
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
36+
)
37+
.model("whisper-1")
38+
.response_format(AudioResponseFormat::VerboseJson)
39+
.timestamp_granularities(vec![TimestampGranularity::Word, TimestampGranularity::Segment])
40+
.build()?;
41+
42+
let response = client.audio().transcribe_verbose_json(request).await?;
1643

1744
println!("{}", response.text);
45+
if let Some(words) = &response.words {
46+
println!("- {} words", words.len());
47+
}
48+
if let Some(segments) = &response.segments {
49+
println!("- {} segments", segments.len());
50+
}
51+
52+
Ok(())
53+
}
54+
55+
async fn transcribe_srt() -> Result<(), Box<dyn Error>> {
56+
let client = Client::new();
57+
let request = CreateTranscriptionRequestArgs::default()
58+
.file(
59+
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
60+
)
61+
.model("whisper-1")
62+
.response_format(AudioResponseFormat::Srt)
63+
.build()?;
1864

65+
let response = client.audio().transcribe_raw(request).await?;
66+
println!("{}", String::from_utf8_lossy(response.as_ref()));
1967
Ok(())
2068
}

0 commit comments

Comments
 (0)