diff --git a/README.md b/README.md index 5dd9da0..4c7dc0e 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ - [**sentence-transformers/all-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) - [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2) - [**sentence-transformers/paraphrase-multilingual-mpnet-base-v2**](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) +- [**lightonai/ModernBERT-embed-large**](https://huggingface.co/lightonai/modernbert-embed-large) - [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1) - [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) - pairs with the `nomic-embed-vision-v1.5` image model for image-to-text search - [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small) diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index a60f748..89778ae 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -41,6 +41,8 @@ pub enum EmbeddingModel { ParaphraseMLMpnetBaseV2, /// BAAI/bge-small-zh-v1.5 BGESmallZHV15, + /// lightonai/modernbert-embed-large + ModernBertEmbedLarge, /// intfloat/multilingual-e5-small MultilingualE5Small, /// intfloat/multilingual-e5-base @@ -210,6 +212,14 @@ fn init_models_map() -> HashMap> { model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), }, + ModelInfo { + model: EmbeddingModel::ModernBertEmbedLarge, + dim: 1024, + description: String::from("Large model of ModernBert Text Embeddings"), + model_code: String::from("lightonai/modernbert-embed-large"), + model_file: String::from("onnx/model.onnx"), + additional_files: Vec::new(), + }, ModelInfo { model: EmbeddingModel::MultilingualE5Small, dim: 384, diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 52f6b05..23222c7 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -180,6 +180,8 @@ impl TextEmbedding { EmbeddingModel::ParaphraseMLMiniLML12V2Q => Some(Pooling::Mean), EmbeddingModel::ParaphraseMLMpnetBaseV2 => Some(Pooling::Mean), + EmbeddingModel::ModernBertEmbedLarge => Some(Pooling::Mean), + EmbeddingModel::MultilingualE5Base => Some(Pooling::Mean), EmbeddingModel::MultilingualE5Small => Some(Pooling::Mean), EmbeddingModel::MultilingualE5Large => Some(Pooling::Mean), diff --git a/tests/embeddings.rs b/tests/embeddings.rs index dc39694..79e1248 100644 --- a/tests/embeddings.rs +++ b/tests/embeddings.rs @@ -50,6 +50,7 @@ fn verify_embeddings(model: &EmbeddingModel, embeddings: &[Embedding]) -> Result EmbeddingModel::GTEBaseENV15Q => [-1.7032102, -1.7076654, -1.729326, -1.5317788], EmbeddingModel::GTELargeENV15 => [-1.6457459, -1.6582386, -1.6809471, -1.6070237], EmbeddingModel::GTELargeENV15Q => [-1.6044945, -1.6469251, -1.6828246, -1.6265479], + EmbeddingModel::ModernBertEmbedLarge => [ 0.24799639, 0.32174295, 0.17255782, 0.32919246], EmbeddingModel::MultilingualE5Base => [-0.057211064, -0.14287914, -0.071678676, -0.17549144], EmbeddingModel::MultilingualE5Large => [-0.7473163, -0.76040405, -0.7537941, -0.72920954], EmbeddingModel::MultilingualE5Small => [-0.2640718, -0.13929011, -0.08091972, -0.12388548], @@ -671,3 +672,41 @@ fn test_allminilml6v2_match_python_counterpart() { assert!((expected - actual).abs() < tolerance); } } + +#[test] +fn test_modernbert_embeddings() { + let supported_model = TextEmbedding::list_supported_models() + .into_iter() + .find(|model| matches!(model.model, EmbeddingModel::ModernBertEmbedLarge)) + .expect("ModernBERT model not found in supported models"); + + let model: TextEmbedding = + TextEmbedding::try_new(InitOptions::new(supported_model.model.clone())).unwrap(); + + let documents = vec![ + "Hello, World!", + "This is an example passage.", + "fastembed-rs is licensed under Apache-2.0", + "Some other short text here blah blah blah", + ]; + + let embeddings = model.embed(documents.clone(), None).unwrap(); + assert_eq!(embeddings.len(), documents.len()); + + for embedding in &embeddings { + assert_eq!(embedding.len(), supported_model.dim); + } + + match verify_embeddings(&supported_model.model, &embeddings) { + Ok(_) => {} + Err(mismatched_indices) => { + panic!( + "Mismatched embeddings for ModernBERT: {sentences:?}", + sentences = &mismatched_indices + .iter() + .map(|&i| documents[i]) + .collect::>() + ); + } + } +}