Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 3 additions & 52 deletions src/image_embedding/init.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,9 @@
use std::path::{Path, PathBuf};

use super::utils::Compose;
use crate::{init::InitOptions, ImageEmbeddingModel};
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};

use crate::{get_cache_dir, ImageEmbeddingModel};

use super::{utils::Compose, DEFAULT_EMBEDDING_MODEL};

/// Options for initializing the ImageEmbedding model
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ImageInitOptions {
pub model_name: ImageEmbeddingModel,
pub execution_providers: Vec<ExecutionProviderDispatch>,
pub cache_dir: PathBuf,
pub show_download_progress: bool,
}

impl ImageInitOptions {
pub fn new(model_name: ImageEmbeddingModel) -> Self {
Self {
model_name,
..Default::default()
}
}

pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache_dir = cache_dir;
self
}

pub fn with_execution_providers(
mut self,
execution_providers: Vec<ExecutionProviderDispatch>,
) -> Self {
self.execution_providers = execution_providers;
self
}

pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
self.show_download_progress = show_download_progress;
self
}
}

impl Default for ImageInitOptions {
fn default() -> Self {
Self {
model_name: DEFAULT_EMBEDDING_MODEL,
execution_providers: Default::default(),
cache_dir: Path::new(&get_cache_dir()).to_path_buf(),
show_download_progress: true,
}
}
}
pub type ImageInitOptions = InitOptions<ImageEmbeddingModel>;

/// Options for initializing UserDefinedImageEmbeddingModel
///
Expand Down
2 changes: 0 additions & 2 deletions src/image_embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use crate::models::image_embedding::ImageEmbeddingModel;
const DEFAULT_BATCH_SIZE: usize = 256;
const DEFAULT_EMBEDDING_MODEL: ImageEmbeddingModel = ImageEmbeddingModel::ClipVitB32;

mod utils;

Expand Down
117 changes: 117 additions & 0 deletions src/init.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use crate::get_cache_dir;
use ort::execution_providers::ExecutionProviderDispatch;
use std::path::PathBuf;

pub trait HasMaxLength {
const MAX_LENGTH: usize;
}

#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct InitOptionsWithLength<M> {
pub model_name: M,
pub execution_providers: Vec<ExecutionProviderDispatch>,
pub cache_dir: PathBuf,
pub show_download_progress: bool,
pub max_length: usize,
}

#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct InitOptions<M> {
pub model_name: M,
pub execution_providers: Vec<ExecutionProviderDispatch>,
pub cache_dir: PathBuf,
pub show_download_progress: bool,
}

impl<M: Default + HasMaxLength> Default for InitOptionsWithLength<M> {
fn default() -> Self {
Self {
model_name: M::default(),
execution_providers: Default::default(),
cache_dir: get_cache_dir().into(),
show_download_progress: true,
max_length: M::MAX_LENGTH,
}
}
}

impl<M: Default> Default for InitOptions<M> {
fn default() -> Self {
Self {
model_name: M::default(),
execution_providers: Default::default(),
cache_dir: get_cache_dir().into(),
show_download_progress: true,
}
}
}

impl<M: Default + HasMaxLength> InitOptionsWithLength<M> {
/// Crea a new InitOptionsWithLength with the given model name
pub fn new(model_name: M) -> Self {
Self {
model_name,
..Default::default()
}
}

/// Set the maximum maximum length
pub fn with_max_length(mut self, max_lenght: usize) -> Self {
self.max_length = max_lenght;
self
}

/// Set the cache directory for the model file
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache_dir = cache_dir;
self
}

/// Set the execution providers for the model
pub fn with_execution_providers(
mut self,
execution_providers: Vec<ExecutionProviderDispatch>,
) -> Self {
self.execution_providers = execution_providers;
self
}

/// Set whether to show download progress
pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
self.show_download_progress = show_download_progress;
self
}
}

impl<M: Default> InitOptions<M> {
/// Crea a new InitOptionsWithLength with the given model name
pub fn new(model_name: M) -> Self {
Self {
model_name,
..Default::default()
}
}

/// Set the cache directory for the model file
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache_dir = cache_dir;
self
}

/// Set the execution providers for the model
pub fn with_execution_providers(
mut self,
execution_providers: Vec<ExecutionProviderDispatch>,
) -> Self {
self.execution_providers = execution_providers;
self
}

/// Set whether to show download progress
pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
self.show_download_progress = show_download_progress;
self
}
}
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

mod common;
mod image_embedding;
mod init;
mod models;
pub mod output;
mod pooling;
Expand All @@ -74,7 +75,8 @@ pub use crate::pooling::Pooling;
// For Text Embedding
pub use crate::models::text_embedding::EmbeddingModel;
pub use crate::text_embedding::{
InitOptions, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel,
InitOptionsUserDefined, TextEmbedding, TextInitOptions as InitOptions,
UserDefinedEmbeddingModel,
};

// For Sparse Text Embedding
Expand Down
3 changes: 2 additions & 1 deletion src/models/image_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use std::{fmt::Display, str::FromStr};

use super::model_info::ModelInfo;

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub enum ImageEmbeddingModel {
/// Qdrant/clip-ViT-B-32-vision
#[default]
ClipVitB32,
/// Qdrant/resnet50-onnx
Resnet50,
Expand Down
3 changes: 2 additions & 1 deletion src/models/reranking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use std::{fmt::Display, str::FromStr};

use crate::RerankerModelInfo;

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub enum RerankerModel {
/// BAAI/bge-reranker-base
#[default]
BGERerankerBase,
/// rozgo/bge-reranker-v2-m3
BGERerankerV2M3,
Expand Down
3 changes: 2 additions & 1 deletion src/models/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use std::{fmt::Display, str::FromStr};

use crate::ModelInfo;

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub enum SparseModel {
/// prithivida/Splade_PP_en_v1
#[default]
SPLADEPPV1,
}

Expand Down
3 changes: 2 additions & 1 deletion src/models/text_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::model_info::ModelInfo;
/// Lazy static list of all available models.
static MODEL_MAP: OnceLock<HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>>> = OnceLock::new();

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub enum EmbeddingModel {
/// sentence-transformers/all-MiniLM-L6-v2
AllMiniLML6V2,
Expand All @@ -24,6 +24,7 @@ pub enum EmbeddingModel {
/// Quantized BAAI/bge-large-en-v1.5
BGELargeENV15Q,
/// BAAI/bge-small-en-v1.5 - Default
#[default]
BGESmallENV15,
/// Quantized BAAI/bge-small-en-v1.5
BGESmallENV15Q,
Expand Down
6 changes: 4 additions & 2 deletions src/output/embedding_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl SingleBatchOutput {
pub fn select_output(
&self,
precedence: &impl OutputPrecedence,
) -> anyhow::Result<ArrayView<f32, Dim<IxDynImpl>>> {
) -> anyhow::Result<ArrayView<'_, f32, Dim<IxDynImpl>>> {
let ort_output: &ort::value::Value = precedence
.key_precedence()
.find_map(|key| match key {
Expand All @@ -35,7 +35,9 @@ impl SingleBatchOutput {
}
}
OutputKey::ByOrder(idx) => self.outputs.get(*idx).map(|(_, v)| v),
OutputKey::ByName(name) => self.outputs.iter().find(|(n, _)| n == name).map(|(_, v)| v),
OutputKey::ByName(name) => {
self.outputs.iter().find(|(n, _)| n == name).map(|(_, v)| v)
}
})
.ok_or_else(|| {
anyhow::Error::msg(format!(
Expand Down
2 changes: 1 addition & 1 deletion src/reranking/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ impl TextRerank {
use super::RerankInitOptions;

let RerankInitOptions {
max_length,
model_name,
execution_providers,
max_length,
cache_dir,
show_download_progress,
} = options;
Expand Down
68 changes: 10 additions & 58 deletions src/reranking/init.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,25 @@
use std::path::{Path, PathBuf};

use super::DEFAULT_MAX_LENGTH;
use crate::{
init::{HasMaxLength, InitOptionsWithLength},
RerankerModel, TokenizerFiles,
};
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
use std::path::PathBuf;
use tokenizers::Tokenizer;

use crate::{common::get_cache_dir, RerankerModel, TokenizerFiles};

use super::{DEFAULT_MAX_LENGTH, DEFAULT_RE_RANKER_MODEL};

#[derive(Debug)]
pub struct TextRerank {
pub tokenizer: Tokenizer,
pub(crate) session: Session,
pub(crate) need_token_type_ids: bool,
}

/// Options for initializing the reranking model
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RerankInitOptions {
pub model_name: RerankerModel,
pub execution_providers: Vec<ExecutionProviderDispatch>,
pub max_length: usize,
pub cache_dir: PathBuf,
pub show_download_progress: bool,
impl HasMaxLength for RerankerModel {
const MAX_LENGTH: usize = DEFAULT_MAX_LENGTH;
}

impl RerankInitOptions {
pub fn new(model_name: RerankerModel) -> Self {
Self {
model_name,
..Default::default()
}
}

pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}

pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache_dir = cache_dir;
self
}

pub fn with_execution_providers(
mut self,
execution_providers: Vec<ExecutionProviderDispatch>,
) -> Self {
self.execution_providers = execution_providers;
self
}

pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self {
self.show_download_progress = show_download_progress;
self
}
}

impl Default for RerankInitOptions {
fn default() -> Self {
Self {
model_name: DEFAULT_RE_RANKER_MODEL,
execution_providers: Default::default(),
max_length: DEFAULT_MAX_LENGTH,
cache_dir: Path::new(&get_cache_dir()).to_path_buf(),
show_download_progress: true,
}
}
}
/// Options for initializing the reranking models
pub type RerankInitOptions = InitOptionsWithLength<RerankerModel>;

/// Options for initializing UserDefinedRerankerModel
///
Expand Down
3 changes: 0 additions & 3 deletions src/reranking/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use crate::RerankerModel;

const DEFAULT_RE_RANKER_MODEL: RerankerModel = RerankerModel::BGERerankerBase;
const DEFAULT_MAX_LENGTH: usize = 512;
const DEFAULT_BATCH_SIZE: usize = 256;

Expand Down
4 changes: 2 additions & 2 deletions src/sparse_text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ impl SparseTextEmbedding {
use ort::{session::builder::GraphOptimizationLevel, session::Session};

let SparseInitOptions {
model_name,
execution_providers,
max_length,
model_name,
cache_dir,
show_download_progress,
execution_providers,
} = options;

let threads = available_parallelism()?.get();
Expand Down
Loading
Loading