@@ -9,6 +9,9 @@ use crate::{
99 Embedding , EmbeddingModel , EmbeddingOutput , ModelInfo , QuantizationMode , SingleBatchOutput ,
1010} ;
1111#[ cfg( feature = "online" ) ]
12+ use anyhow:: Context ;
13+ use anyhow:: Result ;
14+ #[ cfg( feature = "online" ) ]
1215use hf_hub:: {
1316 api:: sync:: { ApiBuilder , ApiRepo } ,
1417 Cache ,
@@ -40,7 +43,7 @@ impl TextEmbedding {
4043 ///
4144 /// Uses the total number of CPUs available as the number of intra-threads
4245 #[ cfg( feature = "online" ) ]
43- pub fn try_new ( options : InitOptions ) -> anyhow :: Result < Self > {
46+ pub fn try_new ( options : InitOptions ) -> Result < Self > {
4447 let InitOptions {
4548 model_name,
4649 execution_providers,
@@ -61,7 +64,7 @@ impl TextEmbedding {
6164 let model_file_name = & model_info. model_file ;
6265 let model_file_reference = model_repo
6366 . get ( model_file_name)
64- . unwrap_or_else ( |_| panic ! ( "Failed to retrieve {} " , model_file_name) ) ;
67+ . context ( format ! ( "Failed to retrieve {}" , model_file_name) ) ? ;
6568
6669 // TODO: If more models need .onnx_data, implement a better way to handle this
6770 // Probably by adding `additional_files` field in the `ModelInfo` struct
@@ -95,7 +98,7 @@ impl TextEmbedding {
9598 pub fn try_new_from_user_defined (
9699 model : UserDefinedEmbeddingModel ,
97100 options : InitOptionsUserDefined ,
98- ) -> anyhow :: Result < Self > {
101+ ) -> Result < Self > {
99102 let InitOptionsUserDefined {
100103 execution_providers,
101104 max_length,
@@ -147,8 +150,7 @@ impl TextEmbedding {
147150 let cache = Cache :: new ( cache_dir) ;
148151 let api = ApiBuilder :: from_cache ( cache)
149152 . with_progress ( show_download_progress)
150- . build ( )
151- . unwrap ( ) ;
153+ . build ( ) ?;
152154
153155 let repo = api. model ( model. to_string ( ) ) ;
154156 Ok ( repo)
@@ -160,7 +162,7 @@ impl TextEmbedding {
160162 }
161163
162164 /// Get ModelInfo from EmbeddingModel
163- pub fn get_model_info ( model : & EmbeddingModel ) -> anyhow :: Result < & ModelInfo < EmbeddingModel > > {
165+ pub fn get_model_info ( model : & EmbeddingModel ) -> Result < & ModelInfo < EmbeddingModel > > {
164166 get_model_info ( model) . ok_or_else ( || {
165167 anyhow:: Error :: msg ( format ! (
166168 "Model {model:?} not found. Please check if the model is supported \
@@ -195,7 +197,7 @@ impl TextEmbedding {
195197 & ' e self ,
196198 texts : Vec < S > ,
197199 batch_size : Option < usize > ,
198- ) -> anyhow :: Result < EmbeddingOutput < ' r , ' s > >
200+ ) -> Result < EmbeddingOutput < ' r , ' s > >
199201 where
200202 ' e : ' r ,
201203 ' e : ' s ,
@@ -223,72 +225,70 @@ impl TextEmbedding {
223225 _ => Ok ( batch_size. unwrap_or ( DEFAULT_BATCH_SIZE ) ) ,
224226 } ?;
225227
226- let batches =
227- anyhow:: Result :: < Vec < _ > > :: from_par_iter ( texts. par_chunks ( batch_size) . map ( |batch| {
228- // Encode the texts in the batch
229- let inputs = batch. iter ( ) . map ( |text| text. as_ref ( ) ) . collect ( ) ;
230- let encodings = self . tokenizer . encode_batch ( inputs, true ) . map_err ( |e| {
231- anyhow:: Error :: msg ( e. to_string ( ) ) . context ( "Failed to encode the batch." )
232- } ) ?;
233-
234- // Extract the encoding length and batch size
235- let encoding_length = encodings[ 0 ] . len ( ) ;
236- let batch_size = batch. len ( ) ;
237-
238- let max_size = encoding_length * batch_size;
239-
240- // Preallocate arrays with the maximum size
241- let mut ids_array = Vec :: with_capacity ( max_size) ;
242- let mut mask_array = Vec :: with_capacity ( max_size) ;
243- let mut typeids_array = Vec :: with_capacity ( max_size) ;
244-
245- // Not using par_iter because the closure needs to be FnMut
246- encodings. iter ( ) . for_each ( |encoding| {
247- let ids = encoding. get_ids ( ) ;
248- let mask = encoding. get_attention_mask ( ) ;
249- let typeids = encoding. get_type_ids ( ) ;
250-
251- // Extend the preallocated arrays with the current encoding
252- // Requires the closure to be FnMut
253- ids_array. extend ( ids. iter ( ) . map ( |x| * x as i64 ) ) ;
254- mask_array. extend ( mask. iter ( ) . map ( |x| * x as i64 ) ) ;
255- typeids_array. extend ( typeids. iter ( ) . map ( |x| * x as i64 ) ) ;
256- } ) ;
257-
258- // Create CowArrays from vectors
259- let inputs_ids_array =
260- Array :: from_shape_vec ( ( batch_size, encoding_length) , ids_array) ?;
261-
262- let attention_mask_array =
263- Array :: from_shape_vec ( ( batch_size, encoding_length) , mask_array) ?;
264-
265- let token_type_ids_array =
266- Array :: from_shape_vec ( ( batch_size, encoding_length) , typeids_array) ?;
267-
268- let mut session_inputs = ort:: inputs![
269- "input_ids" => Value :: from_array( inputs_ids_array) ?,
270- "attention_mask" => Value :: from_array( attention_mask_array. view( ) ) ?,
271- ] ?;
272-
273- if self . need_token_type_ids {
274- session_inputs. push ( (
275- "token_type_ids" . into ( ) ,
276- Value :: from_array ( token_type_ids_array) ?. into ( ) ,
277- ) ) ;
278- }
228+ let batches = Result :: < Vec < _ > > :: from_par_iter ( texts. par_chunks ( batch_size) . map ( |batch| {
229+ // Encode the texts in the batch
230+ let inputs = batch. iter ( ) . map ( |text| text. as_ref ( ) ) . collect ( ) ;
231+ let encodings = self . tokenizer . encode_batch ( inputs, true ) . map_err ( |e| {
232+ anyhow:: Error :: msg ( e. to_string ( ) ) . context ( "Failed to encode the batch." )
233+ } ) ?;
234+
235+ // Extract the encoding length and batch size
236+ let encoding_length = encodings[ 0 ] . len ( ) ;
237+ let batch_size = batch. len ( ) ;
238+
239+ let max_size = encoding_length * batch_size;
240+
241+ // Preallocate arrays with the maximum size
242+ let mut ids_array = Vec :: with_capacity ( max_size) ;
243+ let mut mask_array = Vec :: with_capacity ( max_size) ;
244+ let mut typeids_array = Vec :: with_capacity ( max_size) ;
245+
246+ // Not using par_iter because the closure needs to be FnMut
247+ encodings. iter ( ) . for_each ( |encoding| {
248+ let ids = encoding. get_ids ( ) ;
249+ let mask = encoding. get_attention_mask ( ) ;
250+ let typeids = encoding. get_type_ids ( ) ;
251+
252+ // Extend the preallocated arrays with the current encoding
253+ // Requires the closure to be FnMut
254+ ids_array. extend ( ids. iter ( ) . map ( |x| * x as i64 ) ) ;
255+ mask_array. extend ( mask. iter ( ) . map ( |x| * x as i64 ) ) ;
256+ typeids_array. extend ( typeids. iter ( ) . map ( |x| * x as i64 ) ) ;
257+ } ) ;
258+
259+ // Create CowArrays from vectors
260+ let inputs_ids_array = Array :: from_shape_vec ( ( batch_size, encoding_length) , ids_array) ?;
261+
262+ let attention_mask_array =
263+ Array :: from_shape_vec ( ( batch_size, encoding_length) , mask_array) ?;
264+
265+ let token_type_ids_array =
266+ Array :: from_shape_vec ( ( batch_size, encoding_length) , typeids_array) ?;
267+
268+ let mut session_inputs = ort:: inputs![
269+ "input_ids" => Value :: from_array( inputs_ids_array) ?,
270+ "attention_mask" => Value :: from_array( attention_mask_array. view( ) ) ?,
271+ ] ?;
272+
273+ if self . need_token_type_ids {
274+ session_inputs. push ( (
275+ "token_type_ids" . into ( ) ,
276+ Value :: from_array ( token_type_ids_array) ?. into ( ) ,
277+ ) ) ;
278+ }
279279
280- Ok (
281- // Package all the data required for post-processing (e.g. pooling)
282- // into a SingleBatchOutput struct.
283- SingleBatchOutput {
284- session_outputs : self
285- . session
286- . run ( session_inputs)
287- . map_err ( anyhow:: Error :: new) ?,
288- attention_mask_array,
289- } ,
290- )
291- } ) ) ?;
280+ Ok (
281+ // Package all the data required for post-processing (e.g. pooling)
282+ // into a SingleBatchOutput struct.
283+ SingleBatchOutput {
284+ session_outputs : self
285+ . session
286+ . run ( session_inputs)
287+ . map_err ( anyhow:: Error :: new) ?,
288+ attention_mask_array,
289+ } ,
290+ )
291+ } ) ) ?;
292292
293293 Ok ( EmbeddingOutput :: new ( batches) )
294294 }
@@ -308,7 +308,7 @@ impl TextEmbedding {
308308 & self ,
309309 texts : Vec < S > ,
310310 batch_size : Option < usize > ,
311- ) -> anyhow :: Result < Vec < Embedding > > {
311+ ) -> Result < Vec < Embedding > > {
312312 let batches = self . transform ( texts, batch_size) ?;
313313
314314 batches. export_with_transformer ( output:: transformer_with_precedence (
0 commit comments