Skip to content

Commit 4e41820

Browse files
authored
Retry form submissions on rate limit (64bit#100)
* retry form submission * update rate limit example * update doc * cargo fmt * document request_maker
1 parent db03584 commit 4e41820

File tree

15 files changed

+280
-247
lines changed

15 files changed

+280
-247
lines changed

async-openai/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ tokio-stream = "0.1.11"
3838
tokio-util = { version = "0.7.7", features = ["codec", "io-util"] }
3939
tracing = "0.1.37"
4040
derive_builder = "0.12.0"
41+
async-convert = "1.0.0"
4142

4243
[dev-dependencies]
4344
tokio-test = "0.4.2"

async-openai/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
- [x] Microsoft Azure Endpoints
3535
- [x] Models
3636
- [x] Moderations
37-
- Non-streaming requests are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits) by the API server.
37+
- All requests including form submissions (except SSE streaming) are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits) by the API server.
3838
- Ergonomic Rust library with builder pattern for all request objects.
3939

4040
**Note on Azure OpenAI Service**: `async-openai` primarily implements OpenAI APIs, and exposes same library for Azure OpenAI Service too. In reality Azure OpenAI Service provides only subset of OpenAI APIs.

async-openai/src/audio.rs

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use crate::{
55
CreateTranscriptionRequest, CreateTranscriptionResponse, CreateTranslationRequest,
66
CreateTranslationResponse,
77
},
8-
util::create_file_part,
98
Client,
109
};
1110

@@ -25,50 +24,16 @@ impl<'c, C: Config> Audio<'c, C> {
2524
&self,
2625
request: CreateTranscriptionRequest,
2726
) -> Result<CreateTranscriptionResponse, OpenAIError> {
28-
let audio_part = create_file_part(&request.file.path).await?;
29-
30-
let mut form = reqwest::multipart::Form::new()
31-
.part("file", audio_part)
32-
.text("model", request.model);
33-
34-
if let Some(prompt) = request.prompt {
35-
form = form.text("prompt", prompt);
36-
}
37-
38-
if let Some(response_format) = request.response_format {
39-
form = form.text("response_format", response_format.to_string())
40-
}
41-
42-
if let Some(temperature) = request.temperature {
43-
form = form.text("temperature", temperature.to_string())
44-
}
45-
46-
self.client.post_form("/audio/transcriptions", form).await
27+
self.client
28+
.post_form("/audio/transcriptions", request)
29+
.await
4730
}
4831

4932
/// Translates audio into into English.
5033
pub async fn translate(
5134
&self,
5235
request: CreateTranslationRequest,
5336
) -> Result<CreateTranslationResponse, OpenAIError> {
54-
let audio_part = create_file_part(&request.file.path).await?;
55-
56-
let mut form = reqwest::multipart::Form::new()
57-
.part("file", audio_part)
58-
.text("model", request.model);
59-
60-
if let Some(prompt) = request.prompt {
61-
form = form.text("prompt", prompt);
62-
}
63-
64-
if let Some(response_format) = request.response_format {
65-
form = form.text("response_format", response_format.to_string())
66-
}
67-
68-
if let Some(temperature) = request.temperature {
69-
form = form.text("temperature", temperature.to_string())
70-
}
71-
72-
self.client.post_form("/audio/translations", form).await
37+
self.client.post_form("/audio/translations", request).await
7338
}
7439
}

async-openai/src/client.rs

Lines changed: 94 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ impl<C: Config> Client<C> {
5353
}
5454

5555
/// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
56-
/// Form submissions are not retried.
5756
pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
5857
self.backoff = backoff;
5958
self
@@ -116,29 +115,33 @@ impl<C: Config> Client<C> {
116115
where
117116
O: DeserializeOwned,
118117
{
119-
let request = self
120-
.http_client
121-
.get(self.config.url(path))
122-
.query(&self.config.query())
123-
.headers(self.config.headers())
124-
.build()?;
125-
126-
self.execute(request).await
118+
let request_maker = || async {
119+
Ok(self
120+
.http_client
121+
.get(self.config.url(path))
122+
.query(&self.config.query())
123+
.headers(self.config.headers())
124+
.build()?)
125+
};
126+
127+
self.execute(request_maker).await
127128
}
128129

129130
/// Make a DELETE request to {path} and deserialize the response body
130131
pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
131132
where
132133
O: DeserializeOwned,
133134
{
134-
let request = self
135-
.http_client
136-
.delete(self.config.url(path))
137-
.query(&self.config.query())
138-
.headers(self.config.headers())
139-
.build()?;
140-
141-
self.execute(request).await
135+
let request_maker = || async {
136+
Ok(self
137+
.http_client
138+
.delete(self.config.url(path))
139+
.query(&self.config.query())
140+
.headers(self.config.headers())
141+
.build()?)
142+
};
143+
144+
self.execute(request_maker).await
142145
}
143146

144147
/// Make a POST request to {path} and deserialize the response body
@@ -147,117 +150,96 @@ impl<C: Config> Client<C> {
147150
I: Serialize,
148151
O: DeserializeOwned,
149152
{
150-
let request = self
151-
.http_client
152-
.post(self.config.url(path))
153-
.query(&self.config.query())
154-
.headers(self.config.headers())
155-
.json(&request)
156-
.build()?;
157-
158-
self.execute(request).await
153+
let request_maker = || async {
154+
Ok(self
155+
.http_client
156+
.post(self.config.url(path))
157+
.query(&self.config.query())
158+
.headers(self.config.headers())
159+
.json(&request)
160+
.build()?)
161+
};
162+
163+
self.execute(request_maker).await
159164
}
160165

161166
/// POST a form at {path} and deserialize the response body
162-
pub(crate) async fn post_form<O>(
163-
&self,
164-
path: &str,
165-
form: reqwest::multipart::Form,
166-
) -> Result<O, OpenAIError>
167+
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
167168
where
168169
O: DeserializeOwned,
170+
reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
171+
F: Clone,
169172
{
170-
let request = self
171-
.http_client
172-
.post(self.config.url(path))
173-
.query(&self.config.query())
174-
.headers(self.config.headers())
175-
.multipart(form)
176-
.build()?;
177-
178-
self.execute(request).await
173+
let request_maker = || async {
174+
Ok(self
175+
.http_client
176+
.post(self.config.url(path))
177+
.query(&self.config.query())
178+
.headers(self.config.headers())
179+
.multipart(async_convert::TryInto::try_into(form.clone()).await?)
180+
.build()?)
181+
};
182+
183+
self.execute(request_maker).await
179184
}
180185

181-
/// Deserialize response body from either error object or actual response object
182-
async fn process_response<O>(&self, response: reqwest::Response) -> Result<O, OpenAIError>
183-
where
184-
O: DeserializeOwned,
185-
{
186-
let status = response.status();
187-
let bytes = response.bytes().await?;
188-
189-
if !status.is_success() {
190-
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
191-
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
192-
193-
return Err(OpenAIError::ApiError(wrapped_error.error));
194-
}
195-
196-
let response: O = serde_json::from_slice(bytes.as_ref())
197-
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
198-
Ok(response)
199-
}
200-
201-
/// Execute any HTTP requests and retry on rate limit, except streaming ones as they cannot be cloned for retrying.
202-
async fn execute<O>(&self, request: reqwest::Request) -> Result<O, OpenAIError>
186+
/// Execute a HTTP request and retry on rate limit
187+
///
188+
/// request_maker serves one purpose: to be able to create request again
189+
/// to retry API call after getting rate limited. request_maker is async because
190+
/// reqwest::multipart::Form is created by async calls to read files for uploads.
191+
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
203192
where
204193
O: DeserializeOwned,
194+
M: Fn() -> Fut,
195+
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
205196
{
206197
let client = self.http_client.clone();
207198

208-
match request.try_clone() {
209-
// Only clone-able requests can be retried
210-
Some(request) => {
211-
backoff::future::retry(self.backoff.clone(), || async {
212-
let response = client
213-
.execute(request.try_clone().unwrap())
214-
.await
215-
.map_err(OpenAIError::Reqwest)
216-
.map_err(backoff::Error::Permanent)?;
217-
218-
let status = response.status();
219-
let bytes = response
220-
.bytes()
221-
.await
222-
.map_err(OpenAIError::Reqwest)
223-
.map_err(backoff::Error::Permanent)?;
224-
225-
// Deserialize response body from either error object or actual response object
226-
if !status.is_success() {
227-
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
228-
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
229-
.map_err(backoff::Error::Permanent)?;
230-
231-
if status.as_u16() == 429
232-
// API returns 429 also when:
233-
// "You exceeded your current quota, please check your plan and billing details."
234-
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
235-
{
236-
// Rate limited retry...
237-
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
238-
return Err(backoff::Error::Transient {
239-
err: OpenAIError::ApiError(wrapped_error.error),
240-
retry_after: None,
241-
});
242-
} else {
243-
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
244-
wrapped_error.error,
245-
)));
246-
}
247-
}
199+
backoff::future::retry(self.backoff.clone(), || async {
200+
let response = client
201+
.execute(request_maker().await?)
202+
.await
203+
.map_err(OpenAIError::Reqwest)
204+
.map_err(backoff::Error::Permanent)?;
248205

249-
let response: O = serde_json::from_slice(bytes.as_ref())
250-
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
251-
.map_err(backoff::Error::Permanent)?;
252-
Ok(response)
253-
})
206+
let status = response.status();
207+
let bytes = response
208+
.bytes()
254209
.await
210+
.map_err(OpenAIError::Reqwest)
211+
.map_err(backoff::Error::Permanent)?;
212+
213+
// Deserialize response body from either error object or actual response object
214+
if !status.is_success() {
215+
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
216+
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
217+
.map_err(backoff::Error::Permanent)?;
218+
219+
if status.as_u16() == 429
220+
// API returns 429 also when:
221+
// "You exceeded your current quota, please check your plan and billing details."
222+
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
223+
{
224+
// Rate limited retry...
225+
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
226+
return Err(backoff::Error::Transient {
227+
err: OpenAIError::ApiError(wrapped_error.error),
228+
retry_after: None,
229+
});
230+
} else {
231+
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
232+
wrapped_error.error,
233+
)));
234+
}
255235
}
256-
None => {
257-
let response = client.execute(request).await?;
258-
self.process_response(response).await
259-
}
260-
}
236+
237+
let response: O = serde_json::from_slice(bytes.as_ref())
238+
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
239+
.map_err(backoff::Error::Permanent)?;
240+
Ok(response)
241+
})
242+
.await
261243
}
262244

263245
/// Make HTTP POST request to receive SSE

async-openai/src/file.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use crate::{
22
config::Config,
33
error::OpenAIError,
44
types::{CreateFileRequest, DeleteFileResponse, ListFilesResponse, OpenAIFile},
5-
util::create_file_part,
65
Client,
76
};
87

@@ -18,11 +17,7 @@ impl<'c, C: Config> Files<'c, C> {
1817

1918
/// Upload a file that contains document(s) to be used across various endpoints/features. Currently, the size of all the files uploaded by one organization can be up to 1 GB. Please contact us if you need to increase the storage limit.
2019
pub async fn create(&self, request: CreateFileRequest) -> Result<OpenAIFile, OpenAIError> {
21-
let file_part = create_file_part(&request.file.path).await?;
22-
let form = reqwest::multipart::Form::new()
23-
.part("file", file_part)
24-
.text("purpose", request.purpose);
25-
self.client.post_form("/files", form).await
20+
self.client.post_form("/files", request).await
2621
}
2722

2823
/// Returns a list of files that belong to the user's organization.

0 commit comments

Comments
 (0)