From 3489e69b6463e811158318da4e56f3a2dc0c8fbb Mon Sep 17 00:00:00 2001 From: Anush Date: Mon, 7 Jul 2025 15:32:21 +0530 Subject: [PATCH 1/2] refactor!: Upated ort usage (#170) * refactor: Bump ort Signed-off-by: Anush008 * docs: Updated README.md Signed-off-by: Anush008 * docs: Updated README.md Signed-off-by: Anush008 --------- Signed-off-by: Anush008 --- .github/workflows/test.yml | 2 +- Cargo.toml | 8 +- README.md | 48 ++++++++--- src/common.rs | 23 ++--- src/image_embedding/impl.rs | 40 +++++---- src/lib.rs | 4 +- src/output/embedding_output.rs | 38 +++----- src/reranking/impl.rs | 130 ++++++++++++---------------- src/sparse_text_embedding/impl.rs | 26 +++--- src/text_embedding/impl.rs | 138 ++++++++++++++---------------- tests/embeddings.rs | 98 +++++++++------------ 11 files changed, 258 insertions(+), 297 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 256129b..a7a01c9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,7 +7,7 @@ on: env: CARGO_TERM_COLOR: always RUSTFLAGS: "-Dwarnings" - ONNX_VERSION: v1.20.1 + ONNX_VERSION: v1.22.0 jobs: test: diff --git a/Cargo.toml b/Cargo.toml index a4213a2..9e3690a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,13 +26,11 @@ anyhow = { version = "1" } hf-hub = { version = "0.4.1", default-features = false, optional = true } image = "0.25.2" ndarray = { version = "0.16", default-features = false } -ort = { version = "=2.0.0-rc.9", default-features = false, features = [ - "ndarray", +ort = { version = "=2.0.0-rc.10", default-features = false, features = [ + "ndarray", "std" ] } -rayon = { version = "1.10", default-features = false } serde_json = { version = "1" } -tokenizers = { version = "0.21", default-features = false, features = ["onig"] } -ort-sys = { version = "=2.0.0-rc.9", default-features = false } +tokenizers = { version = "0.21.2", default-features = false, features = ["onig"] } [features] default = ["ort-download-binaries", "hf-hub-native-tls"] diff --git a/README.md b/README.md index def24e3..b65e586 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,6 @@ - Supports synchronous usage. No dependency on Tokio. - Uses [@pykeio/ort](https://github.com/pykeio/ort) for performant ONNX inference. - Uses [@huggingface/tokenizers](https://github.com/huggingface/tokenizers) for fast encodings. -- Supports batch embeddings generation with parallelism using [@rayon-rs/rayon](https://github.com/rayon-rs/rayon). ## 🔍 Not looking for Rust? @@ -64,7 +63,7 @@ ## 🚀 Installation -Run the following command in your project directory: +Run the following in your project directory: ```bash cargo add fastembed @@ -84,11 +83,11 @@ fastembed = "4" ```rust use fastembed::{TextEmbedding, InitOptions, EmbeddingModel}; -// With default InitOptions -let model = TextEmbedding::try_new(Default::default())?; +// With default options +let mut model = TextEmbedding::try_new(Default::default())?; -// With custom InitOptions -let model = TextEmbedding::try_new( +// With custom options +let mut model = TextEmbedding::try_new( InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true), )?; @@ -105,7 +104,30 @@ let documents = vec![ println!("Embeddings length: {}", embeddings.len()); // -> Embeddings length: 4 println!("Embedding dimension: {}", embeddings[0].len()); // -> Embedding dimension: 384 +``` + +### Sparse Text Embeddings + +```rust +use fastembed::{SparseEmbedding, SparseInitOptions, SparseModel, SparseTextEmbedding}; + +// With default options +let mut model = SparseTextEmbedding::try_new(Default::default())?; + +// With custom options +let mut model = SparseTextEmbedding::try_new( + SparseInitOptions::new(SparseModel::SPLADEPPV1).with_show_download_progress(true), +)?; + +let documents = vec![ + "passage: Hello, World!", + "query: Hello, World!", + "passage: This is an example passage.", + "fastembed-rs is licensed under Apache 2.0" + ]; +// Generate embeddings with the default batch size, 256 +let embeddings: Vec = model.embed(documents, None)?; ``` ### Image Embeddings @@ -113,11 +135,11 @@ let documents = vec![ ```rust use fastembed::{ImageEmbedding, ImageInitOptions, ImageEmbeddingModel}; -// With default InitOptions -let model = ImageEmbedding::try_new(Default::default())?; +// With default options +let mut model = ImageEmbedding::try_new(Default::default())?; -// With custom InitOptions -let model = ImageEmbedding::try_new( +// With custom options +let mut model = ImageEmbedding::try_new( ImageInitOptions::new(ImageEmbeddingModel::ClipVitB32).with_show_download_progress(true), )?; @@ -135,7 +157,11 @@ println!("Embedding dimension: {}", embeddings[0].len()); // -> Embedding dimens ```rust use fastembed::{TextRerank, RerankInitOptions, RerankerModel}; -let model = TextRerank::try_new( +// With default options +let mut model = TextRerank::try_new(Default::default())?; + +// With custom options +let mut model = TextRerank::try_new( RerankInitOptions::new(RerankerModel::BGERerankerBase).with_show_download_progress(true), )?; diff --git a/src/common.rs b/src/common.rs index 425c451..ff5de2d 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,8 +1,7 @@ use anyhow::Result; #[cfg(feature = "hf-hub")] use hf_hub::api::sync::{ApiBuilder, ApiRepo}; -use std::io::Read; -use std::{fs::File, path::PathBuf}; +use std::path::PathBuf; use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; const DEFAULT_CACHE_DIR: &str = ".fastembed_cache"; @@ -36,11 +35,11 @@ pub struct TokenizerFiles { #[cfg(feature = "hf-hub")] pub fn load_tokenizer_hf_hub(model_repo: ApiRepo, max_length: usize) -> Result { let tokenizer_files: TokenizerFiles = TokenizerFiles { - tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json")?)?, - config_file: read_file_to_bytes(&model_repo.get("config.json")?)?, - special_tokens_map_file: read_file_to_bytes(&model_repo.get("special_tokens_map.json")?)?, + tokenizer_file: std::fs::read(model_repo.get("tokenizer.json")?)?, + config_file: std::fs::read(&model_repo.get("config.json")?)?, + special_tokens_map_file: std::fs::read(&model_repo.get("special_tokens_map.json")?)?, - tokenizer_config_file: read_file_to_bytes(&model_repo.get("tokenizer_config.json")?)?, + tokenizer_config_file: std::fs::read(&model_repo.get("tokenizer_config.json")?)?, }; load_tokenizer(tokenizer_files, max_length) @@ -140,18 +139,6 @@ pub fn normalize(v: &[f32]) -> Vec { v.iter().map(|&val| val / (norm + epsilon)).collect() } -/// Public function to read a file to bytes. -/// To be used when loading local model files. -/// -/// Could be used to read the onnx file from a local cache in order to constitute a UserDefinedEmbeddingModel. -pub fn read_file_to_bytes(file: &PathBuf) -> Result> { - let mut file = File::open(file)?; - let file_size = file.metadata()?.len() as usize; - let mut buffer = Vec::with_capacity(file_size); - file.read_to_end(&mut buffer)?; - Ok(buffer) -} - /// Pulls a model repo from HuggingFace.. /// HF_HOME decides the location of the cache folder /// HF_ENDPOINT modifies the URL for the HuggingFace location. diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 81401cf..304969c 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -25,7 +25,6 @@ use super::{ utils::{Compose, Transform, TransformData}, ImageEmbedding, DEFAULT_BATCH_SIZE, }; -use rayon::prelude::*; impl ImageEmbedding { /// Try to generate a new ImageEmbedding Instance @@ -128,14 +127,14 @@ impl ImageEmbedding { /// Method to generate image embeddings for a Vec of image bytes pub fn embed_bytes( - &self, + &mut self, images: &[&[u8]], batch_size: Option, ) -> anyhow::Result> { let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); let output = images - .par_chunks(batch_size) + .chunks(batch_size) .map(|batch| { // Encode the texts in the batch let inputs = batch @@ -161,7 +160,7 @@ impl ImageEmbedding { /// Method to generate image embeddings for a Vec of image path // Generic type to accept String, &str, OsString, &OsStr pub fn embed + Send + Sync>( - &self, + &mut self, images: Vec, batch_size: Option, ) -> anyhow::Result> { @@ -169,7 +168,7 @@ impl ImageEmbedding { let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); let output = images - .par_chunks(batch_size) + .chunks(batch_size) .map(|batch| { // Encode the texts in the batch let inputs = batch @@ -192,7 +191,7 @@ impl ImageEmbedding { } /// Embed DynamicImages - pub fn embed_images(&self, imgs: Vec) -> anyhow::Result> { + pub fn embed_images(&mut self, imgs: Vec) -> anyhow::Result> { let inputs = imgs .into_iter() .map(|img| { @@ -211,7 +210,7 @@ impl ImageEmbedding { let input_name = self.session.inputs[0].name.clone(); let session_inputs = ort::inputs![ input_name => Value::from_array(pixel_values_array)?, - ]?; + ]; let outputs = self.session.run(session_inputs)?; @@ -223,7 +222,7 @@ impl ImageEmbedding { }; // Extract tensor and handle different dimensionalities - let output_data = last_hidden_state_key + let (shape, data) = last_hidden_state_key .iter() .find_map(|&key| { outputs @@ -231,30 +230,37 @@ impl ImageEmbedding { .and_then(|v| v.try_extract_tensor::().ok()) }) .ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?; - let shape = output_data.shape(); + let shape: Vec = shape.iter().map(|&d| d as usize).collect(); + let output_array = ndarray::ArrayViewD::from_shape(shape.as_slice(), data)?; - let embeddings = match shape.len() { + let embeddings = match output_array.ndim() { 3 => { // For 3D output [batch_size, sequence_length, hidden_size] // Take only the first token, sequence_length[0] (CLS token), embedding // and return [batch_size, hidden_size] - (0..shape[0]) + (0..output_array.shape()[0]) .map(|batch_idx| { - let cls_embedding = - output_data.slice(ndarray::s![batch_idx, 0, ..]).to_vec(); + let cls_embedding = output_array + .slice(ndarray::s![batch_idx, 0, ..]) + .to_owned() + .to_vec(); normalize(&cls_embedding) }) .collect() } 2 => { // For 2D output [batch_size, hidden_size] - output_data - .rows() - .into_iter() + output_array + .outer_iter() .map(|row| normalize(row.as_slice().unwrap())) .collect() } - _ => return Err(anyhow!("Unexpected output tensor shape: {:?}", shape)), + _ => { + return Err(anyhow!( + "Unexpected output tensor shape: {:?}", + output_array.shape() + )) + } }; Ok(embeddings) diff --git a/src/lib.rs b/src/lib.rs index 1130f5c..1b367f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,9 +64,7 @@ mod text_embedding; pub use ort::execution_providers::ExecutionProviderDispatch; -pub use crate::common::{ - get_cache_dir, read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, -}; +pub use crate::common::{get_cache_dir, Embedding, Error, SparseEmbedding, TokenizerFiles}; pub use crate::models::{ model_info::ModelInfo, model_info::RerankerModelInfo, quantization::QuantizationMode, }; diff --git a/src/output/embedding_output.rs b/src/output/embedding_output.rs index ac7dcc4..384d918 100644 --- a/src/output/embedding_output.rs +++ b/src/output/embedding_output.rs @@ -1,5 +1,4 @@ use ndarray::{Array2, ArrayView, Dim, IxDynImpl}; -use ort::session::SessionOutputs; use crate::pooling; @@ -10,12 +9,12 @@ use super::{OutputKey, OutputPrecedence}; /// In the future, each batch will need to deal with its own post-processing, such as /// pooling etc. This struct should contain all the necessary information for the /// post-processing to be performed. -pub struct SingleBatchOutput<'r, 's> { - pub session_outputs: SessionOutputs<'r, 's>, +pub struct SingleBatchOutput { + pub outputs: std::collections::HashMap, pub attention_mask_array: Array2, } -impl SingleBatchOutput<'_, '_> { +impl SingleBatchOutput { /// Select the output from the session outputs based on the given precedence. /// /// This returns a view into the tensor, which can be used to perform further @@ -27,27 +26,18 @@ impl SingleBatchOutput<'_, '_> { let ort_output: &ort::value::Value = precedence .key_precedence() .find_map(|key| match key { - OutputKey::OnlyOne => self - .session_outputs - .get(self.session_outputs.keys().nth(0)?), - OutputKey::ByOrder(idx) => { - let x = self - .session_outputs - .get(self.session_outputs.keys().nth(*idx)?); - x - } - OutputKey::ByName(name) => self.session_outputs.get(name), + OutputKey::OnlyOne => self.outputs.values().next(), + OutputKey::ByOrder(idx) => self.outputs.values().nth(*idx), + OutputKey::ByName(name) => self.outputs.get(*name), }) .ok_or_else(|| { anyhow::Error::msg(format!( - "No suitable output found in the session outputs. Available outputs: {:?}", - self.session_outputs.keys().collect::>() + "No suitable output found in the outputs. Available outputs: {:?}", + self.outputs.keys().collect::>() )) })?; - ort_output - .try_extract_tensor::() - .map_err(anyhow::Error::new) + ort_output.try_extract_array().map_err(anyhow::Error::new) } /// Select the output from the session outputs based on the given precedence and pool it. @@ -77,13 +67,13 @@ impl SingleBatchOutput<'_, '_> { /// Container struct with all the outputs from the embedding layer. /// /// This will contain one [`SingleBatchOutput`] object per batch/inference call. -pub struct EmbeddingOutput<'r, 's> { - batches: Vec>, +pub struct EmbeddingOutput { + batches: Vec, } -impl<'r, 's> EmbeddingOutput<'r, 's> { +impl EmbeddingOutput { /// Create a new [`EmbeddingOutput`] from a [`ort::SessionOutputs`] object. - pub fn new(batches: impl IntoIterator>) -> Self { + pub fn new(batches: impl IntoIterator) -> Self { Self { batches: batches.into_iter().collect(), } @@ -93,7 +83,7 @@ impl<'r, 's> EmbeddingOutput<'r, 's> { /// /// This allows the user to perform their custom extractions outside of this /// library. - pub fn into_raw(self) -> Vec> { + pub fn into_raw(self) -> Vec { self.batches } diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index 4b27a51..0c7455e 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -16,7 +16,6 @@ use crate::{ #[cfg(feature = "hf-hub")] use hf_hub::{api::sync::ApiBuilder, Cache}; use ndarray::{s, Array}; -use rayon::{iter::ParallelIterator, slice::ParallelSlice}; use tokenizers::Tokenizer; #[cfg(feature = "hf-hub")] @@ -124,85 +123,70 @@ impl TextRerank { /// Rerank documents using the reranker model and returns the results sorted by score in descending order. pub fn rerank + Send + Sync>( - &self, + &mut self, query: S, documents: Vec, return_documents: bool, batch_size: Option, ) -> Result> { let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); - let q = query.as_ref(); - let scores: Vec = documents - .par_chunks(batch_size) - .map(|batch| { - let inputs = batch.iter().map(|d| (q, d.as_ref())).collect(); - - let encodings = self - .tokenizer - .encode_batch(inputs, true) - .expect("Failed to encode batch"); - - let encoding_length = encodings[0].len(); - let batch_size = batch.len(); - - let max_size = encoding_length * batch_size; - - let mut ids_array = Vec::with_capacity(max_size); - let mut mask_array = Vec::with_capacity(max_size); - let mut type_ids_array = Vec::with_capacity(max_size); - - encodings.iter().for_each(|encoding| { - let ids = encoding.get_ids(); - let mask = encoding.get_attention_mask(); - let type_ids = encoding.get_type_ids(); - - ids_array.extend(ids.iter().map(|x| *x as i64)); - mask_array.extend(mask.iter().map(|x| *x as i64)); - type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); - }); - - let inputs_ids_array = - Array::from_shape_vec((batch_size, encoding_length), ids_array)?; - - let attention_mask_array = - Array::from_shape_vec((batch_size, encoding_length), mask_array)?; - - let token_type_ids_array = - Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; - - let mut session_inputs = ort::inputs![ - "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(attention_mask_array)?, - ]?; - - if self.need_token_type_ids { - session_inputs.push(( - "token_type_ids".into(), - Value::from_array(token_type_ids_array)?.into(), - )); - } - - let outputs = self.session.run(session_inputs)?; - - let outputs = outputs["logits"] - .try_extract_tensor::() - .expect("Failed to extract logits tensor"); - - let scores: Vec = outputs - .slice(s![.., 0]) - .rows() - .into_iter() - .flat_map(|row| row.to_vec()) - .collect(); - - Ok(scores) - }) - .collect::>>()? - .into_iter() - .flatten() - .collect(); + let mut scores: Vec = Vec::with_capacity(documents.len()); + for batch in documents.chunks(batch_size) { + let inputs = batch.iter().map(|d| (q, d.as_ref())).collect(); + let encodings = self + .tokenizer + .encode_batch(inputs, true) + .expect("Failed to encode batch"); + + let encoding_length = encodings[0].len(); + let batch_size = batch.len(); + let max_size = encoding_length * batch_size; + + let mut ids_array = Vec::with_capacity(max_size); + let mut mask_array = Vec::with_capacity(max_size); + let mut type_ids_array = Vec::with_capacity(max_size); + + encodings.iter().for_each(|encoding| { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + let type_ids = encoding.get_type_ids(); + + ids_array.extend(ids.iter().map(|x| *x as i64)); + mask_array.extend(mask.iter().map(|x| *x as i64)); + type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); + }); + + let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?; + let attention_mask_array = + Array::from_shape_vec((batch_size, encoding_length), mask_array)?; + let token_type_ids_array = + Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; + + let mut session_inputs = ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array)?, + ]; + if self.need_token_type_ids { + session_inputs.push(( + "token_type_ids".into(), + Value::from_array(token_type_ids_array)?.into(), + )); + } + + let outputs = self.session.run(session_inputs)?; + let outputs = outputs["logits"] + .try_extract_array() + .expect("Failed to extract logits tensor"); + let batch_scores: Vec = outputs + .slice(s![.., 0]) + .rows() + .into_iter() + .flat_map(|row| row.to_vec()) + .collect(); + scores.extend(batch_scores); + } // Return top_n_result of type Vec ordered by score in descending order, don't use binary heap let mut top_n_result: Vec = scores @@ -214,9 +198,7 @@ impl TextRerank { index, }) .collect(); - top_n_result.sort_by(|a, b| a.score.total_cmp(&b.score).reverse()); - Ok(top_n_result.to_vec()) } } diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index c19c052..a55b7a9 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -12,7 +12,6 @@ use hf_hub::api::sync::ApiRepo; use ndarray::{Array, ArrayViewD, Axis, CowArray, Dim}; use ort::{session::Session, value::Value}; #[cfg_attr(not(feature = "hf-hub"), allow(unused_imports))] -use rayon::{iter::ParallelIterator, slice::ParallelSlice}; #[cfg(feature = "hf-hub")] use std::path::PathBuf; use tokenizers::Tokenizer; @@ -108,7 +107,7 @@ impl SparseTextEmbedding { /// Method to generate sentence embeddings for a Vec of texts // Generic type to accept String, &str, OsString, &OsStr pub fn embed + Send + Sync>( - &self, + &mut self, texts: Vec, batch_size: Option, ) -> Result> { @@ -116,7 +115,7 @@ impl SparseTextEmbedding { let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); let output = texts - .par_chunks(batch_size) + .chunks(batch_size) .map(|batch| { // Encode the texts in the batch let inputs = batch.iter().map(|text| text.as_ref()).collect(); @@ -133,33 +132,29 @@ impl SparseTextEmbedding { let mut mask_array = Vec::with_capacity(max_size); let mut type_ids_array = Vec::with_capacity(max_size); - // Not using par_iter because the closure needs to be FnMut encodings.iter().for_each(|encoding| { let ids = encoding.get_ids(); let mask = encoding.get_attention_mask(); let type_ids = encoding.get_type_ids(); - // Extend the preallocated arrays with the current encoding - // Requires the closure to be FnMut ids_array.extend(ids.iter().map(|x| *x as i64)); mask_array.extend(mask.iter().map(|x| *x as i64)); type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); }); - // Create CowArrays from vectors let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?; - let owned_attention_mask = + let attention_mask_array = Array::from_shape_vec((batch_size, encoding_length), mask_array)?; - let attention_mask_array = CowArray::from(&owned_attention_mask); + // removed CowArray usage, use owned array let token_type_ids_array = Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; let mut session_inputs = ort::inputs![ "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(&attention_mask_array)?, - ]?; + "attention_mask" => Value::from_array(attention_mask_array.clone())?, + ]; if self.need_token_type_ids { session_inputs.push(( @@ -177,12 +172,15 @@ impl SparseTextEmbedding { _ => "last_hidden_state", }; - let output_data = outputs[last_hidden_state_key].try_extract_tensor::()?; + let (shape, data) = outputs[last_hidden_state_key].try_extract_tensor::()?; + let shape: Vec = shape.iter().map(|&d| d as usize).collect(); + let output_array = ndarray::ArrayViewD::from_shape(shape.as_slice(), data)?; + let attention_mask_cow = ndarray::CowArray::from(&attention_mask_array); let embeddings = SparseTextEmbedding::post_process( &self.model, - &output_data, - &attention_mask_array, + &output_array, + &attention_mask_cow, ); Ok(embeddings) diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 1161afe..2a19a86 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -18,10 +18,6 @@ use ort::{ session::{builder::GraphOptimizationLevel, Session}, value::Value, }; -use rayon::{ - iter::{FromParallelIterator, ParallelIterator}, - slice::ParallelSlice, -}; #[cfg(feature = "hf-hub")] use std::path::PathBuf; use std::thread::available_parallelism; @@ -260,15 +256,11 @@ impl TextEmbedding { /// arrays are aggregated, you can define your own array transformer /// and use it on [`EmbeddingOutput::export_with_transformer`] to extract the /// embeddings with your custom output type. - pub fn transform<'e, 'r, 's, S: AsRef + Send + Sync>( - &'e self, + pub fn transform + Send + Sync>( + &mut self, texts: Vec, batch_size: Option, - ) -> Result> - where - 'e: 'r, - 'e: 's, - { + ) -> Result { // Determine the batch size according to the quantization method used. // Default if not specified let batch_size = match self.quantization { @@ -292,70 +284,68 @@ impl TextEmbedding { _ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)), }?; - let batches = Result::>::from_par_iter(texts.par_chunks(batch_size).map(|batch| { - // Encode the texts in the batch - let inputs = batch.iter().map(|text| text.as_ref()).collect(); - let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { - anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") - })?; - - // Extract the encoding length and batch size - let encoding_length = encodings[0].len(); - let batch_size = batch.len(); - - let max_size = encoding_length * batch_size; - - // Preallocate arrays with the maximum size - let mut ids_array = Vec::with_capacity(max_size); - let mut mask_array = Vec::with_capacity(max_size); - let mut type_ids_array = Vec::with_capacity(max_size); - - // Not using par_iter because the closure needs to be FnMut - encodings.iter().for_each(|encoding| { - let ids = encoding.get_ids(); - let mask = encoding.get_attention_mask(); - let type_ids = encoding.get_type_ids(); - - // Extend the preallocated arrays with the current encoding - // Requires the closure to be FnMut - ids_array.extend(ids.iter().map(|x| *x as i64)); - mask_array.extend(mask.iter().map(|x| *x as i64)); - type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); - }); - - // Create CowArrays from vectors - let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?; - - let attention_mask_array = - Array::from_shape_vec((batch_size, encoding_length), mask_array)?; - - let token_type_ids_array = - Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; - - let mut session_inputs = ort::inputs![ - "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(attention_mask_array.view())?, - ]?; - - if self.need_token_type_ids { - session_inputs.push(( - "token_type_ids".into(), - Value::from_array(token_type_ids_array)?.into(), - )); - } + let batches = texts + .chunks(batch_size) + .map(|batch| { + // Encode the texts in the batch + let inputs = batch.iter().map(|text| text.as_ref()).collect(); + let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { + anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") + })?; + + // Extract the encoding length and batch size + let encoding_length = encodings[0].len(); + let batch_size = batch.len(); + + let max_size = encoding_length * batch_size; + + // Preallocate arrays with the maximum size + let mut ids_array = Vec::with_capacity(max_size); + let mut mask_array = Vec::with_capacity(max_size); + let mut type_ids_array = Vec::with_capacity(max_size); + + encodings.iter().for_each(|encoding| { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + let type_ids = encoding.get_type_ids(); + + ids_array.extend(ids.iter().map(|x| *x as i64)); + mask_array.extend(mask.iter().map(|x| *x as i64)); + type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); + }); + + let inputs_ids_array = + Array::from_shape_vec((batch_size, encoding_length), ids_array)?; + let attention_mask_array = + Array::from_shape_vec((batch_size, encoding_length), mask_array)?; + let token_type_ids_array = + Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; + + let mut session_inputs = ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array.clone())?, + ]; + + if self.need_token_type_ids { + session_inputs.push(( + "token_type_ids".into(), + Value::from_array(token_type_ids_array)?.into(), + )); + } - Ok( - // Package all the data required for post-processing (e.g. pooling) - // into a SingleBatchOutput struct. - SingleBatchOutput { - session_outputs: self - .session - .run(session_inputs) - .map_err(anyhow::Error::new)?, + let outputs_map = self + .session + .run(session_inputs) + .map_err(anyhow::Error::new)? + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + Ok(SingleBatchOutput { + outputs: outputs_map, attention_mask_array, - }, - ) - }))?; + }) + }) + .collect::>>()?; Ok(EmbeddingOutput::new(batches)) } @@ -372,7 +362,7 @@ impl TextEmbedding { /// This method is a higher level method than [`TextEmbedding::transform`] by utilizing /// the default output precedence and array transformer for the [`TextEmbedding`] model. pub fn embed + Send + Sync>( - &self, + &mut self, texts: Vec, batch_size: Option, ) -> Result> { diff --git a/tests/embeddings.rs b/tests/embeddings.rs index cb0a564..900b57e 100644 --- a/tests/embeddings.rs +++ b/tests/embeddings.rs @@ -4,14 +4,13 @@ use std::fs; use std::path::Path; use hf_hub::Repo; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use fastembed::{ - get_cache_dir, read_file_to_bytes, Embedding, EmbeddingModel, ImageEmbedding, - ImageEmbeddingModel, ImageInitOptions, InitOptions, InitOptionsUserDefined, ModelInfo, - OnnxSource, Pooling, QuantizationMode, RerankInitOptions, RerankInitOptionsUserDefined, - RerankerModel, RerankerModelInfo, SparseInitOptions, SparseTextEmbedding, TextEmbedding, - TextRerank, TokenizerFiles, UserDefinedEmbeddingModel, UserDefinedRerankingModel, + get_cache_dir, Embedding, EmbeddingModel, ImageEmbedding, ImageEmbeddingModel, + ImageInitOptions, InitOptions, InitOptionsUserDefined, ModelInfo, OnnxSource, Pooling, + QuantizationMode, RerankInitOptions, RerankInitOptionsUserDefined, RerankerModel, + RerankerModelInfo, SparseInitOptions, SparseTextEmbedding, TextEmbedding, TextRerank, + TokenizerFiles, UserDefinedEmbeddingModel, UserDefinedRerankingModel, }; /// A small epsilon value for floating point comparisons. @@ -104,9 +103,9 @@ macro_rules! create_embeddings_test { #[test] fn $name() { TextEmbedding::list_supported_models() - .par_iter() + .iter() .for_each(|supported_model| { - let model: TextEmbedding = TextEmbedding::try_new(InitOptions::new(supported_model.model.clone())) + let mut model: TextEmbedding = TextEmbedding::try_new(InitOptions::new(supported_model.model.clone())) .unwrap(); let documents = vec![ @@ -162,17 +161,12 @@ create_embeddings_test!( batch_size: None, ); -create_embeddings_test!( - name: test_with_batch_size, - batch_size: Some(70), -); - #[test] fn test_sparse_embeddings() { SparseTextEmbedding::list_supported_models() - .par_iter() + .iter() .for_each(|supported_model| { - let model: SparseTextEmbedding = + let mut model: SparseTextEmbedding = SparseTextEmbedding::try_new(SparseInitOptions::new(supported_model.model.clone())) .unwrap(); @@ -224,8 +218,8 @@ fn test_user_defined_embedding_model() { .path(); // Find the onnx file - it will be any file ending with .onnx - let onnx_file = read_file_to_bytes( - &model_files_dir + let onnx_file = std::fs::read( + model_files_dir .read_dir() .unwrap() .find(|entry| { @@ -247,15 +241,13 @@ fn test_user_defined_embedding_model() { // Load the tokenizer files let tokenizer_files = TokenizerFiles { - tokenizer_file: read_file_to_bytes(&model_files_dir.join("tokenizer.json")) + tokenizer_file: std::fs::read(model_files_dir.join("tokenizer.json")) .expect("Could not read tokenizer.json"), - config_file: read_file_to_bytes(&model_files_dir.join("config.json")) + config_file: std::fs::read(model_files_dir.join("config.json")) .expect("Could not read config.json"), - special_tokens_map_file: read_file_to_bytes( - &model_files_dir.join("special_tokens_map.json"), - ) - .expect("Could not read special_tokens_map.json"), - tokenizer_config_file: read_file_to_bytes(&model_files_dir.join("tokenizer_config.json")) + special_tokens_map_file: std::fs::read(model_files_dir.join("special_tokens_map.json")) + .expect("Could not read special_tokens_map.json"), + tokenizer_config_file: std::fs::read(model_files_dir.join("tokenizer_config.json")) .expect("Could not read tokenizer_config.json"), }; // Create a UserDefinedEmbeddingModel @@ -263,7 +255,7 @@ fn test_user_defined_embedding_model() { UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean); // Try creating a TextEmbedding instance from the user-defined model - let user_defined_text_embedding = TextEmbedding::try_new_from_user_defined( + let mut user_defined_text_embedding = TextEmbedding::try_new_from_user_defined( user_defined_model, InitOptionsUserDefined::default(), ) @@ -291,7 +283,7 @@ fn test_rerank() { let test_one_model = |supported_model: &RerankerModelInfo| { println!("supported_model: {:?}", supported_model); - let result = + let mut result = TextRerank::try_new(RerankInitOptions::new(supported_model.model.clone())).unwrap(); let documents = vec![ @@ -333,7 +325,7 @@ fn test_rerank() { clean_cache(supported_model.model_code.clone()) }; TextRerank::list_supported_models() - .par_iter() + .iter() .for_each(test_one_model); } @@ -358,22 +350,18 @@ fn test_user_defined_reranking_large_model() { // Load the tokenizer files let tokenizer_files: TokenizerFiles = TokenizerFiles { - tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json").unwrap()).unwrap(), - config_file: read_file_to_bytes(&model_repo.get("config.json").unwrap()).unwrap(), - special_tokens_map_file: read_file_to_bytes( - &model_repo.get("special_tokens_map.json").unwrap(), - ) - .unwrap(), - - tokenizer_config_file: read_file_to_bytes( - &model_repo.get("tokenizer_config.json").unwrap(), - ) - .unwrap(), + tokenizer_file: std::fs::read(model_repo.get("tokenizer.json").unwrap()).unwrap(), + config_file: std::fs::read(model_repo.get("config.json").unwrap()).unwrap(), + special_tokens_map_file: std::fs::read(model_repo.get("special_tokens_map.json").unwrap()) + .unwrap(), + + tokenizer_config_file: std::fs::read(model_repo.get("tokenizer_config.json").unwrap()) + .unwrap(), }; let model = UserDefinedRerankingModel::new(onnx_source, tokenizer_files); - let user_defined_reranker = + let mut user_defined_reranker = TextRerank::try_new_from_user_defined(model, Default::default()).unwrap(); let documents = vec![ @@ -416,8 +404,8 @@ fn test_user_defined_reranking_model() { .path(); // Find the onnx file - it will be any file in ./onnx ending with .onnx - let onnx_file = read_file_to_bytes( - &model_files_dir + let onnx_file = std::fs::read( + model_files_dir .join("onnx") .read_dir() .unwrap() @@ -440,22 +428,20 @@ fn test_user_defined_reranking_model() { // Load the tokenizer files let tokenizer_files = TokenizerFiles { - tokenizer_file: read_file_to_bytes(&model_files_dir.join("tokenizer.json")) + tokenizer_file: std::fs::read(model_files_dir.join("tokenizer.json")) .expect("Could not read tokenizer.json"), - config_file: read_file_to_bytes(&model_files_dir.join("config.json")) + config_file: std::fs::read(model_files_dir.join("config.json")) .expect("Could not read config.json"), - special_tokens_map_file: read_file_to_bytes( - &model_files_dir.join("special_tokens_map.json"), - ) - .expect("Could not read special_tokens_map.json"), - tokenizer_config_file: read_file_to_bytes(&model_files_dir.join("tokenizer_config.json")) + special_tokens_map_file: std::fs::read(model_files_dir.join("special_tokens_map.json")) + .expect("Could not read special_tokens_map.json"), + tokenizer_config_file: std::fs::read(model_files_dir.join("tokenizer_config.json")) .expect("Could not read tokenizer_config.json"), }; // Create a UserDefinedEmbeddingModel let user_defined_model = UserDefinedRerankingModel::new(onnx_file, tokenizer_files); // Try creating a TextEmbedding instance from the user-defined model - let user_defined_reranker = TextRerank::try_new_from_user_defined( + let mut user_defined_reranker = TextRerank::try_new_from_user_defined( user_defined_model, RerankInitOptionsUserDefined::default(), ) @@ -480,7 +466,7 @@ fn test_user_defined_reranking_model() { #[test] fn test_image_embedding_model() { let test_one_model = |supported_model: &ModelInfo| { - let model: ImageEmbedding = + let mut model: ImageEmbedding = ImageEmbedding::try_new(ImageInitOptions::new(supported_model.model.clone())).unwrap(); let images = vec!["tests/assets/image_0.png", "tests/assets/image_1.png"]; @@ -491,7 +477,7 @@ fn test_image_embedding_model() { assert_eq!(embeddings.len(), images.len()); }; ImageEmbedding::list_supported_models() - .par_iter() + .iter() .for_each(test_one_model); } @@ -523,7 +509,7 @@ fn test_nomic_embed_vision_v1_5() { // Test the NomicEmbedVisionV15 model specifically because it outputs a 3D tensor with a different // output key ('last_hidden_state') compared to other models. This test ensures our tensor extraction // logic can handle both standard output keys and this model's specific naming convention. - let image_model = ImageEmbedding::try_new(ImageInitOptions::new( + let mut image_model = ImageEmbedding::try_new(ImageInitOptions::new( fastembed::ImageEmbeddingModel::NomicEmbedVisionV15, )) .unwrap(); @@ -534,7 +520,7 @@ fn test_nomic_embed_vision_v1_5() { let image_embeddings = image_model.embed(images.clone(), None).unwrap(); assert_eq!(image_embeddings.len(), images.len()); - let text_model = TextEmbedding::try_new(InitOptions::new( + let mut text_model = TextEmbedding::try_new(InitOptions::new( fastembed::EmbeddingModel::NomicEmbedTextV15, )) .unwrap(); @@ -566,7 +552,7 @@ fn get_sample_text() -> String { #[test] fn test_batch_size_does_not_change_output() { - let model = TextEmbedding::try_new( + let mut model = TextEmbedding::try_new( InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_max_length(384), ) .expect("Create model successfully"); @@ -598,7 +584,7 @@ fn test_batch_size_does_not_change_output() { #[test] fn test_bgesmallen1point5_match_python_counterpart() { - let model = TextEmbedding::try_new( + let mut model = TextEmbedding::try_new( InitOptions::new(EmbeddingModel::BGESmallENV15).with_max_length(384), ) .expect("Create model successfully"); @@ -637,7 +623,7 @@ fn test_bgesmallen1point5_match_python_counterpart() { #[test] fn test_allminilml6v2_match_python_counterpart() { - let model = TextEmbedding::try_new( + let mut model = TextEmbedding::try_new( InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_max_length(384), ) .expect("Create model successfully"); From 1dc9968d4db3891c19504ec4f44165dfe4c3f8c1 Mon Sep 17 00:00:00 2001 From: Anush Date: Mon, 7 Jul 2025 10:04:36 +0000 Subject: [PATCH 2/2] chore(release): 5.0.0 [skip ci] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # [5.0.0](https://github.com/Anush008/fastembed-rs/compare/v4.9.1...v5.0.0) (2025-07-07) ## [5.0.0](https://github.com/Anush008/fastembed-rs/compare/v4.9.1...v5.0.0) (2025-07-07) ### ⚠ BREAKING CHANGES * Upated ort usage (#170) ### 🧑‍💻 Code Refactoring * Upated ort usage ([#170](https://github.com/Anush008/fastembed-rs/issues/170)) ([3489e69](https://github.com/Anush008/fastembed-rs/commit/3489e69b6463e811158318da4e56f3a2dc0c8fbb)) --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9e3690a..d3ee71a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fastembed" -version = "4.9.1" +version = "5.0.0" edition = "2021" description = "Library for generating vector embeddings, reranking locally." license = "Apache-2.0"