diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..4af2269 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "cargo" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/Cargo.toml b/Cargo.toml index 3328646..8e87310 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fastembed" -version = "5.1.0" +version = "5.2.0" edition = "2021" description = "Library for generating vector embeddings, reranking locally." license = "Apache-2.0" @@ -30,7 +30,7 @@ ort = { version = "=2.0.0-rc.10", default-features = false, features = [ "ndarray", "std" ] } serde_json = { version = "1" } -tokenizers = { version = "0.21.2", default-features = false, features = ["onig"] } +tokenizers = { version = "0.22.0", default-features = false, features = ["onig"] } [features] default = ["ort-download-binaries", "hf-hub-native-tls"] diff --git a/src/common.rs b/src/common.rs index 5023c94..ba820c3 100644 --- a/src/common.rs +++ b/src/common.rs @@ -112,20 +112,36 @@ pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Res if let serde_json::Value::Object(root_object) = special_tokens_map { for (_, value) in root_object.iter() { if value.is_string() { - tokenizer.add_special_tokens(&[AddedToken { - content: value.as_str().unwrap().into(), - special: true, - ..Default::default() - }]); + if let Some(content) = value.as_str() { + tokenizer.add_special_tokens(&[AddedToken { + content: content.into(), + special: true, + ..Default::default() + }]); + } } else if value.is_object() { - tokenizer.add_special_tokens(&[AddedToken { - content: value["content"].as_str().unwrap().into(), - special: true, - single_word: value["single_word"].as_bool().unwrap(), - lstrip: value["lstrip"].as_bool().unwrap(), - rstrip: value["rstrip"].as_bool().unwrap(), - normalized: value["normalized"].as_bool().unwrap(), - }]); + if let ( + Some(content), + Some(single_word), + Some(lstrip), + Some(rstrip), + Some(normalized), + ) = ( + value["content"].as_str(), + value["single_word"].as_bool(), + value["lstrip"].as_bool(), + value["rstrip"].as_bool(), + value["normalized"].as_bool(), + ) { + tokenizer.add_special_tokens(&[AddedToken { + content: content.into(), + special: true, + single_word, + lstrip, + rstrip, + normalized, + }]); + } } } } diff --git a/src/models/image_embedding.rs b/src/models/image_embedding.rs index 4913de6..59a4fde 100644 --- a/src/models/image_embedding.rs +++ b/src/models/image_embedding.rs @@ -26,6 +26,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/clip-ViT-B-32-vision"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::Resnet50, @@ -34,6 +35,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/resnet50-onnx"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::UnicomVitB16, @@ -42,6 +44,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/Unicom-ViT-B-16"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::UnicomVitB32, @@ -50,6 +53,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/Unicom-ViT-B-32"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::NomicEmbedVisionV15, @@ -58,6 +62,7 @@ pub fn models_list() -> Vec> { model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ]; diff --git a/src/models/model_info.rs b/src/models/model_info.rs index 4d7b184..f65170c 100644 --- a/src/models/model_info.rs +++ b/src/models/model_info.rs @@ -1,7 +1,8 @@ -use crate::RerankerModel; +use crate::{OutputKey, RerankerModel}; /// Data struct about the available models #[derive(Debug, Clone)] +#[non_exhaustive] pub struct ModelInfo { pub model: T, pub dim: usize, @@ -9,10 +10,12 @@ pub struct ModelInfo { pub model_code: String, pub model_file: String, pub additional_files: Vec, + pub output_key: Option, } /// Data struct about the available reranker models #[derive(Debug, Clone)] +#[non_exhaustive] pub struct RerankerModelInfo { pub model: RerankerModel, pub description: String, diff --git a/src/models/sparse.rs b/src/models/sparse.rs index a89f93c..1b52d2f 100644 --- a/src/models/sparse.rs +++ b/src/models/sparse.rs @@ -17,6 +17,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/Splade_PP_en_v1"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }] } diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index 56da9a1..8292872 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -68,6 +68,8 @@ pub enum EmbeddingModel { ClipVitB32, /// jinaai/jina-embeddings-v2-base-code JinaEmbeddingsV2BaseCode, + /// onnx-community/embeddinggemma-300m-ONNX + EmbeddingGemma300M, } /// Centralized function to initialize the models map. @@ -80,6 +82,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/all-MiniLM-L6-v2-onnx"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::AllMiniLML6V2Q, @@ -88,6 +91,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/all-MiniLM-L6-v2"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::AllMiniLML12V2, @@ -96,6 +100,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/all-MiniLM-L12-v2"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::AllMiniLML12V2Q, @@ -104,6 +109,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/all-MiniLM-L12-v2"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGEBaseENV15, @@ -112,6 +118,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-base-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGEBaseENV15Q, @@ -120,6 +127,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/bge-base-en-v1.5-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGELargeENV15, @@ -128,6 +136,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-large-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGELargeENV15Q, @@ -136,6 +145,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/bge-large-en-v1.5-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGESmallENV15, @@ -144,6 +154,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-small-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGESmallENV15Q, @@ -154,6 +165,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/bge-small-en-v1.5-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::NomicEmbedTextV1, @@ -162,6 +174,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("nomic-ai/nomic-embed-text-v1"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::NomicEmbedTextV15, @@ -170,6 +183,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("nomic-ai/nomic-embed-text-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::NomicEmbedTextV15Q, @@ -180,6 +194,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("nomic-ai/nomic-embed-text-v1.5"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ParaphraseMLMiniLML12V2Q, @@ -188,6 +203,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ParaphraseMLMiniLML12V2, @@ -196,6 +212,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/paraphrase-multilingual-MiniLM-L12-v2"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ParaphraseMLMpnetBaseV2, @@ -206,6 +223,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/paraphrase-multilingual-mpnet-base-v2"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGESmallZHV15, @@ -214,6 +232,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-small-zh-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGELargeZHV15, @@ -222,6 +241,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-large-zh-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ModernBertEmbedLarge, @@ -230,6 +250,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("lightonai/modernbert-embed-large"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MultilingualE5Small, @@ -238,6 +259,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("intfloat/multilingual-e5-small"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MultilingualE5Base, @@ -246,6 +268,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("intfloat/multilingual-e5-base"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MultilingualE5Large, @@ -254,6 +277,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/multilingual-e5-large-onnx"), model_file: String::from("model.onnx"), additional_files: vec!["model.onnx_data".to_string()], + output_key: None, }, ModelInfo { model: EmbeddingModel::MxbaiEmbedLargeV1, @@ -262,6 +286,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MxbaiEmbedLargeV1Q, @@ -270,6 +295,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTEBaseENV15, @@ -278,6 +304,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTEBaseENV15Q, @@ -286,6 +313,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTELargeENV15, @@ -294,6 +322,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTELargeENV15Q, @@ -302,6 +331,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ClipVitB32, @@ -310,6 +340,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/clip-ViT-B-32-text"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::JinaEmbeddingsV2BaseCode, @@ -318,6 +349,16 @@ fn init_models_map() -> HashMap> { model_code: String::from("jinaai/jina-embeddings-v2-base-code"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, + }, + ModelInfo { + model: EmbeddingModel::EmbeddingGemma300M, + dim: 768, + description: String::from("EmbeddingGemma is a 300M parameter from Google"), + model_code: String::from("onnx-community/embeddinggemma-300m-ONNX"), + model_file: String::from("onnx/model.onnx"), + additional_files: vec!["onnx/model.onnx_data".to_string()], + output_key: Some(crate::OutputKey::ByName("sentence_embedding")), }, ]; diff --git a/src/output/output_precedence.rs b/src/output/output_precedence.rs index ead5bb7..5dcc6a3 100644 --- a/src/output/output_precedence.rs +++ b/src/output/output_precedence.rs @@ -7,7 +7,7 @@ //! e.g. reading the output keys from the model file. /// Enum for defining the key of the output. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum OutputKey { OnlyOne, ByOrder(usize), @@ -41,3 +41,9 @@ impl OutputPrecedence for &[OutputKey] { self.iter() } } + +impl OutputPrecedence for &OutputKey { + fn key_precedence(&self) -> impl Iterator { + std::iter::once(*self) + } +} diff --git a/src/pooling.rs b/src/pooling.rs index cd0e6b3..9169f61 100644 --- a/src/pooling.rs +++ b/src/pooling.rs @@ -59,13 +59,13 @@ pub fn mean( let attention_mask = attention_mask_array .insert_axis(ndarray::Axis(2)) .broadcast(token_embeddings.dim()) - .unwrap_or_else(|| { - panic!( + .ok_or_else(|| { + anyhow::Error::msg(format!( "Could not broadcast attention mask from {:?} to {:?}", attention_mask_original_dim, token_embeddings.dim() - ) - }) + )) + })? .mapv(|x| x as f32); let masked_tensor = &attention_mask * &token_embeddings; diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index f8c29db..189010b 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -119,7 +119,9 @@ impl SparseTextEmbedding { .map(|batch| { // Encode the texts in the batch let inputs = batch.iter().map(|text| text.as_ref()).collect(); - let encodings = self.tokenizer.encode_batch(inputs, true).unwrap(); + let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { + anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") + })?; // Extract the encoding length and batch size let encoding_length = encodings[0].len(); diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index e85c942..3a0fdc0 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -3,8 +3,10 @@ #[cfg(feature = "hf-hub")] use crate::common::load_tokenizer_hf_hub; use crate::{ - common::load_tokenizer, models::text_embedding::models_list, models::ModelTrait, - pooling::Pooling, Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, + common::load_tokenizer, + models::{text_embedding::models_list, ModelTrait}, + pooling::Pooling, + Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, OutputKey, QuantizationMode, SingleBatchOutput, }; #[cfg(feature = "hf-hub")] @@ -80,6 +82,7 @@ impl TextEmbedding { session, post_processing, TextEmbedding::get_quantization_mode(&model_name), + model_info.output_key.clone(), )) } @@ -109,6 +112,7 @@ impl TextEmbedding { session, model.pooling, model.quantization, + model.output_key, )) } @@ -118,6 +122,7 @@ impl TextEmbedding { session: Session, post_process: Option, quantization: QuantizationMode, + output_key: Option, ) -> Self { let need_token_type_ids = session .inputs @@ -130,6 +135,7 @@ impl TextEmbedding { need_token_type_ids, pooling: post_process, quantization, + output_key, } } /// Return the TextEmbedding model's directory from cache or remote retrieval @@ -185,6 +191,8 @@ impl TextEmbedding { EmbeddingModel::ClipVitB32 => Some(Pooling::Mean), EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean), + + EmbeddingModel::EmbeddingGemma300M => Some(Pooling::Mean), } } @@ -365,10 +373,16 @@ impl TextEmbedding { batch_size: Option, ) -> Result> { let batches = self.transform(texts, batch_size)?; - - batches.export_with_transformer(output::transformer_with_precedence( - output::OUTPUT_TYPE_PRECEDENCE, - self.pooling.clone(), - )) + if let Some(output_key) = &self.output_key { + batches.export_with_transformer(output::transformer_with_precedence( + output_key, + self.pooling.clone(), + )) + } else { + batches.export_with_transformer(output::transformer_with_precedence( + output::OUTPUT_TYPE_PRECEDENCE, + self.pooling.clone(), + )) + } } } diff --git a/src/text_embedding/init.rs b/src/text_embedding/init.rs index 2e4393f..54b28fe 100644 --- a/src/text_embedding/init.rs +++ b/src/text_embedding/init.rs @@ -5,7 +5,7 @@ use crate::{ common::TokenizerFiles, init::{HasMaxLength, InitOptionsWithLength}, pooling::Pooling, - EmbeddingModel, QuantizationMode, + EmbeddingModel, OutputKey, QuantizationMode, }; use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; use tokenizers::Tokenizer; @@ -80,6 +80,7 @@ pub struct UserDefinedEmbeddingModel { pub tokenizer_files: TokenizerFiles, pub pooling: Option, pub quantization: QuantizationMode, + pub output_key: Option, } impl UserDefinedEmbeddingModel { @@ -89,6 +90,7 @@ impl UserDefinedEmbeddingModel { tokenizer_files, quantization: QuantizationMode::None, pooling: None, + output_key: None, } } @@ -110,4 +112,5 @@ pub struct TextEmbedding { pub(crate) session: Session, pub(crate) need_token_type_ids: bool, pub(crate) quantization: QuantizationMode, + pub(crate) output_key: Option, } diff --git a/tests/embeddings.rs b/tests/embeddings.rs index 70e7772..6ef5616 100644 --- a/tests/embeddings.rs +++ b/tests/embeddings.rs @@ -64,6 +64,7 @@ fn verify_embeddings(model: &EmbeddingModel, embeddings: &[Embedding]) -> Result EmbeddingModel::ParaphraseMLMpnetBaseV2 => [0.39132136, 0.49490625, 0.65497226, 0.34237382], EmbeddingModel::ClipVitB32 => [0.7057363, 1.3549932, 0.46823958, 0.52351093], EmbeddingModel::JinaEmbeddingsV2BaseCode => [-0.31383067, -0.3758629, -0.24878195, -0.35373706], + EmbeddingModel::EmbeddingGemma300M => [0.22703816, 0.6947083, 0.07579082, 1.6958784], _ => panic!("Model {model} not found. If you have just inserted this `EmbeddingModel` variant, please update the expected embeddings."), };