Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Res
let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32;
let pad_token = tokenizer_config["pad_token"]
.as_str()
.expect("Error reading pad_token from tokenier_config.json")
.expect("Error reading pad_token from tokenizer_config.json")
.into();

let mut tokenizer = tokenizer
Expand Down
32 changes: 21 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,33 @@ pub use ort::execution_providers::ExecutionProviderDispatch;
pub use crate::common::{
read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, DEFAULT_CACHE_DIR,
};
pub use crate::image_embedding::{
ImageEmbedding, ImageInitOptions, ImageInitOptionsUserDefined, UserDefinedImageEmbeddingModel,
};
pub use crate::models::image_embedding::ImageEmbeddingModel;
pub use crate::models::reranking::{RerankerModel, RerankerModelInfo};
pub use crate::models::{
model_info::ModelInfo, quantization::QuantizationMode, text_embedding::EmbeddingModel,
model_info::ModelInfo, model_info::RerankerModelInfo, quantization::QuantizationMode,
};
pub use crate::output::{EmbeddingOutput, OutputKey, OutputPrecedence, SingleBatchOutput};
pub use crate::pooling::Pooling;
pub use crate::reranking::{
OnnxSource, RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank,
UserDefinedRerankingModel,

// For Text Embedding
pub use crate::models::text_embedding::EmbeddingModel;
pub use crate::text_embedding::{
InitOptions, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel,
};

// For Sparse Text Embedding
pub use crate::models::sparse::SparseModel;
pub use crate::sparse_text_embedding::{
SparseInitOptions, SparseTextEmbedding, UserDefinedSparseModel,
};
pub use crate::text_embedding::{
InitOptions, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel,

// For Image Embedding
pub use crate::image_embedding::{
ImageEmbedding, ImageInitOptions, ImageInitOptionsUserDefined, UserDefinedImageEmbeddingModel,
};
pub use crate::models::image_embedding::ImageEmbeddingModel;

// For Reranking
pub use crate::models::reranking::RerankerModel;
pub use crate::reranking::{
OnnxSource, RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank,
UserDefinedRerankingModel,
};
12 changes: 12 additions & 0 deletions src/models/model_info.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::RerankerModel;

/// Data struct about the available models
#[derive(Debug, Clone)]
pub struct ModelInfo<T> {
Expand All @@ -8,3 +10,13 @@ pub struct ModelInfo<T> {
pub model_file: String,
pub additional_files: Vec<String>,
}

/// Data struct about the available reranker models
#[derive(Debug, Clone)]
pub struct RerankerModelInfo {
pub model: RerankerModel,
pub description: String,
pub model_code: String,
pub model_file: String,
pub additional_files: Vec<String>,
}
12 changes: 2 additions & 10 deletions src/models/reranking.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Display;

use crate::RerankerModelInfo;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RerankerModel {
/// BAAI/bge-reranker-base
Expand Down Expand Up @@ -46,16 +48,6 @@ pub fn reranker_model_list() -> Vec<RerankerModelInfo> {
reranker_model_list
}

/// Data struct about the available reanker models
#[derive(Debug, Clone)]
pub struct RerankerModelInfo {
pub model: RerankerModel,
pub description: String,
pub model_code: String,
pub model_file: String,
pub additional_files: Vec<String>,
}

impl Display for RerankerModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = reranker_model_list()
Expand Down
45 changes: 1 addition & 44 deletions src/models/sparse.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::fmt::Display;

use crate::{common::SparseEmbedding, ModelInfo};
use ndarray::{ArrayViewD, Axis, CowArray, Dim};
use crate::ModelInfo;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SparseModel {
Expand All @@ -20,48 +19,6 @@ pub fn models_list() -> Vec<ModelInfo<SparseModel>> {
}]
}

impl SparseModel {
pub fn post_process(
&self,
model_output: &ArrayViewD<f32>,
attention_mask: &CowArray<i64, Dim<[usize; 2]>>,
) -> Vec<SparseEmbedding> {
match self {
SparseModel::SPLADEPPV1 => {
// Apply ReLU and logarithm transformation
let relu_log = model_output.mapv(|x| (1.0 + x.max(0.0)).ln());

// Convert to f32 and expand the dimensions
let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2));

// Weight the transformed values by the attention mask
let weighted_log = relu_log * attention_mask;

// Get the max scores
let scores = weighted_log.fold_axis(Axis(1), f32::NEG_INFINITY, |r, &v| r.max(v));

scores
.rows()
.into_iter()
.map(|row_scores| {
let mut values: Vec<f32> = Vec::with_capacity(scores.len());
let mut indices: Vec<usize> = Vec::with_capacity(scores.len());

row_scores.into_iter().enumerate().for_each(|(idx, f)| {
if *f > 0.0 {
values.push(*f);
indices.push(idx);
}
});

SparseEmbedding { values, indices }
})
.collect()
}
}
}
}

impl Display for SparseModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = models_list()
Expand Down
78 changes: 1 addition & 77 deletions src/models/text_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use crate::pooling::Pooling;
use std::{collections::HashMap, fmt::Display, sync::OnceLock};

use super::model_info::ModelInfo;

use super::quantization::QuantizationMode;

use std::{collections::HashMap, fmt::Display, sync::OnceLock};

/// Lazy static list of all available models.
static MODEL_MAP: OnceLock<HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>>> = OnceLock::new();

Expand Down Expand Up @@ -338,78 +334,6 @@ pub fn models_list() -> Vec<ModelInfo<EmbeddingModel>> {
models_map().values().cloned().collect()
}

impl EmbeddingModel {
pub fn get_default_pooling_method(&self) -> Option<Pooling> {
match self {
EmbeddingModel::AllMiniLML6V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML6V2Q => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2Q => Some(Pooling::Mean),

EmbeddingModel::BGEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::BGEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15 => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallZHV15 => Some(Pooling::Cls),

EmbeddingModel::NomicEmbedTextV1 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15Q => Some(Pooling::Mean),

EmbeddingModel::ParaphraseMLMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMiniLML12V2Q => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMpnetBaseV2 => Some(Pooling::Mean),

EmbeddingModel::MultilingualE5Base => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Small => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Large => Some(Pooling::Mean),

EmbeddingModel::MxbaiEmbedLargeV1 => Some(Pooling::Cls),
EmbeddingModel::MxbaiEmbedLargeV1Q => Some(Pooling::Cls),

EmbeddingModel::GTEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::GTEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls),

EmbeddingModel::ClipVitB32 => Some(Pooling::Mean),

EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean),
}
}

/// Get the quantization mode of the model.
///
/// Any models with a `Q` suffix in their name are quantized models.
///
/// Currently only 6 supported models have dynamic quantization:
/// - Alibaba-NLP/gte-base-en-v1.5
/// - Alibaba-NLP/gte-large-en-v1.5
/// - mixedbread-ai/mxbai-embed-large-v1
/// - nomic-ai/nomic-embed-text-v1.5
/// - Xenova/all-MiniLM-L12-v2
/// - Xenova/all-MiniLM-L6-v2
///
// TODO: Update this list when more models are added
pub fn get_quantization_mode(&self) -> QuantizationMode {
match self {
EmbeddingModel::AllMiniLML6V2Q => QuantizationMode::Dynamic,
EmbeddingModel::AllMiniLML12V2Q => QuantizationMode::Dynamic,
EmbeddingModel::BGEBaseENV15Q => QuantizationMode::Static,
EmbeddingModel::BGELargeENV15Q => QuantizationMode::Static,
EmbeddingModel::BGESmallENV15Q => QuantizationMode::Static,
EmbeddingModel::NomicEmbedTextV15Q => QuantizationMode::Dynamic,
EmbeddingModel::ParaphraseMLMiniLML12V2Q => QuantizationMode::Static,
EmbeddingModel::MxbaiEmbedLargeV1Q => QuantizationMode::Dynamic,
EmbeddingModel::GTEBaseENV15Q => QuantizationMode::Dynamic,
EmbeddingModel::GTELargeENV15Q => QuantizationMode::Dynamic,
_ => QuantizationMode::None,
}
}
}

impl Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = get_model_info(self).expect("Model not found.");
Expand Down
2 changes: 1 addition & 1 deletion src/output/embedding_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl SingleBatchOutput<'_, '_> {

// If there is none pooling, default to cls so as not to break the existing implementations
// TODO: Consider return output as is to support custom model that has built-in pooling layer:
// - [] Add model with built-in pooling to the list of supported model in ``models::text_embdding::models_list``
// - [] Add model with built-in pooling to the list of supported model in ``models::text_embedding::models_list``
// - [] Write unit test for new model
// - [] Update ``pooling::Pooling`` to include None type
// - [] Change the line below to return output as is
Expand Down
10 changes: 5 additions & 5 deletions src/reranking/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl TextRerank {
Ok(Self::new(tokenizer, session))
}

/// Reranks documents using the reranker model and returns the results sorted by score in descending order.
/// Rerank documents using the reranker model and returns the results sorted by score in descending order.
pub fn rerank<S: AsRef<str> + Send + Sync>(
&self,
query: S,
Expand Down Expand Up @@ -151,16 +151,16 @@ impl TextRerank {

let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);
let mut type_ids_array = Vec::with_capacity(max_size);

encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();
let type_ids = encoding.get_type_ids();

ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
type_ids_array.extend(type_ids.iter().map(|x| *x as i64));
});

let inputs_ids_array =
Expand All @@ -170,7 +170,7 @@ impl TextRerank {
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;
Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
Expand Down
56 changes: 50 additions & 6 deletions src/sparse_text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use hf_hub::{
api::sync::{ApiBuilder, ApiRepo},
Cache,
};
use ndarray::{Array, CowArray};
use ndarray::{Array, ArrayViewD, Axis, CowArray, Dim};
use ort::{session::Session, value::Value};
#[cfg_attr(not(feature = "online"), allow(unused_imports))]
use rayon::{iter::ParallelIterator, slice::ParallelSlice};
Expand Down Expand Up @@ -138,19 +138,19 @@ impl SparseTextEmbedding {
// Preallocate arrays with the maximum size
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);
let mut type_ids_array = Vec::with_capacity(max_size);

// Not using par_iter because the closure needs to be FnMut
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();
let type_ids = encoding.get_type_ids();

// Extend the preallocated arrays with the current encoding
// Requires the closure to be FnMut
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
type_ids_array.extend(type_ids.iter().map(|x| *x as i64));
});

// Create CowArrays from vectors
Expand All @@ -161,7 +161,7 @@ impl SparseTextEmbedding {
let attention_mask_array = CowArray::from(&owned_attention_mask);

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;
Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
Expand All @@ -186,7 +186,11 @@ impl SparseTextEmbedding {

let output_data = outputs[last_hidden_state_key].try_extract_tensor::<f32>()?;

let embeddings = self.model.post_process(&output_data, &attention_mask_array);
let embeddings = SparseTextEmbedding::post_process(
&self.model,
&output_data,
&attention_mask_array,
);

Ok(embeddings)
})
Expand All @@ -197,4 +201,44 @@ impl SparseTextEmbedding {

Ok(output)
}

fn post_process(
model_name: &SparseModel,
model_output: &ArrayViewD<f32>,
attention_mask: &CowArray<i64, Dim<[usize; 2]>>,
) -> Vec<SparseEmbedding> {
match model_name {
SparseModel::SPLADEPPV1 => {
// Apply ReLU and logarithm transformation
let relu_log = model_output.mapv(|x| (1.0 + x.max(0.0)).ln());

// Convert to f32 and expand the dimensions
let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2));

// Weight the transformed values by the attention mask
let weighted_log = relu_log * attention_mask;

// Get the max scores
let scores = weighted_log.fold_axis(Axis(1), f32::NEG_INFINITY, |r, &v| r.max(v));

scores
.rows()
.into_iter()
.map(|row_scores| {
let mut values: Vec<f32> = Vec::with_capacity(scores.len());
let mut indices: Vec<usize> = Vec::with_capacity(scores.len());

row_scores.into_iter().enumerate().for_each(|(idx, f)| {
if *f > 0.0 {
values.push(*f);
indices.push(idx);
}
});

SparseEmbedding { values, indices }
})
.collect()
}
}
}
}
Loading