Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ use std::io::Read;
use std::{fs::File, path::PathBuf};
use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};

pub const DEFAULT_CACHE_DIR: &str = ".fastembed_cache";
const DEFAULT_CACHE_DIR: &str = ".fastembed_cache";

pub fn get_cache_dir() -> String {
std::env::var("FASTEMBED_CACHE_DIR").unwrap_or(DEFAULT_CACHE_DIR.into())
}

pub struct SparseEmbedding {
pub indices: Vec<usize>,
Expand Down
4 changes: 2 additions & 2 deletions src/image_embedding/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::{Path, PathBuf};

use ort::{execution_providers::ExecutionProviderDispatch, session::Session};

use crate::{ImageEmbeddingModel, DEFAULT_CACHE_DIR};
use crate::{get_cache_dir, ImageEmbeddingModel};

use super::{utils::Compose, DEFAULT_EMBEDDING_MODEL};

Expand Down Expand Up @@ -48,7 +48,7 @@ impl Default for ImageInitOptions {
Self {
model_name: DEFAULT_EMBEDDING_MODEL,
execution_providers: Default::default(),
cache_dir: Path::new(DEFAULT_CACHE_DIR).to_path_buf(),
cache_dir: Path::new(&get_cache_dir()).to_path_buf(),
show_download_progress: true,
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mod text_embedding;
pub use ort::execution_providers::ExecutionProviderDispatch;

pub use crate::common::{
read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, DEFAULT_CACHE_DIR,
get_cache_dir, read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles,
};
pub use crate::models::{
model_info::ModelInfo, model_info::RerankerModelInfo, quantization::QuantizationMode,
Expand Down
4 changes: 2 additions & 2 deletions src/reranking/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::{Path, PathBuf};
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
use tokenizers::Tokenizer;

use crate::{RerankerModel, TokenizerFiles, DEFAULT_CACHE_DIR};
use crate::{common::get_cache_dir, RerankerModel, TokenizerFiles};

use super::{DEFAULT_MAX_LENGTH, DEFAULT_RE_RANKER_MODEL};

Expand Down Expand Up @@ -63,7 +63,7 @@ impl Default for RerankInitOptions {
model_name: DEFAULT_RE_RANKER_MODEL,
execution_providers: Default::default(),
max_length: DEFAULT_MAX_LENGTH,
cache_dir: Path::new(DEFAULT_CACHE_DIR).to_path_buf(),
cache_dir: Path::new(&get_cache_dir()).to_path_buf(),
show_download_progress: true,
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/sparse_text_embedding/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::{Path, PathBuf};
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
use tokenizers::Tokenizer;

use crate::{models::sparse::SparseModel, TokenizerFiles, DEFAULT_CACHE_DIR};
use crate::{common::get_cache_dir, models::sparse::SparseModel, TokenizerFiles};

use super::{DEFAULT_EMBEDDING_MODEL, DEFAULT_MAX_LENGTH};

Expand Down Expand Up @@ -56,7 +56,7 @@ impl Default for SparseInitOptions {
model_name: DEFAULT_EMBEDDING_MODEL,
execution_providers: Default::default(),
max_length: DEFAULT_MAX_LENGTH,
cache_dir: Path::new(DEFAULT_CACHE_DIR).to_path_buf(),
cache_dir: Path::new(&get_cache_dir()).to_path_buf(),
show_download_progress: true,
}
}
Expand Down
6 changes: 2 additions & 4 deletions src/text_embedding/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
//!

use crate::{
common::{TokenizerFiles, DEFAULT_CACHE_DIR},
pooling::Pooling,
EmbeddingModel, QuantizationMode,
common::TokenizerFiles, get_cache_dir, pooling::Pooling, EmbeddingModel, QuantizationMode,
};
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -66,7 +64,7 @@ impl Default for InitOptions {
model_name: DEFAULT_EMBEDDING_MODEL,
execution_providers: Default::default(),
max_length: DEFAULT_MAX_LENGTH,
cache_dir: Path::new(DEFAULT_CACHE_DIR).to_path_buf(),
cache_dir: Path::new(&get_cache_dir()).to_path_buf(),
show_download_progress: true,
}
}
Expand Down
18 changes: 9 additions & 9 deletions tests/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use hf_hub::Repo;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

use fastembed::{
read_file_to_bytes, Embedding, EmbeddingModel, ImageEmbedding, ImageEmbeddingModel,
ImageInitOptions, InitOptions, InitOptionsUserDefined, ModelInfo, OnnxSource, Pooling,
QuantizationMode, RerankInitOptions, RerankInitOptionsUserDefined, RerankerModel,
RerankerModelInfo, SparseInitOptions, SparseTextEmbedding, TextEmbedding, TextRerank,
TokenizerFiles, UserDefinedEmbeddingModel, UserDefinedRerankingModel, DEFAULT_CACHE_DIR,
get_cache_dir, read_file_to_bytes, Embedding, EmbeddingModel, ImageEmbedding,
ImageEmbeddingModel, ImageInitOptions, InitOptions, InitOptionsUserDefined, ModelInfo,
OnnxSource, Pooling, QuantizationMode, RerankInitOptions, RerankInitOptionsUserDefined,
RerankerModel, RerankerModelInfo, SparseInitOptions, SparseTextEmbedding, TextEmbedding,
TextRerank, TokenizerFiles, UserDefinedEmbeddingModel, UserDefinedRerankingModel,
};

/// A small epsilon value for floating point comparisons.
Expand Down Expand Up @@ -209,7 +209,7 @@ fn test_user_defined_embedding_model() {

// Get the directory of the model
let model_name = test_model_info.model_code.replace('/', "--");
let model_dir = Path::new(DEFAULT_CACHE_DIR).join(format!("models--{}", model_name));
let model_dir = Path::new(&get_cache_dir()).join(format!("models--{}", model_name));

// Find the "snapshots" sub-directory
let snapshots_dir = model_dir.join("snapshots");
Expand Down Expand Up @@ -341,7 +341,7 @@ fn test_rerank() {
#[test]
fn test_user_defined_reranking_large_model() {
// Setup model to download from Hugging Face
let cache = hf_hub::Cache::new(std::path::PathBuf::from(fastembed::DEFAULT_CACHE_DIR));
let cache = hf_hub::Cache::new(std::path::PathBuf::from(&fastembed::get_cache_dir()));
let api = hf_hub::api::sync::ApiBuilder::from_cache(cache)
.with_progress(true)
.build()
Expand Down Expand Up @@ -401,7 +401,7 @@ fn test_user_defined_reranking_model() {

// Get the directory of the model
let model_name = test_model_info.model_code.replace('/', "--");
let model_dir = Path::new(DEFAULT_CACHE_DIR).join(format!("models--{}", model_name));
let model_dir = Path::new(&get_cache_dir()).join(format!("models--{}", model_name));

// Find the "snapshots" sub-directory
let snapshots_dir = model_dir.join("snapshots");
Expand Down Expand Up @@ -554,7 +554,7 @@ fn test_nomic_embed_vision_v1_5() {

fn clean_cache(model_code: String) {
let repo = Repo::model(model_code);
let cache_dir = format!("{}/{}", DEFAULT_CACHE_DIR, repo.folder_name());
let cache_dir = format!("{}/{}", &get_cache_dir(), repo.folder_name());
fs::remove_dir_all(cache_dir).ok();
}

Expand Down
8 changes: 4 additions & 4 deletions tests/optimum_cli_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
use std::{path::PathBuf, process};

use fastembed::{
Pooling, QuantizationMode, TextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel,
DEFAULT_CACHE_DIR,
get_cache_dir, Pooling, QuantizationMode, TextEmbedding, TokenizerFiles,
UserDefinedEmbeddingModel,
};

const EPS: f32 = 1e-4;
Expand Down Expand Up @@ -87,8 +87,8 @@ macro_rules! create_test {
let repo_name = $repo_name;
let repo_owner = $repo_owner;
let model_name = format!("{}/{}", repo_owner, repo_name);
let output_path =
format!("{DEFAULT_CACHE_DIR}/exported--{repo_owner}--{repo_name}-onnx");
let cache_dir = get_cache_dir();
let output_path = format!("{cache_dir}/exported--{repo_owner}--{repo_name}-onnx");
let output = PathBuf::from(output_path);

assert!(
Expand Down