Skip to content

Commit 87f4ced

Browse files
authored
chore: Bump ort to v2.0.0-rc.9 (#125)
* chore: Bump ort to v2.0.0-rc.9 Signed-off-by: Anush008 <[email protected]> * refactor: Dedup imports Signed-off-by: Anush008 <[email protected]> * ci: Custom link v1.20.1 ONNX build Signed-off-by: Anush008 <[email protected]> * chore: no 'half' --------- Signed-off-by: Anush008 <[email protected]>
1 parent a746b09 commit 87f4ced

File tree

13 files changed

+57
-26
lines changed

13 files changed

+57
-26
lines changed

.github/workflows/test.yml

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,40 @@ jobs:
2020
steps:
2121
- uses: actions/checkout@v3
2222

23+
- name: Restore Builds
24+
id: cache-build-restore
25+
uses: actions/cache/restore@v4
26+
with:
27+
key: '${{ runner.os }}-cargox-${{ hashFiles(''**/Cargo.toml'') }}'
28+
path: |
29+
onnxruntime/build/Linux/Release/
30+
31+
- name: Compile ONNX Runtime for Linux
32+
if: steps.cache-build-restore.outputs.cache-hit != 'true'
33+
run: |
34+
echo Cloning ONNX Runtime repository...
35+
git clone https://github.com/microsoft/onnxruntime --recursive --branch v1.20.1 --single-branch --depth 1
36+
cd onnxruntime
37+
./build.sh --update --build --config Release --parallel --compile_no_warning_as_error --skip_submodule_sync
38+
cd ..
39+
2340
- name: Cargo Test With Release Build
24-
run: cargo test --release
41+
run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --release --no-default-features --features online
2542

2643
- name: Cargo Test Offline
27-
run: cargo test --no-default-features --features ort-download-binaries
44+
run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --no-default-features
2845

2946
- name: Cargo Clippy
3047
run: cargo clippy
3148

3249
- name: Cargo FMT
3350
run: cargo fmt --all -- --check
51+
52+
- name: Always Save Cache
53+
id: cache-build-save
54+
if: always() && steps.cache-build-restore.outputs.cache-hit != 'true'
55+
uses: actions/cache/save@v4
56+
with:
57+
key: '${{ steps.cache-build-restore.outputs.cache-primary-key }}'
58+
path: |
59+
onnxruntime/build/Linux/Release/

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ authors = [
1313
"Luya Wang <[email protected]>",
1414
1515
"Denny Wong <[email protected]>",
16-
"Alex Rozgo <[email protected]>"
16+
"Alex Rozgo <[email protected]>",
1717
]
1818
documentation = "https://docs.rs/fastembed"
1919
repository = "https://github.com/Anush008/fastembed-rs"
@@ -26,8 +26,8 @@ anyhow = { version = "1" }
2626
hf-hub = { version = "0.3", default-features = false }
2727
image = "0.25.2"
2828
ndarray = { version = "0.16", default-features = false }
29-
ort = { version = "=2.0.0-rc.8", default-features = false, features = [
30-
"half", "ndarray",
29+
ort = { version = "=2.0.0-rc.9", default-features = false, features = [
30+
"ndarray",
3131
] }
3232
rayon = { version = "1.10", default-features = false }
3333
serde_json = { version = "1" }

src/image_embedding/impl.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use hf_hub::{
44
Cache,
55
};
66
use ndarray::{Array3, ArrayView3};
7-
use ort::{GraphOptimizationLevel, Session, Value};
7+
use ort::{
8+
session::{builder::GraphOptimizationLevel, Session},
9+
value::Value,
10+
};
811
#[cfg(feature = "online")]
912
use std::path::PathBuf;
1013
use std::{path::Path, thread::available_parallelism};

src/image_embedding/init.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::path::{Path, PathBuf};
22

3-
use ort::{ExecutionProviderDispatch, Session};
3+
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
44

55
use crate::{ImageEmbeddingModel, DEFAULT_CACHE_DIR};
66

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ mod reranking;
6262
mod sparse_text_embedding;
6363
mod text_embedding;
6464

65-
pub use ort::ExecutionProviderDispatch;
65+
pub use ort::execution_providers::ExecutionProviderDispatch;
6666

6767
pub use crate::common::{
6868
read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, DEFAULT_CACHE_DIR,

src/output/embedding_output.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use ndarray::{Array2, ArrayView, Dim, IxDynImpl};
2+
use ort::session::SessionOutputs;
23

34
use crate::pooling;
45

@@ -10,7 +11,7 @@ use super::{OutputKey, OutputPrecedence};
1011
/// pooling etc. This struct should contain all the necessary information for the
1112
/// post-processing to be performed.
1213
pub struct SingleBatchOutput<'r, 's> {
13-
pub session_outputs: ort::SessionOutputs<'r, 's>,
14+
pub session_outputs: SessionOutputs<'r, 's>,
1415
pub attention_mask_array: Array2<i64>,
1516
}
1617

@@ -23,17 +24,12 @@ impl<'r, 's> SingleBatchOutput<'r, 's> {
2324
&self,
2425
precedence: &impl OutputPrecedence,
2526
) -> anyhow::Result<ArrayView<f32, Dim<IxDynImpl>>> {
26-
let ort_output = precedence
27+
let ort_output: &ort::value::Value = precedence
2728
.key_precedence()
2829
.find_map(|key| match key {
29-
OutputKey::OnlyOne => {
30-
// Only export the value if there is only one output available.
31-
if self.session_outputs.len() == 1 {
32-
self.session_outputs.values().next()
33-
} else {
34-
None
35-
}
36-
}
30+
OutputKey::OnlyOne => self
31+
.session_outputs
32+
.get(self.session_outputs.keys().nth(0)?),
3733
OutputKey::ByOrder(idx) => {
3834
let x = self
3935
.session_outputs

src/reranking/impl.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
use anyhow::Result;
2+
use ort::{
3+
session::{builder::GraphOptimizationLevel, Session},
4+
value::Value,
5+
};
26
use std::thread::available_parallelism;
37

48
#[cfg(feature = "online")]
@@ -10,7 +14,6 @@ use crate::{
1014
#[cfg(feature = "online")]
1115
use hf_hub::{api::sync::ApiBuilder, Cache};
1216
use ndarray::{s, Array};
13-
use ort::{GraphOptimizationLevel, Session, Value};
1417
use rayon::{iter::ParallelIterator, slice::ParallelSlice};
1518
use tokenizers::Tokenizer;
1619

src/reranking/init.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::path::{Path, PathBuf};
22

3-
use ort::{ExecutionProviderDispatch, Session};
3+
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
44
use tokenizers::Tokenizer;
55

66
use crate::{RerankerModel, TokenizerFiles, DEFAULT_CACHE_DIR};

src/sparse_text_embedding/impl.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@ use hf_hub::{
1111
Cache,
1212
};
1313
use ndarray::{Array, CowArray};
14+
use ort::{session::Session, value::Value};
1415
#[cfg_attr(not(feature = "online"), allow(unused_imports))]
15-
use ort::GraphOptimizationLevel;
16-
use ort::{Session, Value};
1716
use rayon::{iter::ParallelIterator, slice::ParallelSlice};
1817
#[cfg(feature = "online")]
1918
use std::path::PathBuf;
@@ -35,6 +34,7 @@ impl SparseTextEmbedding {
3534
#[cfg(feature = "online")]
3635
pub fn try_new(options: SparseInitOptions) -> Result<Self> {
3736
use super::SparseInitOptions;
37+
use ort::{session::builder::GraphOptimizationLevel, session::Session};
3838

3939
let SparseInitOptions {
4040
model_name,

src/sparse_text_embedding/init.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::path::{Path, PathBuf};
22

3-
use ort::{ExecutionProviderDispatch, Session};
3+
use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
44
use tokenizers::Tokenizer;
55

66
use crate::{models::sparse::SparseModel, TokenizerFiles, DEFAULT_CACHE_DIR};

0 commit comments

Comments
 (0)