Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ tokio-stream = "0.1.11"
tokio-util = { version = "0.7.7", features = ["codec", "io-util"] }
tracing = "0.1.37"
derive_builder = "0.12.0"
async-convert = "1.0.0"

[dev-dependencies]
tokio-test = "0.4.2"
2 changes: 1 addition & 1 deletion async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
- [x] Microsoft Azure Endpoints
- [x] Models
- [x] Moderations
- Non-streaming requests are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits) by the API server.
- All requests including form submissions (except SSE streaming) are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits) by the API server.
- Ergonomic Rust library with builder pattern for all request objects.

**Note on Azure OpenAI Service**: `async-openai` primarily implements OpenAI APIs, and exposes same library for Azure OpenAI Service too. In reality Azure OpenAI Service provides only subset of OpenAI APIs.
Expand Down
43 changes: 4 additions & 39 deletions async-openai/src/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
CreateTranscriptionRequest, CreateTranscriptionResponse, CreateTranslationRequest,
CreateTranslationResponse,
},
util::create_file_part,
Client,
};

Expand All @@ -25,50 +24,16 @@ impl<'c, C: Config> Audio<'c, C> {
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponse, OpenAIError> {
let audio_part = create_file_part(&request.file.path).await?;

let mut form = reqwest::multipart::Form::new()
.part("file", audio_part)
.text("model", request.model);

if let Some(prompt) = request.prompt {
form = form.text("prompt", prompt);
}

if let Some(response_format) = request.response_format {
form = form.text("response_format", response_format.to_string())
}

if let Some(temperature) = request.temperature {
form = form.text("temperature", temperature.to_string())
}

self.client.post_form("/audio/transcriptions", form).await
self.client
.post_form("/audio/transcriptions", request)
.await
}

/// Translates audio into into English.
pub async fn translate(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponse, OpenAIError> {
let audio_part = create_file_part(&request.file.path).await?;

let mut form = reqwest::multipart::Form::new()
.part("file", audio_part)
.text("model", request.model);

if let Some(prompt) = request.prompt {
form = form.text("prompt", prompt);
}

if let Some(response_format) = request.response_format {
form = form.text("response_format", response_format.to_string())
}

if let Some(temperature) = request.temperature {
form = form.text("temperature", temperature.to_string())
}

self.client.post_form("/audio/translations", form).await
self.client.post_form("/audio/translations", request).await
}
}
206 changes: 94 additions & 112 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ impl<C: Config> Client<C> {
}

/// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
/// Form submissions are not retried.
pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
self.backoff = backoff;
self
Expand Down Expand Up @@ -116,29 +115,33 @@ impl<C: Config> Client<C> {
where
O: DeserializeOwned,
{
let request = self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?;

self.execute(request).await
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};

self.execute(request_maker).await
}

/// Make a DELETE request to {path} and deserialize the response body
pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let request = self
.http_client
.delete(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?;

self.execute(request).await
let request_maker = || async {
Ok(self
.http_client
.delete(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};

self.execute(request_maker).await
}

/// Make a POST request to {path} and deserialize the response body
Expand All @@ -147,117 +150,96 @@ impl<C: Config> Client<C> {
I: Serialize,
O: DeserializeOwned,
{
let request = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?;

self.execute(request).await
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};

self.execute(request_maker).await
}

/// POST a form at {path} and deserialize the response body
pub(crate) async fn post_form<O>(
&self,
path: &str,
form: reqwest::multipart::Form,
) -> Result<O, OpenAIError>
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(form)
.build()?;

self.execute(request).await
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(async_convert::TryInto::try_into(form.clone()).await?)
.build()?)
};

self.execute(request_maker).await
}

/// Deserialize response body from either error object or actual response object
async fn process_response<O>(&self, response: reqwest::Response) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let status = response.status();
let bytes = response.bytes().await?;

if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;

return Err(OpenAIError::ApiError(wrapped_error.error));
}

let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
Ok(response)
}

/// Execute any HTTP requests and retry on rate limit, except streaming ones as they cannot be cloned for retrying.
async fn execute<O>(&self, request: reqwest::Request) -> Result<O, OpenAIError>
/// Execute a HTTP request and retry on rate limit
///
/// request_maker serves one purpose: to be able to create request again
/// to retry API call after getting rate limited. request_maker is async because
/// reqwest::multipart::Form is created by async calls to read files for uploads.
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let client = self.http_client.clone();

match request.try_clone() {
// Only clone-able requests can be retried
Some(request) => {
backoff::future::retry(self.backoff.clone(), || async {
let response = client
.execute(request.try_clone().unwrap())
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;

let status = response.status();
let bytes = response
.bytes()
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;

if status.as_u16() == 429
// API returns 429 also when:
// "You exceeded your current quota, please check your plan and billing details."
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
wrapped_error.error,
)));
}
}
backoff::future::retry(self.backoff.clone(), || async {
let response = client
.execute(request_maker().await?)
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;

let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
Ok(response)
})
let status = response.status();
let bytes = response
.bytes()
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;

if status.as_u16() == 429
// API returns 429 also when:
// "You exceeded your current quota, please check your plan and billing details."
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
wrapped_error.error,
)));
}
}
None => {
let response = client.execute(request).await?;
self.process_response(response).await
}
}

let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
Ok(response)
})
.await
}

/// Make HTTP POST request to receive SSE
Expand Down
7 changes: 1 addition & 6 deletions async-openai/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
config::Config,
error::OpenAIError,
types::{CreateFileRequest, DeleteFileResponse, ListFilesResponse, OpenAIFile},
util::create_file_part,
Client,
};

Expand All @@ -18,11 +17,7 @@ impl<'c, C: Config> Files<'c, C> {

/// Upload a file that contains document(s) to be used across various endpoints/features. Currently, the size of all the files uploaded by one organization can be up to 1 GB. Please contact us if you need to increase the storage limit.
pub async fn create(&self, request: CreateFileRequest) -> Result<OpenAIFile, OpenAIError> {
let file_part = create_file_part(&request.file.path).await?;
let form = reqwest::multipart::Form::new()
.part("file", file_part)
.text("purpose", request.purpose);
self.client.post_form("/files", form).await
self.client.post_form("/files", request).await
}

/// Returns a list of files that belong to the user's organization.
Expand Down
Loading