Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
name: "Cargo Tests"
on:
pull_request:
types:
- opened
- edited
- synchronize
- reopened
schedule:
- cron: 0 0 * * *

pull_request:
schedule:
- cron: 0 0 * * *

env:
CARGO_TERM_COLOR: always
RUSTFLAGS: "-Dwarnings"
ONNX_VERSION: v1.20.1

jobs:
test:
Expand All @@ -20,14 +16,40 @@ jobs:
steps:
- uses: actions/checkout@v3

- name: Restore Builds
id: cache-build-restore
uses: actions/cache/restore@v4
with:
key: '${{ runner.os }}-onnxruntime-${{ env.ONNX_VERSION }}'
path: |
onnxruntime/build/Linux/Release/

- name: Compile ONNX Runtime for Linux
if: steps.cache-build-restore.outputs.cache-hit != 'true'
run: |
echo Cloning ONNX Runtime repository...
git clone https://github.com/microsoft/onnxruntime --recursive --branch $ONNX_VERSION --single-branch --depth 1
cd onnxruntime
./build.sh --update --build --config Release --parallel --compile_no_warning_as_error --skip_submodule_sync
cd ..

- name: Cargo Test With Release Build
run: cargo test --release
run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --release --no-default-features --features online

- name: Cargo Test Offline
run: cargo test --no-default-features --features ort-download-binaries
run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --no-default-features

- name: Cargo Clippy
run: cargo clippy

- name: Cargo FMT
run: cargo fmt --all -- --check

- name: Always Save Cache
id: cache-build-save
if: always() && steps.cache-build-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
key: '${{ steps.cache-build-restore.outputs.cache-primary-key }}'
path: |
onnxruntime/build/Linux/Release/
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fastembed"
version = "4.1.0"
version = "4.3.0"
edition = "2021"
description = "Rust implementation of https://github.com/qdrant/fastembed"
license = "Apache-2.0"
Expand Down Expand Up @@ -31,7 +31,7 @@ ort = { git = "https://github.com/pykeio/ort", rev = "2a9f66d", default-features
] }
rayon = { version = "1.10", default-features = false }
serde_json = { version = "1" }
tokenizers = { version = "0.19", default-features = false, features = ["onig"] }
tokenizers = { version = "0.21", default-features = false, features = ["onig"] }

[features]
default = ["ort-download-binaries", "online"]
Expand Down
24 changes: 10 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- [**mixedbread-ai/mxbai-embed-large-v1**](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
- [**Qdrant/clip-ViT-B-32-text**](https://huggingface.co/Qdrant/clip-ViT-B-32-text) - pairs with the image model clip-ViT-B-32-vision for image-to-text search

<details>
<summary>Click to see full List</summary>
Expand All @@ -39,7 +40,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2)
- [**sentence-transformers/paraphrase-multilingual-mpnet-base-v2**](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)
- [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1)
- [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5)
- [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) - pairs with the image model nomic-embed-vision-v1.5 for image-to-text search
- [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small)
- [**intfloat/multilingual-e5-base**](https://huggingface.co/intfloat/multilingual-e5-base)
- [**intfloat/multilingual-e5-large**](https://huggingface.co/intfloat/multilingual-e5-large)
Expand All @@ -58,6 +59,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**Qdrant/resnet50-onnx**](https://huggingface.co/Qdrant/resnet50-onnx)
- [**Qdrant/Unicom-ViT-B-16**](https://huggingface.co/Qdrant/Unicom-ViT-B-16)
- [**Qdrant/Unicom-ViT-B-32**](https://huggingface.co/Qdrant/Unicom-ViT-B-32)
- [**nomic-ai/nomic-embed-vision-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-vision-v1.5)

### Reranking

Expand Down Expand Up @@ -158,23 +160,17 @@ println!("Rerank result: {:?}", results);

Alternatively, local model files can be used for inference via the `try_new_from_user_defined(...)` methods of respective structs.

## 🚒 Under the hood
## ✊ Support

### Why fast?
To support the library, please consider donating to our primary upstream dependency, [`ort`](https://github.com/pykeio/ort?tab=readme-ov-file#-sponsor-ort) - The Rust wrapper for the ONNX runtime.

It's important we justify the "fast" in FastEmbed. FastEmbed is fast because:
## ⚙️ Under the hood

1. Quantized model weights
2. ONNX Runtime which allows for inference on CPU, GPU, and other dedicated runtimes
It's important we justify the "fast" in FastEmbed. FastEmbed is fast because of:

### Why light?

1. No hidden dependencies via Huggingface Transformers

### Why accurate?

1. Better than OpenAI Ada-002
2. Top of the Embedding leaderboards e.g. [MTEB](https://huggingface.co/spaces/mteb/leaderboard)
1. Quantized model weights.
2. ONNX Runtime which allows for inference on CPU, GPU, and other dedicated runtimes.
3. No hidden dependencies via Huggingface Transformers.

## 📄 LICENSE

Expand Down
66 changes: 49 additions & 17 deletions src/image_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use hf_hub::{
Cache,
};
use ndarray::{Array3, ArrayView3};
use ort::{GraphOptimizationLevel, Session, Value};
use ort::{
session::{builder::GraphOptimizationLevel, Session},
value::Value,
};
#[cfg(feature = "online")]
use std::path::PathBuf;
use std::{path::Path, thread::available_parallelism};
Expand All @@ -14,6 +17,8 @@ use crate::{
ModelInfo,
};
use anyhow::anyhow;
#[cfg(feature = "online")]
use anyhow::Context;

#[cfg(feature = "online")]
use super::ImageInitOptions;
Expand Down Expand Up @@ -49,13 +54,13 @@ impl ImageEmbedding {

let preprocessor_file = model_repo
.get("preprocessor_config.json")
.unwrap_or_else(|_| panic!("Failed to retrieve preprocessor_config.json"));
.context("Failed to retrieve preprocessor_config.json")?;
let preprocessor = Compose::from_file(preprocessor_file)?;

let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
.context(format!("Failed to retrieve {}", model_file_name))?;

let session = Session::builder()?
.with_execution_providers(execution_providers)?
Expand Down Expand Up @@ -108,8 +113,7 @@ impl ImageEmbedding {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
.build()?;

let repo = api.model(model.to_string());
Ok(repo)
Expand Down Expand Up @@ -169,24 +173,52 @@ impl ImageEmbedding {
let outputs = self.session.run(session_inputs)?;

// Try to get the only output key
// If multiple, then default to `image_embeds`
// If multiple, then default to few known keys `image_embeds` and `last_hidden_state`
let last_hidden_state_key = match outputs.len() {
1 => outputs.keys().next().unwrap(),
_ => "image_embeds",
1 => vec![outputs.keys().next().unwrap()],
_ => vec!["image_embeds", "last_hidden_state"],
};

// Extract and normalize embeddings
let output_data = outputs[last_hidden_state_key].try_extract_tensor::<f32>()?;

let embeddings: Vec<Vec<f32>> = output_data
.rows()
.into_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect();
// Extract tensor and handle different dimensionalities
let output_data = last_hidden_state_key
.iter()
.find_map(|&key| {
outputs
.get(key)
.and_then(|v| v.try_extract_tensor::<f32>().ok())
})
.ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?;
let shape = output_data.shape();

let embeddings: Vec<Vec<f32>> = match shape.len() {
3 => {
// For 3D output [batch_size, sequence_length, hidden_size]
// Take only the first token, sequence_length[0] (CLS token), embedding
// and return [batch_size, hidden_size]
(0..shape[0])
.map(|batch_idx| {
let cls_embedding =
output_data.slice(ndarray::s![batch_idx, 0, ..]).to_vec();
normalize(&cls_embedding)
})
.collect()
}
2 => {
// For 2D output [batch_size, hidden_size]
output_data
.rows()
.into_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect()
}
_ => return Err(anyhow!("Unexpected output tensor shape: {:?}", shape)),
};

Ok(embeddings)
})
.flat_map(|result: Result<Vec<Vec<f32>>, anyhow::Error>| result.unwrap())
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

Ok(output)
Expand Down
2 changes: 1 addition & 1 deletion src/image_embedding/init.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::{Path, PathBuf};

use ort::{ExecutionProviderDispatch, Session};
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};

use crate::{ImageEmbeddingModel, DEFAULT_CACHE_DIR};

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mod reranking;
mod sparse_text_embedding;
mod text_embedding;

pub use ort::ExecutionProviderDispatch;
pub use ort::execution_providers::ExecutionProviderDispatch;

pub use crate::common::{
read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, DEFAULT_CACHE_DIR,
Expand Down
11 changes: 10 additions & 1 deletion src/models/image_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub enum ImageEmbeddingModel {
UnicomVitB16,
/// Qdrant/Unicom-ViT-B-32
UnicomVitB32,
/// nomic-ai/nomic-embed-vision-v1.5
NomicEmbedVisionV15,
}

pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
Expand Down Expand Up @@ -43,7 +45,14 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
description: String::from("Unicom Unicom-ViT-B-32 from open-metric-learning"),
model_code: String::from("Qdrant/Unicom-ViT-B-32"),
model_file: String::from("model.onnx"),
}
},
ModelInfo {
model: ImageEmbeddingModel::NomicEmbedVisionV15,
dim: 768,
description: String::from("Nomic NomicEmbedVisionV15"),
model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"),
model_file: String::from("onnx/model.onnx"),
},
];

// TODO: Use when out in stable
Expand Down
11 changes: 11 additions & 0 deletions src/models/text_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub enum EmbeddingModel {
GTELargeENV15,
/// Quantized Alibaba-NLP/gte-large-en-v1.5
GTELargeENV15Q,
/// Qdrant/clip-ViT-B-32-text
ClipVitB32,
}

/// Centralized function to initialize the models map.
Expand Down Expand Up @@ -256,6 +258,13 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
},
ModelInfo {
model: EmbeddingModel::ClipVitB32,
dim: 512,
description: String::from("CLIP text encoder based on ViT-B/32"),
model_code: String::from("Qdrant/clip-ViT-B-32-text"),
model_file: String::from("model.onnx"),
},
];

// TODO: Use when out in stable
Expand Down Expand Up @@ -327,6 +336,8 @@ impl EmbeddingModel {
EmbeddingModel::GTEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls),

EmbeddingModel::ClipVitB32 => Some(Pooling::Mean),
}
}

Expand Down
20 changes: 7 additions & 13 deletions src/output/embedding_output.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ndarray::{Array2, ArrayView, Dim, IxDynImpl};
use ort::Value;
use ort::session::SessionOutputs;

use crate::pooling;

Expand All @@ -11,7 +11,7 @@ use super::{OutputKey, OutputPrecedence};
/// pooling etc. This struct should contain all the necessary information for the
/// post-processing to be performed.
pub struct SingleBatchOutput<'r, 's> {
pub session_outputs: ort::SessionOutputs<'r, 's>,
pub session_outputs: SessionOutputs<'r, 's>,
pub attention_mask_array: Array2<i64>,
}

Expand All @@ -23,19 +23,13 @@ impl SingleBatchOutput<'_, '_> {
pub fn select_output<'a>(
&'a self,
precedence: &impl OutputPrecedence,
) -> anyhow::Result<ArrayView<'a, f32, Dim<IxDynImpl>>> {
let ort_output: &Value = precedence
) -> anyhow::Result<ArrayView<f32, Dim<IxDynImpl>>> {
let ort_output: &ort::value::Value = precedence
.key_precedence()
.find_map(|key| match key {
OutputKey::OnlyOne => {
// Only export the value if there is only one output available.
if self.session_outputs.len() == 1 {
let key = self.session_outputs.keys().next().unwrap();
self.session_outputs.get(key)
} else {
None
}
}
OutputKey::OnlyOne => self
.session_outputs
.get(self.session_outputs.keys().nth(0)?),
OutputKey::ByOrder(idx) => {
let x = self
.session_outputs
Expand Down
Loading