Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Added ClipVitB32 model to support text-to-image search. (Anush0…
  • Loading branch information
bvgastel authored Nov 23, 2024
commit a890a6140d83e0da4b01061533ffe47d24da7212
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- [**mixedbread-ai/mxbai-embed-large-v1**](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
- [**Qdrant/clip-ViT-B-32-text**](https://huggingface.co/Qdrant/clip-ViT-B-32-text) - pairs with the image model clip-ViT-B-32-vision for image-to-text search

<details>
<summary>Click to see full List</summary>
Expand Down
11 changes: 11 additions & 0 deletions src/models/text_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub enum EmbeddingModel {
GTELargeENV15,
/// Quantized Alibaba-NLP/gte-large-en-v1.5
GTELargeENV15Q,
/// Qdrant/clip-ViT-B-32-text
ClipVitB32,
}

/// Centralized function to initialize the models map.
Expand Down Expand Up @@ -256,6 +258,13 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
},
ModelInfo {
model: EmbeddingModel::ClipVitB32,
dim: 512,
description: String::from("CLIP text encoder based on ViT-B/32"),
model_code: String::from("Qdrant/clip-ViT-B-32-text"),
model_file: String::from("model.onnx"),
},
];

// TODO: Use when out in stable
Expand Down Expand Up @@ -327,6 +336,8 @@ impl EmbeddingModel {
EmbeddingModel::GTEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls),

EmbeddingModel::ClipVitB32 => Some(Pooling::Mean),
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ fn verify_embeddings(model: &EmbeddingModel, embeddings: &[Embedding]) -> Result
EmbeddingModel::ParaphraseMLMiniLML12V2 => [-0.07795018, -0.059113946, -0.043668486, -0.1880083],
EmbeddingModel::ParaphraseMLMiniLML12V2Q => [-0.07749095, -0.058981877, -0.043487836, -0.18775631],
EmbeddingModel::ParaphraseMLMpnetBaseV2 => [0.39132136, 0.49490625, 0.65497226, 0.34237382],
EmbeddingModel::ClipVitB32 => [0.7057363, 1.3549932, 0.46823958, 0.52351093],
_ => panic!("Model {model} not found. If you have just inserted this `EmbeddingModel` variant, please update the expected embeddings."),
};

Expand Down