diff --git a/src/common.rs b/src/common.rs index ff5de2d..5023c94 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,6 +1,7 @@ use anyhow::Result; #[cfg(feature = "hf-hub")] use hf_hub::api::sync::{ApiBuilder, ApiRepo}; +#[cfg(feature = "hf-hub")] use std::path::PathBuf; use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; diff --git a/src/lib.rs b/src/lib.rs index 5fc314c..bdfbc4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,11 +72,16 @@ pub use crate::models::{ pub use crate::output::{EmbeddingOutput, OutputKey, OutputPrecedence, SingleBatchOutput}; pub use crate::pooling::Pooling; +// For all Embedding +pub use crate::init::{InitOptions as BaseInitOptions, InitOptionsWithLength}; +pub use crate::models::ModelTrait; + // For Text Embedding pub use crate::models::text_embedding::EmbeddingModel; +#[deprecated(note = "use `TextInitOptions` instead")] +pub use crate::text_embedding::TextInitOptions as InitOptions; pub use crate::text_embedding::{ - InitOptionsUserDefined, TextEmbedding, TextInitOptions as InitOptions, - UserDefinedEmbeddingModel, + InitOptionsUserDefined, TextEmbedding, TextInitOptions, UserDefinedEmbeddingModel, }; // For Sparse Text Embedding diff --git a/src/models/mod.rs b/src/models/mod.rs index 03daff4..6f010e2 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,6 +1,13 @@ +use crate::ModelInfo; + pub mod image_embedding; pub mod model_info; pub mod quantization; pub mod reranking; pub mod sparse; pub mod text_embedding; + +pub trait ModelTrait { + type Model; + fn get_model_info(model: &Self::Model) -> Option<&ModelInfo>; +} diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index 1365af8..56da9a1 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, convert::TryFrom, fmt::Display, str::FromStr, sync::OnceLock}; -use super::model_info::ModelInfo; +use super::{model_info::ModelInfo, ModelTrait}; /// Lazy static list of all available models. static MODEL_MAP: OnceLock>> = OnceLock::new(); @@ -342,11 +342,6 @@ pub fn models_map() -> &'static HashMap Option<&ModelInfo> { - models_map().get(model) -} - /// Get a list of all available models. /// /// This will assign new memory to the models list; where possible, use @@ -355,9 +350,18 @@ pub fn models_list() -> Vec> { models_map().values().cloned().collect() } +impl ModelTrait for EmbeddingModel { + type Model = Self; + + /// Get model information by model code. + fn get_model_info(model: &EmbeddingModel) -> Option<&ModelInfo> { + models_map().get(model) + } +} + impl Display for EmbeddingModel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let model_info = get_model_info(self).expect("Model not found."); + let model_info = EmbeddingModel::get_model_info(self).expect("Model not found."); write!(f, "{}", model_info.model_code) } } diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index cd074b0..e85c942 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -3,10 +3,9 @@ #[cfg(feature = "hf-hub")] use crate::common::load_tokenizer_hf_hub; use crate::{ - common::load_tokenizer, - models::text_embedding::{get_model_info, models_list}, - pooling::Pooling, - Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput, + common::load_tokenizer, models::text_embedding::models_list, models::ModelTrait, + pooling::Pooling, Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, + SingleBatchOutput, }; #[cfg(feature = "hf-hub")] use anyhow::Context; @@ -225,7 +224,7 @@ impl TextEmbedding { /// Get ModelInfo from EmbeddingModel pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo> { - get_model_info(model).ok_or_else(|| { + EmbeddingModel::get_model_info(model).ok_or_else(|| { anyhow::Error::msg(format!( "Model {model:?} not found. Please check if the model is supported \ by the current version."