@@ -53,7 +53,6 @@ impl<C: Config> Client<C> {
53
53
}
54
54
55
55
/// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
56
- /// Form submissions are not retried.
57
56
pub fn with_backoff ( mut self , backoff : backoff:: ExponentialBackoff ) -> Self {
58
57
self . backoff = backoff;
59
58
self
@@ -116,29 +115,33 @@ impl<C: Config> Client<C> {
116
115
where
117
116
O : DeserializeOwned ,
118
117
{
119
- let request = self
120
- . http_client
121
- . get ( self . config . url ( path) )
122
- . query ( & self . config . query ( ) )
123
- . headers ( self . config . headers ( ) )
124
- . build ( ) ?;
125
-
126
- self . execute ( request) . await
118
+ let request_maker = || async {
119
+ Ok ( self
120
+ . http_client
121
+ . get ( self . config . url ( path) )
122
+ . query ( & self . config . query ( ) )
123
+ . headers ( self . config . headers ( ) )
124
+ . build ( ) ?)
125
+ } ;
126
+
127
+ self . execute ( request_maker) . await
127
128
}
128
129
129
130
/// Make a DELETE request to {path} and deserialize the response body
130
131
pub ( crate ) async fn delete < O > ( & self , path : & str ) -> Result < O , OpenAIError >
131
132
where
132
133
O : DeserializeOwned ,
133
134
{
134
- let request = self
135
- . http_client
136
- . delete ( self . config . url ( path) )
137
- . query ( & self . config . query ( ) )
138
- . headers ( self . config . headers ( ) )
139
- . build ( ) ?;
140
-
141
- self . execute ( request) . await
135
+ let request_maker = || async {
136
+ Ok ( self
137
+ . http_client
138
+ . delete ( self . config . url ( path) )
139
+ . query ( & self . config . query ( ) )
140
+ . headers ( self . config . headers ( ) )
141
+ . build ( ) ?)
142
+ } ;
143
+
144
+ self . execute ( request_maker) . await
142
145
}
143
146
144
147
/// Make a POST request to {path} and deserialize the response body
@@ -147,117 +150,96 @@ impl<C: Config> Client<C> {
147
150
I : Serialize ,
148
151
O : DeserializeOwned ,
149
152
{
150
- let request = self
151
- . http_client
152
- . post ( self . config . url ( path) )
153
- . query ( & self . config . query ( ) )
154
- . headers ( self . config . headers ( ) )
155
- . json ( & request)
156
- . build ( ) ?;
157
-
158
- self . execute ( request) . await
153
+ let request_maker = || async {
154
+ Ok ( self
155
+ . http_client
156
+ . post ( self . config . url ( path) )
157
+ . query ( & self . config . query ( ) )
158
+ . headers ( self . config . headers ( ) )
159
+ . json ( & request)
160
+ . build ( ) ?)
161
+ } ;
162
+
163
+ self . execute ( request_maker) . await
159
164
}
160
165
161
166
/// POST a form at {path} and deserialize the response body
162
- pub ( crate ) async fn post_form < O > (
163
- & self ,
164
- path : & str ,
165
- form : reqwest:: multipart:: Form ,
166
- ) -> Result < O , OpenAIError >
167
+ pub ( crate ) async fn post_form < O , F > ( & self , path : & str , form : F ) -> Result < O , OpenAIError >
167
168
where
168
169
O : DeserializeOwned ,
170
+ reqwest:: multipart:: Form : async_convert:: TryFrom < F , Error = OpenAIError > ,
171
+ F : Clone ,
169
172
{
170
- let request = self
171
- . http_client
172
- . post ( self . config . url ( path) )
173
- . query ( & self . config . query ( ) )
174
- . headers ( self . config . headers ( ) )
175
- . multipart ( form)
176
- . build ( ) ?;
177
-
178
- self . execute ( request) . await
173
+ let request_maker = || async {
174
+ Ok ( self
175
+ . http_client
176
+ . post ( self . config . url ( path) )
177
+ . query ( & self . config . query ( ) )
178
+ . headers ( self . config . headers ( ) )
179
+ . multipart ( async_convert:: TryInto :: try_into ( form. clone ( ) ) . await ?)
180
+ . build ( ) ?)
181
+ } ;
182
+
183
+ self . execute ( request_maker) . await
179
184
}
180
185
181
- /// Deserialize response body from either error object or actual response object
182
- async fn process_response < O > ( & self , response : reqwest:: Response ) -> Result < O , OpenAIError >
183
- where
184
- O : DeserializeOwned ,
185
- {
186
- let status = response. status ( ) ;
187
- let bytes = response. bytes ( ) . await ?;
188
-
189
- if !status. is_success ( ) {
190
- let wrapped_error: WrappedError = serde_json:: from_slice ( bytes. as_ref ( ) )
191
- . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) ) ?;
192
-
193
- return Err ( OpenAIError :: ApiError ( wrapped_error. error ) ) ;
194
- }
195
-
196
- let response: O = serde_json:: from_slice ( bytes. as_ref ( ) )
197
- . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) ) ?;
198
- Ok ( response)
199
- }
200
-
201
- /// Execute any HTTP requests and retry on rate limit, except streaming ones as they cannot be cloned for retrying.
202
- async fn execute < O > ( & self , request : reqwest:: Request ) -> Result < O , OpenAIError >
186
+ /// Execute a HTTP request and retry on rate limit
187
+ ///
188
+ /// request_maker serves one purpose: to be able to create request again
189
+ /// to retry API call after getting rate limited. request_maker is async because
190
+ /// reqwest::multipart::Form is created by async calls to read files for uploads.
191
+ async fn execute < O , M , Fut > ( & self , request_maker : M ) -> Result < O , OpenAIError >
203
192
where
204
193
O : DeserializeOwned ,
194
+ M : Fn ( ) -> Fut ,
195
+ Fut : core:: future:: Future < Output = Result < reqwest:: Request , OpenAIError > > ,
205
196
{
206
197
let client = self . http_client . clone ( ) ;
207
198
208
- match request. try_clone ( ) {
209
- // Only clone-able requests can be retried
210
- Some ( request) => {
211
- backoff:: future:: retry ( self . backoff . clone ( ) , || async {
212
- let response = client
213
- . execute ( request. try_clone ( ) . unwrap ( ) )
214
- . await
215
- . map_err ( OpenAIError :: Reqwest )
216
- . map_err ( backoff:: Error :: Permanent ) ?;
217
-
218
- let status = response. status ( ) ;
219
- let bytes = response
220
- . bytes ( )
221
- . await
222
- . map_err ( OpenAIError :: Reqwest )
223
- . map_err ( backoff:: Error :: Permanent ) ?;
224
-
225
- // Deserialize response body from either error object or actual response object
226
- if !status. is_success ( ) {
227
- let wrapped_error: WrappedError = serde_json:: from_slice ( bytes. as_ref ( ) )
228
- . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) )
229
- . map_err ( backoff:: Error :: Permanent ) ?;
230
-
231
- if status. as_u16 ( ) == 429
232
- // API returns 429 also when:
233
- // "You exceeded your current quota, please check your plan and billing details."
234
- && wrapped_error. error . r#type != Some ( "insufficient_quota" . to_string ( ) )
235
- {
236
- // Rate limited retry...
237
- tracing:: warn!( "Rate limited: {}" , wrapped_error. error. message) ;
238
- return Err ( backoff:: Error :: Transient {
239
- err : OpenAIError :: ApiError ( wrapped_error. error ) ,
240
- retry_after : None ,
241
- } ) ;
242
- } else {
243
- return Err ( backoff:: Error :: Permanent ( OpenAIError :: ApiError (
244
- wrapped_error. error ,
245
- ) ) ) ;
246
- }
247
- }
199
+ backoff:: future:: retry ( self . backoff . clone ( ) , || async {
200
+ let response = client
201
+ . execute ( request_maker ( ) . await ?)
202
+ . await
203
+ . map_err ( OpenAIError :: Reqwest )
204
+ . map_err ( backoff:: Error :: Permanent ) ?;
248
205
249
- let response: O = serde_json:: from_slice ( bytes. as_ref ( ) )
250
- . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) )
251
- . map_err ( backoff:: Error :: Permanent ) ?;
252
- Ok ( response)
253
- } )
206
+ let status = response. status ( ) ;
207
+ let bytes = response
208
+ . bytes ( )
254
209
. await
210
+ . map_err ( OpenAIError :: Reqwest )
211
+ . map_err ( backoff:: Error :: Permanent ) ?;
212
+
213
+ // Deserialize response body from either error object or actual response object
214
+ if !status. is_success ( ) {
215
+ let wrapped_error: WrappedError = serde_json:: from_slice ( bytes. as_ref ( ) )
216
+ . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) )
217
+ . map_err ( backoff:: Error :: Permanent ) ?;
218
+
219
+ if status. as_u16 ( ) == 429
220
+ // API returns 429 also when:
221
+ // "You exceeded your current quota, please check your plan and billing details."
222
+ && wrapped_error. error . r#type != Some ( "insufficient_quota" . to_string ( ) )
223
+ {
224
+ // Rate limited retry...
225
+ tracing:: warn!( "Rate limited: {}" , wrapped_error. error. message) ;
226
+ return Err ( backoff:: Error :: Transient {
227
+ err : OpenAIError :: ApiError ( wrapped_error. error ) ,
228
+ retry_after : None ,
229
+ } ) ;
230
+ } else {
231
+ return Err ( backoff:: Error :: Permanent ( OpenAIError :: ApiError (
232
+ wrapped_error. error ,
233
+ ) ) ) ;
234
+ }
255
235
}
256
- None => {
257
- let response = client. execute ( request) . await ?;
258
- self . process_response ( response) . await
259
- }
260
- }
236
+
237
+ let response: O = serde_json:: from_slice ( bytes. as_ref ( ) )
238
+ . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) )
239
+ . map_err ( backoff:: Error :: Permanent ) ?;
240
+ Ok ( response)
241
+ } )
242
+ . await
261
243
}
262
244
263
245
/// Make HTTP POST request to receive SSE
0 commit comments