From d261cceea39fc9c677a99b6e4ce270f43a3e06e4 Mon Sep 17 00:00:00 2001 From: Angel <77701490+cimandef@users.noreply.github.com> Date: Sat, 16 Aug 2025 18:35:08 +0000 Subject: [PATCH 1/2] refactor: create a common options interface (#179) * refactor: create a common struct & impl for embeddings * chore: Misc. clippy update Signed-off-by: Anush008 --------- Signed-off-by: Anush008 --- src/image_embedding/init.rs | 55 +------------- src/image_embedding/mod.rs | 2 - src/init.rs | 117 ++++++++++++++++++++++++++++++ src/lib.rs | 4 +- src/models/image_embedding.rs | 3 +- src/models/reranking.rs | 3 +- src/models/sparse.rs | 3 +- src/models/text_embedding.rs | 3 +- src/output/embedding_output.rs | 6 +- src/reranking/impl.rs | 2 +- src/reranking/init.rs | 68 +++-------------- src/reranking/mod.rs | 3 - src/sparse_text_embedding/impl.rs | 4 +- src/sparse_text_embedding/init.rs | 68 +++-------------- src/sparse_text_embedding/mod.rs | 3 - src/text_embedding/impl.rs | 9 +-- src/text_embedding/init.rs | 74 +++---------------- src/text_embedding/mod.rs | 3 - 18 files changed, 174 insertions(+), 256 deletions(-) create mode 100644 src/init.rs diff --git a/src/image_embedding/init.rs b/src/image_embedding/init.rs index 3c2113e..b8f8fd4 100644 --- a/src/image_embedding/init.rs +++ b/src/image_embedding/init.rs @@ -1,58 +1,9 @@ -use std::path::{Path, PathBuf}; - +use super::utils::Compose; +use crate::{init::InitOptions, ImageEmbeddingModel}; use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; -use crate::{get_cache_dir, ImageEmbeddingModel}; - -use super::{utils::Compose, DEFAULT_EMBEDDING_MODEL}; - /// Options for initializing the ImageEmbedding model -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct ImageInitOptions { - pub model_name: ImageEmbeddingModel, - pub execution_providers: Vec, - pub cache_dir: PathBuf, - pub show_download_progress: bool, -} - -impl ImageInitOptions { - pub fn new(model_name: ImageEmbeddingModel) -> Self { - Self { - model_name, - ..Default::default() - } - } - - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache_dir = cache_dir; - self - } - - pub fn with_execution_providers( - mut self, - execution_providers: Vec, - ) -> Self { - self.execution_providers = execution_providers; - self - } - - pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { - self.show_download_progress = show_download_progress; - self - } -} - -impl Default for ImageInitOptions { - fn default() -> Self { - Self { - model_name: DEFAULT_EMBEDDING_MODEL, - execution_providers: Default::default(), - cache_dir: Path::new(&get_cache_dir()).to_path_buf(), - show_download_progress: true, - } - } -} +pub type ImageInitOptions = InitOptions; /// Options for initializing UserDefinedImageEmbeddingModel /// diff --git a/src/image_embedding/mod.rs b/src/image_embedding/mod.rs index ac1645f..a205074 100644 --- a/src/image_embedding/mod.rs +++ b/src/image_embedding/mod.rs @@ -1,6 +1,4 @@ -use crate::models::image_embedding::ImageEmbeddingModel; const DEFAULT_BATCH_SIZE: usize = 256; -const DEFAULT_EMBEDDING_MODEL: ImageEmbeddingModel = ImageEmbeddingModel::ClipVitB32; mod utils; diff --git a/src/init.rs b/src/init.rs new file mode 100644 index 0000000..0fb8da1 --- /dev/null +++ b/src/init.rs @@ -0,0 +1,117 @@ +use crate::get_cache_dir; +use ort::execution_providers::ExecutionProviderDispatch; +use std::path::PathBuf; + +pub trait HasMaxLength { + const MAX_LENGTH: usize; +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct InitOptionsWithLength { + pub model_name: M, + pub execution_providers: Vec, + pub cache_dir: PathBuf, + pub show_download_progress: bool, + pub max_length: usize, +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct InitOptions { + pub model_name: M, + pub execution_providers: Vec, + pub cache_dir: PathBuf, + pub show_download_progress: bool, +} + +impl Default for InitOptionsWithLength { + fn default() -> Self { + Self { + model_name: M::default(), + execution_providers: Default::default(), + cache_dir: get_cache_dir().into(), + show_download_progress: true, + max_length: M::MAX_LENGTH, + } + } +} + +impl Default for InitOptions { + fn default() -> Self { + Self { + model_name: M::default(), + execution_providers: Default::default(), + cache_dir: get_cache_dir().into(), + show_download_progress: true, + } + } +} + +impl InitOptionsWithLength { + /// Crea a new InitOptionsWithLength with the given model name + pub fn new(model_name: M) -> Self { + Self { + model_name, + ..Default::default() + } + } + + /// Set the maximum maximum length + pub fn with_max_length(mut self, max_lenght: usize) -> Self { + self.max_length = max_lenght; + self + } + + /// Set the cache directory for the model file + pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { + self.cache_dir = cache_dir; + self + } + + /// Set the execution providers for the model + pub fn with_execution_providers( + mut self, + execution_providers: Vec, + ) -> Self { + self.execution_providers = execution_providers; + self + } + + /// Set whether to show download progress + pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { + self.show_download_progress = show_download_progress; + self + } +} + +impl InitOptions { + /// Crea a new InitOptionsWithLength with the given model name + pub fn new(model_name: M) -> Self { + Self { + model_name, + ..Default::default() + } + } + + /// Set the cache directory for the model file + pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { + self.cache_dir = cache_dir; + self + } + + /// Set the execution providers for the model + pub fn with_execution_providers( + mut self, + execution_providers: Vec, + ) -> Self { + self.execution_providers = execution_providers; + self + } + + /// Set whether to show download progress + pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { + self.show_download_progress = show_download_progress; + self + } +} diff --git a/src/lib.rs b/src/lib.rs index bdcf21a..5fc314c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,7 @@ mod common; mod image_embedding; +mod init; mod models; pub mod output; mod pooling; @@ -74,7 +75,8 @@ pub use crate::pooling::Pooling; // For Text Embedding pub use crate::models::text_embedding::EmbeddingModel; pub use crate::text_embedding::{ - InitOptions, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel, + InitOptionsUserDefined, TextEmbedding, TextInitOptions as InitOptions, + UserDefinedEmbeddingModel, }; // For Sparse Text Embedding diff --git a/src/models/image_embedding.rs b/src/models/image_embedding.rs index 185212c..4913de6 100644 --- a/src/models/image_embedding.rs +++ b/src/models/image_embedding.rs @@ -2,9 +2,10 @@ use std::{fmt::Display, str::FromStr}; use super::model_info::ModelInfo; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub enum ImageEmbeddingModel { /// Qdrant/clip-ViT-B-32-vision + #[default] ClipVitB32, /// Qdrant/resnet50-onnx Resnet50, diff --git a/src/models/reranking.rs b/src/models/reranking.rs index 5fb6b77..5b10da1 100644 --- a/src/models/reranking.rs +++ b/src/models/reranking.rs @@ -2,9 +2,10 @@ use std::{fmt::Display, str::FromStr}; use crate::RerankerModelInfo; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub enum RerankerModel { /// BAAI/bge-reranker-base + #[default] BGERerankerBase, /// rozgo/bge-reranker-v2-m3 BGERerankerV2M3, diff --git a/src/models/sparse.rs b/src/models/sparse.rs index f933a7a..a89f93c 100644 --- a/src/models/sparse.rs +++ b/src/models/sparse.rs @@ -2,9 +2,10 @@ use std::{fmt::Display, str::FromStr}; use crate::ModelInfo; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Clone, PartialEq, Eq)] pub enum SparseModel { /// prithivida/Splade_PP_en_v1 + #[default] SPLADEPPV1, } diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index fa5b242..1365af8 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -5,7 +5,7 @@ use super::model_info::ModelInfo; /// Lazy static list of all available models. static MODEL_MAP: OnceLock>> = OnceLock::new(); -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] pub enum EmbeddingModel { /// sentence-transformers/all-MiniLM-L6-v2 AllMiniLML6V2, @@ -24,6 +24,7 @@ pub enum EmbeddingModel { /// Quantized BAAI/bge-large-en-v1.5 BGELargeENV15Q, /// BAAI/bge-small-en-v1.5 - Default + #[default] BGESmallENV15, /// Quantized BAAI/bge-small-en-v1.5 BGESmallENV15Q, diff --git a/src/output/embedding_output.rs b/src/output/embedding_output.rs index 5dcc29d..1862cb6 100644 --- a/src/output/embedding_output.rs +++ b/src/output/embedding_output.rs @@ -22,7 +22,7 @@ impl SingleBatchOutput { pub fn select_output( &self, precedence: &impl OutputPrecedence, - ) -> anyhow::Result>> { + ) -> anyhow::Result>> { let ort_output: &ort::value::Value = precedence .key_precedence() .find_map(|key| match key { @@ -35,7 +35,9 @@ impl SingleBatchOutput { } } OutputKey::ByOrder(idx) => self.outputs.get(*idx).map(|(_, v)| v), - OutputKey::ByName(name) => self.outputs.iter().find(|(n, _)| n == name).map(|(_, v)| v), + OutputKey::ByName(name) => { + self.outputs.iter().find(|(n, _)| n == name).map(|(_, v)| v) + } }) .ok_or_else(|| { anyhow::Error::msg(format!( diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index 0c7455e..2c6ef41 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -54,9 +54,9 @@ impl TextRerank { use super::RerankInitOptions; let RerankInitOptions { + max_length, model_name, execution_providers, - max_length, cache_dir, show_download_progress, } = options; diff --git a/src/reranking/init.rs b/src/reranking/init.rs index e204814..8119e00 100644 --- a/src/reranking/init.rs +++ b/src/reranking/init.rs @@ -1,12 +1,12 @@ -use std::path::{Path, PathBuf}; - +use super::DEFAULT_MAX_LENGTH; +use crate::{ + init::{HasMaxLength, InitOptionsWithLength}, + RerankerModel, TokenizerFiles, +}; use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; +use std::path::PathBuf; use tokenizers::Tokenizer; -use crate::{common::get_cache_dir, RerankerModel, TokenizerFiles}; - -use super::{DEFAULT_MAX_LENGTH, DEFAULT_RE_RANKER_MODEL}; - #[derive(Debug)] pub struct TextRerank { pub tokenizer: Tokenizer, @@ -14,60 +14,12 @@ pub struct TextRerank { pub(crate) need_token_type_ids: bool, } -/// Options for initializing the reranking model -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct RerankInitOptions { - pub model_name: RerankerModel, - pub execution_providers: Vec, - pub max_length: usize, - pub cache_dir: PathBuf, - pub show_download_progress: bool, +impl HasMaxLength for RerankerModel { + const MAX_LENGTH: usize = DEFAULT_MAX_LENGTH; } -impl RerankInitOptions { - pub fn new(model_name: RerankerModel) -> Self { - Self { - model_name, - ..Default::default() - } - } - - pub fn with_max_length(mut self, max_length: usize) -> Self { - self.max_length = max_length; - self - } - - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache_dir = cache_dir; - self - } - - pub fn with_execution_providers( - mut self, - execution_providers: Vec, - ) -> Self { - self.execution_providers = execution_providers; - self - } - - pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { - self.show_download_progress = show_download_progress; - self - } -} - -impl Default for RerankInitOptions { - fn default() -> Self { - Self { - model_name: DEFAULT_RE_RANKER_MODEL, - execution_providers: Default::default(), - max_length: DEFAULT_MAX_LENGTH, - cache_dir: Path::new(&get_cache_dir()).to_path_buf(), - show_download_progress: true, - } - } -} +/// Options for initializing the reranking models +pub type RerankInitOptions = InitOptionsWithLength; /// Options for initializing UserDefinedRerankerModel /// diff --git a/src/reranking/mod.rs b/src/reranking/mod.rs index 86c5660..45e4452 100644 --- a/src/reranking/mod.rs +++ b/src/reranking/mod.rs @@ -1,6 +1,3 @@ -use crate::RerankerModel; - -const DEFAULT_RE_RANKER_MODEL: RerankerModel = RerankerModel::BGERerankerBase; const DEFAULT_MAX_LENGTH: usize = 512; const DEFAULT_BATCH_SIZE: usize = 256; diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index a55b7a9..f8c29db 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -35,11 +35,11 @@ impl SparseTextEmbedding { use ort::{session::builder::GraphOptimizationLevel, session::Session}; let SparseInitOptions { - model_name, - execution_providers, max_length, + model_name, cache_dir, show_download_progress, + execution_providers, } = options; let threads = available_parallelism()?.get(); diff --git a/src/sparse_text_embedding/init.rs b/src/sparse_text_embedding/init.rs index 59482a7..2992954 100644 --- a/src/sparse_text_embedding/init.rs +++ b/src/sparse_text_embedding/init.rs @@ -1,66 +1,20 @@ -use std::path::{Path, PathBuf}; - -use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; +use ort::session::Session; use tokenizers::Tokenizer; -use crate::{common::get_cache_dir, models::sparse::SparseModel, TokenizerFiles}; +use crate::{ + init::{HasMaxLength, InitOptionsWithLength}, + models::sparse::SparseModel, + TokenizerFiles, +}; -use super::{DEFAULT_EMBEDDING_MODEL, DEFAULT_MAX_LENGTH}; +use super::DEFAULT_MAX_LENGTH; -/// Options for initializing the SparseTextEmbedding model -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct SparseInitOptions { - pub model_name: SparseModel, - pub execution_providers: Vec, - pub max_length: usize, - pub cache_dir: PathBuf, - pub show_download_progress: bool, +impl HasMaxLength for SparseModel { + const MAX_LENGTH: usize = DEFAULT_MAX_LENGTH; } -impl SparseInitOptions { - pub fn new(model_name: SparseModel) -> Self { - Self { - model_name, - ..Default::default() - } - } - - pub fn with_max_length(mut self, max_length: usize) -> Self { - self.max_length = max_length; - self - } - - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache_dir = cache_dir; - self - } - - pub fn with_execution_providers( - mut self, - execution_providers: Vec, - ) -> Self { - self.execution_providers = execution_providers; - self - } - - pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { - self.show_download_progress = show_download_progress; - self - } -} - -impl Default for SparseInitOptions { - fn default() -> Self { - Self { - model_name: DEFAULT_EMBEDDING_MODEL, - execution_providers: Default::default(), - max_length: DEFAULT_MAX_LENGTH, - cache_dir: Path::new(&get_cache_dir()).to_path_buf(), - show_download_progress: true, - } - } -} +/// Options for initializing the SparseTextEmbedding model +pub type SparseInitOptions = InitOptionsWithLength; /// Struct for "bring your own" embedding models /// diff --git a/src/sparse_text_embedding/mod.rs b/src/sparse_text_embedding/mod.rs index 89e7a84..0cd8877 100644 --- a/src/sparse_text_embedding/mod.rs +++ b/src/sparse_text_embedding/mod.rs @@ -1,8 +1,5 @@ -use crate::models::sparse::SparseModel; - const DEFAULT_BATCH_SIZE: usize = 256; const DEFAULT_MAX_LENGTH: usize = 512; -const DEFAULT_EMBEDDING_MODEL: SparseModel = SparseModel::SPLADEPPV1; mod init; pub use init::*; diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 2a19a86..cd074b0 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -24,7 +24,7 @@ use std::thread::available_parallelism; use tokenizers::Tokenizer; #[cfg(feature = "hf-hub")] -use super::InitOptions; +use super::TextInitOptions; use super::{ output, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel, DEFAULT_BATCH_SIZE, }; @@ -36,15 +36,14 @@ impl TextEmbedding { /// /// Uses the total number of CPUs available as the number of intra-threads #[cfg(feature = "hf-hub")] - pub fn try_new(options: InitOptions) -> Result { - let InitOptions { + pub fn try_new(options: TextInitOptions) -> Result { + let TextInitOptions { + max_length, model_name, execution_providers, - max_length, cache_dir, show_download_progress, } = options; - let threads = available_parallelism()?.get(); let model_repo = TextEmbedding::retrieve_model( diff --git a/src/text_embedding/init.rs b/src/text_embedding/init.rs index 468c90a..2e4393f 100644 --- a/src/text_embedding/init.rs +++ b/src/text_embedding/init.rs @@ -2,73 +2,22 @@ //! use crate::{ - common::TokenizerFiles, get_cache_dir, pooling::Pooling, EmbeddingModel, QuantizationMode, + common::TokenizerFiles, + init::{HasMaxLength, InitOptionsWithLength}, + pooling::Pooling, + EmbeddingModel, QuantizationMode, }; use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; -use std::path::{Path, PathBuf}; use tokenizers::Tokenizer; -use super::{DEFAULT_EMBEDDING_MODEL, DEFAULT_MAX_LENGTH}; +use super::DEFAULT_MAX_LENGTH; -/// Options for initializing the TextEmbedding model -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct InitOptions { - pub model_name: EmbeddingModel, - pub execution_providers: Vec, - pub max_length: usize, - pub cache_dir: PathBuf, - pub show_download_progress: bool, +impl HasMaxLength for EmbeddingModel { + const MAX_LENGTH: usize = DEFAULT_MAX_LENGTH; } -impl InitOptions { - /// Create a new InitOptions with the given model name - pub fn new(model_name: EmbeddingModel) -> Self { - Self { - model_name, - ..Default::default() - } - } - - /// Set the maximum length of the input text - pub fn with_max_length(mut self, max_length: usize) -> Self { - self.max_length = max_length; - self - } - - /// Set the cache directory for the model files - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache_dir = cache_dir; - self - } - - /// Set the execution providers for the model - pub fn with_execution_providers( - mut self, - execution_providers: Vec, - ) -> Self { - self.execution_providers = execution_providers; - self - } - - /// Set whether to show download progress - pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { - self.show_download_progress = show_download_progress; - self - } -} - -impl Default for InitOptions { - fn default() -> Self { - Self { - model_name: DEFAULT_EMBEDDING_MODEL, - execution_providers: Default::default(), - max_length: DEFAULT_MAX_LENGTH, - cache_dir: Path::new(&get_cache_dir()).to_path_buf(), - show_download_progress: true, - } - } -} +/// Options for initializing the TextEmbedding model +pub type TextInitOptions = InitOptionsWithLength; /// Options for initializing UserDefinedEmbeddingModel /// @@ -113,8 +62,8 @@ impl Default for InitOptionsUserDefined { /// Convert InitOptions to InitOptionsUserDefined /// /// This is useful for when the user wants to use the same options for both the default and user-defined models -impl From for InitOptionsUserDefined { - fn from(options: InitOptions) -> Self { +impl From for InitOptionsUserDefined { + fn from(options: TextInitOptions) -> Self { InitOptionsUserDefined { execution_providers: options.execution_providers, max_length: options.max_length, @@ -126,7 +75,6 @@ impl From for InitOptionsUserDefined { /// /// The onnx_file and tokenizer_files are expecting the files' bytes #[derive(Debug, Clone, PartialEq, Eq)] -#[non_exhaustive] pub struct UserDefinedEmbeddingModel { pub onnx_file: Vec, pub tokenizer_files: TokenizerFiles, diff --git a/src/text_embedding/mod.rs b/src/text_embedding/mod.rs index dc07263..bea7012 100644 --- a/src/text_embedding/mod.rs +++ b/src/text_embedding/mod.rs @@ -1,12 +1,9 @@ //! Text embedding module, containing the main struct [TextEmbedding] and its //! initialization options. -use crate::models::text_embedding::EmbeddingModel; - // Constants. const DEFAULT_BATCH_SIZE: usize = 256; const DEFAULT_MAX_LENGTH: usize = 512; -const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallENV15; // Output precedence and transforming functions. pub mod output; From 733ba748ff306e92fa429c646f6912418e35b2a2 Mon Sep 17 00:00:00 2001 From: Angel <77701490+cimandef@users.noreply.github.com> Date: Sat, 16 Aug 2025 18:36:42 +0000 Subject: [PATCH 2/2] chore(release): 5.0.3 [skip ci] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## [5.0.3](https://github.com/Anush008/fastembed-rs/compare/v5.0.2...v5.0.3) (2025-08-16) ## [5.0.3](https://github.com/Anush008/fastembed-rs/compare/v5.0.2...v5.0.3) (2025-08-16) ### 🧑‍💻 Code Refactoring * create a common options interface ([#179](https://github.com/Anush008/fastembed-rs/issues/179)) ([d261cce](https://github.com/Anush008/fastembed-rs/commit/d261cceea39fc9c677a99b6e4ce270f43a3e06e4)) --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 962f39b..660b6f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fastembed" -version = "5.0.2" +version = "5.0.3" edition = "2021" description = "Library for generating vector embeddings, reranking locally." license = "Apache-2.0"