diff --git a/.envrc b/.envrc new file mode 100644 index 00000000..3550a30f --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/rust/lib.rs b/rust/lib.rs index 3b97b830..135eaa70 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -431,6 +431,8 @@ pub mod ffi { } } +use std::marker::PhantomData; + // Re-export the FFI structs and enums at the crate root for easy access pub use ffi::{IndexOptions, MetricKind, ScalarKind}; @@ -481,7 +483,8 @@ pub use ffi::{IndexOptions, MetricKind, ScalarKind}; /// ``` /// /// In this example, `dimensions` should be defined and valid for the vectors `a` and `b`. -pub enum MetricFunction { +#[doc(hidden)] +pub enum MetricFunctionPtr { B1X8Metric(*mut std::boxed::Box Distance + Send + Sync>), I8Metric(*mut std::boxed::Box Distance + Send + Sync>), F16Metric(*mut std::boxed::Box Distance + Send + Sync>), @@ -489,6 +492,46 @@ pub enum MetricFunction { F64Metric(*mut std::boxed::Box Distance + Send + Sync>), } +impl MetricFunctionPtr { + /// Cast inner non-owning raw pointer to boxed closure into owned Box. + /// + /// # Safety + /// + /// Pointer must be valid, and the returned value must not be dropped + /// as long as the C++ side is still using the pointer to the closure. + unsafe fn into_owned(self) -> MetricFunctionOwned { + unsafe { + match self { + MetricFunctionPtr::B1X8Metric(pointer) => { + MetricFunctionOwned::B1X8Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::I8Metric(pointer) => { + MetricFunctionOwned::I8Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::F16Metric(pointer) => { + MetricFunctionOwned::F16Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::F32Metric(pointer) => { + MetricFunctionOwned::F32Metric(Box::from_raw(pointer)) + } + MetricFunctionPtr::F64Metric(pointer) => { + MetricFunctionOwned::F64Metric(Box::from_raw(pointer)) + } + } + } + } +} + +#[allow(unused)] +enum MetricFunctionOwned { + // Double boxed because Box is a wide pointer, and the C++ side needs a regular pointer + B1X8Metric(Box Distance + Send + Sync>>), + I8Metric(Box Distance + Send + Sync>>), + F16Metric(Box Distance + Send + Sync>>), + F32Metric(Box Distance + Send + Sync>>), + F64Metric(Box Distance + Send + Sync>>), +} + /// Approximate Nearest Neighbors search index for dense vectors. /// /// The `Index` struct provides an abstraction over a dense vector space, allowing @@ -500,7 +543,7 @@ pub enum MetricFunction { /// Basic usage: /// /// ```rust -/// use usearch::{Index, IndexOptions, MetricKind, ScalarKind}; +/// use usearch::{Index, IndexMethods, IndexViewMethods, IndexOptions, MetricKind, ScalarKind}; /// /// let mut options = IndexOptions::default(); /// options.dimensions = 4; // Set the number of dimensions for vectors @@ -527,31 +570,28 @@ pub enum MetricFunction { /// refer to the individual method documentation. pub struct Index { inner: cxx::UniquePtr, - metric_fn: Option, + scalar_kind: ScalarKind, + metric_fn: Option, } unsafe impl Send for Index {} unsafe impl Sync for Index {} +/// A read-only view into an index, read from a file or in-memory buffer. +pub struct IndexView<'buf> { + inner: Index, + _phantom_data: PhantomData<(&'buf [u8], *const ())>, +} + +unsafe impl Send for IndexView<'static> {} +unsafe impl Sync for IndexView<'static> {} + impl Drop for Index { fn drop(&mut self) { - if let Some(metric) = &self.metric_fn { - match metric { - MetricFunction::B1X8Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::I8Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::F16Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::F32Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, - MetricFunction::F64Metric(pointer) => unsafe { - drop(Box::from_raw(*pointer)); - }, + if let Some(metric) = self.metric_fn.take() { + // SAFETY: the pointed-to closure is never used again after Index is dropped. + unsafe { + drop(metric.into_owned()); } } } @@ -585,10 +625,68 @@ impl Clone for ffi::IndexOptions { } } +/// Data types are cast on the C++ side, but only some conversions are valid. +/// TODO: think about this more. I guess you can cast int8 to double, +/// but it's more likely an error. Maybe all casts except float to smaller float should be errors? +fn is_kind_convertible_to(a: ScalarKind, b: ScalarKind) -> bool { + match a { + ScalarKind::F16 | ScalarKind::F32 | ScalarKind::F64 | ScalarKind::BF16 => [ + ScalarKind::F16, + ScalarKind::F32, + ScalarKind::F64, + ScalarKind::BF16, + ] + .contains(&b), + ScalarKind::B1 | ScalarKind::I8 => a == b, + ScalarKind::Unknown => false, + ScalarKind { repr: _ } => unreachable!("Invalid Enum representation"), + } +} + +#[non_exhaustive] +pub enum IndexOperationError { + TypeError(ScalarKind, ScalarKind), + CXXException(cxx::Exception), +} + +impl std::error::Error for IndexOperationError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + IndexOperationError::CXXException(exception) => Some(exception), + _ => None, + } + } +} + +impl std::fmt::Display for IndexOperationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IndexOperationError::TypeError(twant, tgot) => write!( + f, + "Type Error: Attempted to use an Index storing type {twant:?} with type {tgot:?}.", + ), + IndexOperationError::CXXException(exception) => { + write!(f, "C++ Exception: {}", exception) + } + } + } +} + +impl std::fmt::Debug for IndexOperationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self, f) + } +} + /// The `VectorType` trait defines operations for managing and querying vectors /// in an index. It supports generic operations on vectors of different types, /// allowing for the addition, retrieval, and search of vectors within an index. -pub trait VectorType { +pub unsafe trait VectorType: Sized { + const KIND: ScalarKind; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunctionPtr; + /// Adds a vector to the index under the specified key. /// /// # Parameters @@ -598,8 +696,30 @@ pub trait VectorType { /// /// # Returns /// - `Ok(())` if the vector was successfully added to the index. - /// - `Err(cxx::Exception)` if an error occurred during the operation. - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> + /// - `Err(IndexOperationError)` if an error occurred during the operation. + fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), IndexOperationError> { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are compatible + // Check that `vector.len()` matches `dimensionality` happens on the C++ side. + unsafe { + Self::add_unchecked(index, key, vector).map_err(IndexOperationError::CXXException) + } + } + + /// Adds a vector to the index under the specified key. + /// Refer to [VectorType::add] for usage. + /// + /// # Safety + /// + /// - The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `vector`. + unsafe fn add_unchecked(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> where Self: Sized; @@ -611,12 +731,36 @@ pub trait VectorType { /// - `buffer`: A mutable slice where the retrieved vector will be stored. The size of the /// buffer determines the maximum number of elements that can be retrieved. /// - /// # Returns + /// # Retuns /// - `Ok(usize)` indicating the number of elements actually written into the `buffer`. - /// - `Err(cxx::Exception)` if an error occurred during the operation. - fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result - where - Self: Sized; + /// - `Err(IndexOperationError)` if an error occurred during the operation. + fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + // Check that `buffer` is large enough happens on the C++ side. + unsafe { + Self::get_unchecked(index, key, buffer).map_err(IndexOperationError::CXXException) + } + } + + /// Retrieves a vector from the index by its key. + /// Refer to [Index::get] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `buffer`. + unsafe fn get_unchecked( + index: &Index, + key: Key, + buffer: &mut [Self], + ) -> Result; /// Performs a search in the index using the given query vector, returning /// up to `count` closest matches. @@ -628,10 +772,38 @@ pub trait VectorType { /// /// # Returns /// - `Ok(ffi::Matches)` containing the matches found. - /// - `Err(cxx::Exception)` if an error occurred during the search operation. - fn search(index: &Index, query: &[Self], count: usize) -> Result - where - Self: Sized; + /// - `Err(IndexOperationError)` if an error occurred during the search operation. + fn search( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + unsafe { + Self::search_unchecked(index, query, count).map_err(IndexOperationError::CXXException) + } + } + + /// Performs a search in the index using the given query vector, returning + /// up to `count` closest matches. + /// Refer to [Index::search] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `query`. + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result; /// Performs an exact (brute force) search in the index using the given query vector, returning /// up to `count` closest matches. This search checks all vectors in the index, guaranteeing to find @@ -644,14 +816,38 @@ pub trait VectorType { /// /// # Returns /// - `Ok(ffi::Matches)` containing the matches found. - /// - `Err(cxx::Exception)` if an error occurred during the search operation. + /// - `Err(IndexOperationError)` if an error occurred during the search operation. fn exact_search( index: &Index, query: &[Self], count: usize, - ) -> Result - where - Self: Sized; + ) -> Result { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + unsafe { + Self::exact_search_unchecked(index, query, count) + .map_err(IndexOperationError::CXXException) + } + } + + /// Performs an exact (brute force) search in the index using the given query vector. + /// Refer to [Index::exact_search] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `query`. + unsafe fn exact_search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result; /// Performs a filtered search in the index using a query vector and a custom /// filter function, returning up to `count` matches that satisfy the filter. @@ -671,9 +867,39 @@ pub trait VectorType { query: &[Self], count: usize, filter: F, + ) -> Result + where + F: Fn(Key) -> bool, + { + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and vector scalar types are the same. + unsafe { + Self::filtered_search_unchecked(index, query, count, filter) + .map_err(IndexOperationError::CXXException) + } + } + + /// Performs a filtered search in the index using a query vector and a custom + /// filter function. + /// Refer to [Index::filtered_search] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as that of `query`. + unsafe fn filtered_search_unchecked( + index: &Index, + query: &[Self], + count: usize, + filter: F, ) -> Result where - Self: Sized, F: Fn(Key) -> bool; /// Changes the metric used for distance calculations within the index. @@ -685,21 +911,91 @@ pub trait VectorType { /// /// # Returns /// - `Ok(())` if the metric was successfully changed. - /// - `Err(cxx::Exception)` if an error occurred during the operation. + /// - `Err(IndexOperationError)` if an error occurred during the operation. fn change_metric( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> - where - Self: Sized; + ) -> Result<(), IndexOperationError> { + // TODO: same question as higher up. what kind of casts are allowed and sensible here? + if !is_kind_convertible_to(index.scalar_kind, Self::KIND) { + return Err(IndexOperationError::TypeError( + Self::KIND, + index.scalar_kind, + )); + } + + // SAFETY: index and metric types are the same. + unsafe { + Self::change_metric_unchecked(index, metric).map_err(IndexOperationError::CXXException) + } + } + + /// Changes the metric used for distance calculations within the index. + /// Refer to [Index::change_metric] for usage. + /// + /// # Safety + /// + /// The scalar element type of `index` (the `quantization` field in [ffi::IndexOptions]) + /// must be the same as the arguments to `metric`. + unsafe fn change_metric_unchecked( + index: &mut Index, + metric: std::boxed::Box Distance + Send + Sync>, + ) -> Result<(), cxx::Exception> { + if let Some(metric) = index.metric_fn.take() { + // SAFETY: We have an exclusive &mut to Index, so no one can be using the + // pointed-to closure. + unsafe { + drop(metric.into_owned()); + } + } + + index.metric_fn = Some(Self::METRIC_FN(Box::into_raw(Box::new(metric)))); + + // Trampoline is the function that knows how to call the Rust closure. + // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, + // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function + // and the number of dimensions. + extern "C" fn trampoline( + first: usize, + second: usize, + closure_address: usize, + ) -> Distance { + let first_ptr = first as *const T; + let second_ptr = second as *const T; + let closure: *mut _ = + closure_address as *mut Box Distance>; + unsafe { (*closure)(first_ptr, second_ptr) } + } + + let trampoline_fn: usize = trampoline:: as *const () as usize; + let closure_address = match index.metric_fn.as_ref().expect("Was just set to Some") { + MetricFunctionPtr::F32Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::B1X8Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::I8Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::F16Metric(metric) => (*metric as *mut _) as *mut () as usize, + MetricFunctionPtr::F64Metric(metric) => (*metric as *mut _) as *mut () as usize, + }; + index.inner.change_metric(trampoline_fn, closure_address); + + Ok(()) + } } -impl VectorType for f32 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { +unsafe impl VectorType for f32 { + const KIND: ScalarKind = ScalarKind::F32; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunctionPtr = MetricFunctionPtr::F32Metric; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_f32(query, count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -707,15 +1003,23 @@ impl VectorType for f32 { index.inner.exact_search_f32(query, count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_f32(key, vector) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_f32(key, vector) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -738,43 +1042,23 @@ impl VectorType for f32 { .inner .filtered_search_f32(query, count, trampoline_fn, closure_address) } - - fn change_metric( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::F32Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const f32; - let second_ptr = second as *const f32; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::F32Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected F32Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) - } } -impl VectorType for i8 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { +unsafe impl VectorType for i8 { + const KIND: ScalarKind = ScalarKind::I8; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunctionPtr = MetricFunctionPtr::I8Metric; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_i8(query, count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -782,15 +1066,23 @@ impl VectorType for i8 { index.inner.exact_search_i8(query, count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_i8(key, vector) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_i8(key, vector) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -813,42 +1105,23 @@ impl VectorType for i8 { .inner .filtered_search_i8(query, count, trampoline_fn, closure_address) } - fn change_metric( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::I8Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const i8; - let second_ptr = second as *const i8; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::I8Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected I8Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) - } } -impl VectorType for f64 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { +unsafe impl VectorType for f64 { + const KIND: ScalarKind = ScalarKind::F64; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunctionPtr = MetricFunctionPtr::F64Metric; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_f64(query, count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -856,15 +1129,23 @@ impl VectorType for f64 { index.inner.exact_search_f64(query, count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_f64(key, vector) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_f64(key, vector) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -887,42 +1168,23 @@ impl VectorType for f64 { .inner .filtered_search_f64(query, count, trampoline_fn, closure_address) } - fn change_metric( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::F64Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const f64; - let second_ptr = second as *const f64; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::F64Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected F64Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) - } } -impl VectorType for f16 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { +unsafe impl VectorType for f16 { + const KIND: ScalarKind = ScalarKind::F16; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunctionPtr = MetricFunctionPtr::F16Metric; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_f16(f16::to_i16s(query), count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -930,15 +1192,23 @@ impl VectorType for f16 { index.inner.exact_search_f16(f16::to_i16s(query), count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { - index.inner.get_f16(key, f16::to_mut_i16s(vector)) + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { + index.inner.get_f16(key, f16::to_mut_i16s(vector)) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_f16(key, f16::to_i16s(vector)) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -957,50 +1227,27 @@ impl VectorType for f16 { // Temporarily cast the closure to a raw pointer for passing. let trampoline_fn: usize = trampoline:: as *const () as usize; let closure_address: usize = &filter as *const F as usize; - index.inner.filtered_search_f16( - f16::to_i16s(query), - count, - trampoline_fn, - closure_address, - ) - } - - fn change_metric( - index: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) -> Result<(), cxx::Exception> { - // Store the metric function in the Index. - type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::F16Metric(Box::into_raw(Box::new(metric)))); - - // Trampoline is the function that knows how to call the Rust closure. - // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, - // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function - // and the number of dimensions. - extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance { - let first_ptr = first as *const f16; - let second_ptr = second as *const f16; - let closure: *mut MetricFn = closure_address as *mut MetricFn; - unsafe { (*closure)(first_ptr, second_ptr) } - } - - let trampoline_fn: usize = trampoline as *const () as usize; - let closure_address = match index.metric_fn { - Some(MetricFunction::F16Metric(metric)) => metric as *mut () as usize, - _ => panic!("Expected F16Metric"), - }; - index.inner.change_metric(trampoline_fn, closure_address); - - Ok(()) + index + .inner + .filtered_search_f16(f16::to_i16s(query), count, trampoline_fn, closure_address) } } -impl VectorType for b1x8 { - fn search(index: &Index, query: &[Self], count: usize) -> Result { +unsafe impl VectorType for b1x8 { + const KIND: ScalarKind = ScalarKind::B1; + const METRIC_FN: fn( + *mut std::boxed::Box Distance + Send + Sync>, + ) -> MetricFunctionPtr = MetricFunctionPtr::B1X8Metric; + + unsafe fn search_unchecked( + index: &Index, + query: &[Self], + count: usize, + ) -> Result { index.inner.search_b1x8(b1x8::to_u8s(query), count) } - fn exact_search( + unsafe fn exact_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -1008,15 +1255,23 @@ impl VectorType for b1x8 { index.inner.exact_search_b1x8(b1x8::to_u8s(query), count) } - fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result { + unsafe fn get_unchecked( + index: &Index, + key: Key, + vector: &mut [Self], + ) -> Result { index.inner.get_b1x8(key, b1x8::to_mut_u8s(vector)) } - fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> { + unsafe fn add_unchecked( + index: &Index, + key: Key, + vector: &[Self], + ) -> Result<(), cxx::Exception> { index.inner.add_b1x8(key, b1x8::to_u8s(vector)) } - fn filtered_search( + unsafe fn filtered_search_unchecked( index: &Index, query: &[Self], count: usize, @@ -1035,21 +1290,20 @@ impl VectorType for b1x8 { // Temporarily cast the closure to a raw pointer for passing. let trampoline_fn: usize = trampoline:: as *const () as usize; let closure_address: usize = &filter as *const F as usize; - index.inner.filtered_search_b1x8( - b1x8::to_u8s(query), - count, - trampoline_fn, - closure_address, - ) + index + .inner + .filtered_search_b1x8(b1x8::to_u8s(query), count, trampoline_fn, closure_address) } - fn change_metric( + unsafe fn change_metric_unchecked( index: &mut Index, metric: std::boxed::Box Distance + Send + Sync>, ) -> Result<(), cxx::Exception> { // Store the metric function in the Index. type MetricFn = Box Distance>; - index.metric_fn = Some(MetricFunction::B1X8Metric(Box::into_raw(Box::new(metric)))); + index.metric_fn = Some(MetricFunctionPtr::B1X8Metric(Box::into_raw(Box::new( + metric, + )))); // Trampoline is the function that knows how to call the Rust closure. // The `first` is a pointer to the first vector, `second` is a pointer to the second vector, @@ -1064,12 +1318,392 @@ impl VectorType for b1x8 { let trampoline_fn: usize = trampoline as *const () as usize; let closure_address = match index.metric_fn { - Some(MetricFunction::B1X8Metric(metric)) => metric as *mut () as usize, + Some(MetricFunctionPtr::B1X8Metric(metric)) => metric as *mut () as usize, _ => panic!("Expected F1X8Metric"), }; index.inner.change_metric(trampoline_fn, closure_address); - Ok(()) + Ok(()) + } +} + +pub trait IndexViewMethods { + /// Retrieves the expansion value used during index creation. + fn expansion_add(&self) -> usize; + + /// Retrieves the expansion value used during search. + fn expansion_search(&self) -> usize; + + /// Retrieves the hardware acceleration information. + fn hardware_acceleration(&self) -> String; + + /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query. + /// + /// # Arguments + /// + /// * `query` - A slice containing the query vector data. + /// * `count` - The maximum number of neighbors to search for. + /// + /// # Returns + /// + /// A `Result` containing the matches found. + fn search( + &self, + query: &[T], + count: usize, + ) -> Result; + + /// Performs exact (brute force) Nearest Neighbors Search for closest vectors to the provided query. + /// This search checks all vectors in the index, guaranteeing to find the true nearest neighbors, + /// but may be slower for large indices. + /// + /// # Arguments + /// + /// * `query` - A slice containing the query vector data. + /// * `count` - The maximum number of neighbors to search for. + /// + /// # Returns + /// + /// A `Result` containing the matches found. + fn exact_search( + &self, + query: &[T], + count: usize, + ) -> Result; + + /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query + /// satisfying a custom filter function. + /// + /// # Arguments + /// + /// * `query` - A slice containing the query vector data. + /// * `count` - The maximum number of neighbors to search for. + /// * `filter` - A closure that takes a `Key` and returns `true` if the corresponding vector should be included in the search results, or `false` otherwise. + /// + /// # Returns + /// + /// A `Result` containing the matches found. + fn filtered_search( + &self, + query: &[T], + count: usize, + filter: F, + ) -> Result + where + F: Fn(Key) -> bool; + + /// Extracts one or more vectors matching the specified key. + /// The `vector` slice must be a multiple of the number of dimensions in the index. + /// After the execution, return the number `X` of vectors found. + /// The vector slice's first `X * dimensions` elements will be filled. + /// + /// If you are a novice user, consider `export`. + /// + /// # Arguments + /// + /// * `key` - The key associated with the vector. + /// * `vector` - A slice containing the vector data. + fn get(&self, key: Key, vector: &mut [T]) -> Result; + + /// Extracts one or more vectors matching specified key into supplied resizable vector. + /// The `vector` is resized to a multiple of the number of dimensions in the index. + /// + /// # Arguments + /// + /// * `key` - The key associated with the vector. + /// * `vector` - A mutable vector containing the vector data. + fn export( + &self, + key: Key, + vector: &mut Vec, + ) -> Result; + + /// Retrieves the number of dimensions in the vectors indexed. + fn dimensions(&self) -> usize; + + /// Retrieves the connectivity parameter that limits connections-per-node in the graph. + fn connectivity(&self) -> usize; + + /// Retrieves the current number of vectors in the index. + fn size(&self) -> usize; + + /// Retrieves the total capacity of the index, including reserved space. + fn capacity(&self) -> usize; + + /// Reports expected file size after serialization. + fn serialized_length(&self) -> usize; + + /// Checks if the index contains a vector with a specified key. + /// + /// # Arguments + /// + /// * `key` - The key to be checked. + /// + /// # Returns + /// + /// `true` if the index contains the vector with the given key, `false` otherwise. + fn contains(&self, key: Key) -> bool; + + /// Count the count of vectors with the same specified key. + /// + /// # Arguments + /// + /// * `key` - The key to be checked. + /// + /// # Returns + /// + /// Number of vectors found. + fn count(&self, key: Key) -> usize; + + /// Saves the index to a specified file. + /// + /// # Arguments + /// + /// * `path` - The file path where the index will be saved. + fn save(&self, path: &str) -> Result<(), cxx::Exception>; + + /// A relatively accurate lower bound on the amount of memory consumed by the system. + /// In practice, its error will be below 10%. + fn memory_usage(&self) -> usize; + + /// Saves the index to a specified file. + /// + /// # Arguments + /// + /// * `buffer` - The buffer where the index will be saved. + fn save_to_buffer(&self, buffer: &mut [u8]) -> Result<(), cxx::Exception>; +} + +pub trait IndexMethods: IndexViewMethods { + /// Updates the expansion value used during index creation. Rarely used. + fn change_expansion_add(&self, n: usize); + + /// Updates the expansion value used during search operations. + fn change_expansion_search(&self, n: usize); + + /// Changes the metric kind used to calculate the distance between vectors. + fn change_metric_kind(self: &Self, metric: ffi::MetricKind); + + /// Overrides the metric function used to calculate the distance between vectors. + fn change_metric( + &mut self, + metric: std::boxed::Box Distance + Send + Sync>, + ) -> Result<(), IndexOperationError>; + + /// Adds a vector with a specified key to the index. + /// + /// # Arguments + /// + /// * `key` - The key associated with the vector. + /// * `vector` - A slice containing the vector data. + fn add(&self, key: Key, vector: &[T]) -> Result<(), IndexOperationError>; + + /// Reserves memory for a specified number of incoming vectors. + /// + /// # Arguments + /// + /// * `capacity` - The desired total capacity, including the current size. + fn reserve(&self, capacity: usize) -> Result<(), cxx::Exception>; + + /// Reserves memory for a specified number of incoming vectors & active threads. + /// + /// # Arguments + /// + /// * `capacity` - The desired total capacity, including the current size. + /// * `threads` - The number of threads to use for the operation. + fn reserve_capacity_and_threads( + &self, + capacity: usize, + threads: usize, + ) -> Result<(), cxx::Exception>; + + /// Removes the vector associated with the given key from the index. + /// + /// # Arguments + /// + /// * `key` - The key of the vector to be removed. + /// + /// # Returns + /// + /// `true` if the vector is successfully removed, `false` otherwise. + fn remove(&self, key: Key) -> Result; + + /// Renames the vector under a specific key. + /// + /// # Arguments + /// + /// * `from` - The key of the vector to be renamed. + /// * `to` - The new name. + /// + /// # Returns + /// + /// `true` if the vector is renamed, `false` otherwise. + fn rename(&self, from: Key, to: Key) -> Result; + + /// Loads the index from a specified file. + /// + /// # Arguments + /// + /// * `path` - The file path from where the index will be loaded. + fn load(&self, path: &str) -> Result<(), cxx::Exception>; + + /// Erases all members from the index, closes files, and returns RAM to OS. + fn reset(&self) -> Result<(), cxx::Exception>; +} + +impl IndexViewMethods for Index { + fn expansion_add(&self) -> usize { + self.inner.expansion_add() + } + + fn expansion_search(&self) -> usize { + self.inner.expansion_search() + } + + fn hardware_acceleration(&self) -> String { + use core::ffi::CStr; + unsafe { + let c_str = CStr::from_ptr(self.inner.hardware_acceleration()); + c_str.to_string_lossy().into_owned() + } + } + + fn search( + &self, + query: &[T], + count: usize, + ) -> Result { + T::search(self, query, count) + } + + fn exact_search( + &self, + query: &[T], + count: usize, + ) -> Result { + T::exact_search(self, query, count) + } + + fn filtered_search( + &self, + query: &[T], + count: usize, + filter: F, + ) -> Result + where + F: Fn(Key) -> bool, + { + T::filtered_search(self, query, count, filter) + } + + fn get(&self, key: Key, vector: &mut [T]) -> Result { + T::get(self, key, vector) + } + + fn export( + &self, + key: Key, + vector: &mut Vec, + ) -> Result { + let dim = self.dimensions(); + let max_matches = self.count(key); + vector.resize(dim * max_matches, T::default()); + let matches = T::get(self, key, &mut vector[..])?; + vector.resize(dim * matches, T::default()); + Ok(matches) + } + + fn dimensions(&self) -> usize { + self.inner.dimensions() + } + + fn connectivity(&self) -> usize { + self.inner.connectivity() + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn capacity(&self) -> usize { + self.inner.capacity() + } + + fn serialized_length(&self) -> usize { + self.inner.serialized_length() + } + + fn contains(&self, key: Key) -> bool { + self.inner.contains(key) + } + + fn count(&self, key: Key) -> usize { + self.inner.count(key) + } + + fn save(&self, path: &str) -> Result<(), cxx::Exception> { + self.inner.save(path) + } + + fn memory_usage(&self) -> usize { + self.inner.memory_usage() + } + + fn save_to_buffer(&self, buffer: &mut [u8]) -> Result<(), cxx::Exception> { + self.inner.save_to_buffer(buffer) + } +} + +impl IndexMethods for Index { + fn change_expansion_add(&self, n: usize) { + self.inner.change_expansion_add(n) + } + + fn change_expansion_search(&self, n: usize) { + self.inner.change_expansion_search(n) + } + + fn change_metric_kind(&self, metric: ffi::MetricKind) { + self.inner.change_metric_kind(metric) + } + + fn change_metric( + &mut self, + metric: std::boxed::Box Distance + Send + Sync>, + ) -> Result<(), IndexOperationError> { + T::change_metric(self, metric) + } + + fn add(&self, key: Key, vector: &[T]) -> Result<(), IndexOperationError> { + T::add(self, key, vector) + } + + fn reserve(&self, capacity: usize) -> Result<(), cxx::Exception> { + self.inner.reserve(capacity) + } + + fn reserve_capacity_and_threads( + &self, + capacity: usize, + threads: usize, + ) -> Result<(), cxx::Exception> { + self.inner.reserve_capacity_and_threads(capacity, threads) + } + + fn remove(&self, key: Key) -> Result { + self.inner.remove(key) + } + + fn rename(&self, from: Key, to: Key) -> Result { + self.inner.rename(from, to) + } + + fn load(&self, path: &str) -> Result<(), cxx::Exception> { + self.inner.load(path) + } + + fn reset(&self) -> Result<(), cxx::Exception> { + self.inner.reset() } } @@ -1078,352 +1712,146 @@ impl Index { match ffi::new_native_index(options) { Ok(inner) => Result::Ok(Self { inner, + scalar_kind: options.quantization, metric_fn: None, }), Err(err) => Err(err), } } - /// Retrieves the expansion value used during index creation. - pub fn expansion_add(self: &Index) -> usize { - self.inner.expansion_add() - } - - /// Retrieves the expansion value used during search. - pub fn expansion_search(self: &Index) -> usize { - self.inner.expansion_search() - } - - /// Updates the expansion value used during index creation. Rarely used. - pub fn change_expansion_add(self: &Index, n: usize) { - self.inner.change_expansion_add(n) - } - - /// Updates the expansion value used during search operations. - pub fn change_expansion_search(self: &Index, n: usize) { - self.inner.change_expansion_search(n) + /// Loads the index from a specified file. + /// + /// # Arguments + /// + /// * `buffer` - The buffer from where the index will be loaded. + pub fn load_from_buffer(&self, buffer: &[u8]) -> Result<(), cxx::Exception> { + self.inner.load_from_buffer(buffer) } +} - /// Changes the metric kind used to calculate the distance between vectors. - pub fn change_metric_kind(self: &Index, metric: ffi::MetricKind) { - self.inner.change_metric_kind(metric) +impl<'buf> IndexViewMethods for IndexView<'buf> { + fn expansion_add(&self) -> usize { + self.inner.expansion_add() } - /// Overrides the metric function used to calculate the distance between vectors. - pub fn change_metric( - self: &mut Index, - metric: std::boxed::Box Distance + Send + Sync>, - ) { - T::change_metric(self, metric).unwrap(); + fn expansion_search(&self) -> usize { + self.inner.expansion_search() } - /// Retrieves the hardware acceleration information. - pub fn hardware_acceleration(&self) -> String { - use core::ffi::CStr; - unsafe { - let c_str = CStr::from_ptr(self.inner.hardware_acceleration()); - c_str.to_string_lossy().into_owned() - } + fn hardware_acceleration(&self) -> String { + self.inner.hardware_acceleration() } - /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query. - /// - /// # Arguments - /// - /// * `query` - A slice containing the query vector data. - /// * `count` - The maximum number of neighbors to search for. - /// - /// # Returns - /// - /// A `Result` containing the matches found. - pub fn search( - self: &Index, + fn search( + &self, query: &[T], count: usize, - ) -> Result { - T::search(self, query, count) + ) -> Result { + self.inner.search(query, count) } - /// Performs exact (brute force) Nearest Neighbors Search for closest vectors to the provided query. - /// This search checks all vectors in the index, guaranteeing to find the true nearest neighbors, - /// but may be slower for large indices. - /// - /// # Arguments - /// - /// * `query` - A slice containing the query vector data. - /// * `count` - The maximum number of neighbors to search for. - /// - /// # Returns - /// - /// A `Result` containing the matches found. - pub fn exact_search( - self: &Index, + fn exact_search( + &self, query: &[T], count: usize, - ) -> Result { - T::exact_search(self, query, count) + ) -> Result { + self.inner.exact_search(query, count) } - /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query - /// satisfying a custom filter function. - /// - /// # Arguments - /// - /// * `query` - A slice containing the query vector data. - /// * `count` - The maximum number of neighbors to search for. - /// * `filter` - A closure that takes a `Key` and returns `true` if the corresponding vector should be included in the search results, or `false` otherwise. - /// - /// # Returns - /// - /// A `Result` containing the matches found. - pub fn filtered_search( - self: &Index, + fn filtered_search( + &self, query: &[T], count: usize, filter: F, - ) -> Result + ) -> Result where F: Fn(Key) -> bool, { - T::filtered_search(self, query, count, filter) - } - - /// Adds a vector with a specified key to the index. - /// - /// # Arguments - /// - /// * `key` - The key associated with the vector. - /// * `vector` - A slice containing the vector data. - pub fn add(self: &Index, key: Key, vector: &[T]) -> Result<(), cxx::Exception> { - T::add(self, key, vector) + self.inner.filtered_search(query, count, filter) } - /// Extracts one or more vectors matching the specified key. - /// The `vector` slice must be a multiple of the number of dimensions in the index. - /// After the execution, return the number `X` of vectors found. - /// The vector slice's first `X * dimensions` elements will be filled. - /// - /// If you are a novice user, consider `export`. - /// - /// # Arguments - /// - /// * `key` - The key associated with the vector. - /// * `vector` - A slice containing the vector data. - pub fn get( - self: &Index, - key: Key, - vector: &mut [T], - ) -> Result { - T::get(self, key, vector) + fn get(&self, key: Key, vector: &mut [T]) -> Result { + self.inner.get(key, vector) } - /// Extracts one or more vectors matching specified key into supplied resizable vector. - /// The `vector` is resized to a multiple of the number of dimensions in the index. - /// - /// # Arguments - /// - /// * `key` - The key associated with the vector. - /// * `vector` - A mutable vector containing the vector data. - pub fn export( - self: &Index, + fn export( + &self, key: Key, vector: &mut Vec, - ) -> Result { - let dim = self.dimensions(); - let max_matches = self.count(key); - vector.resize(dim * max_matches, T::default()); - let matches = T::get(self, key, &mut vector[..])?; - vector.resize(dim * matches, T::default()); - Ok(matches) - } - - /// Reserves memory for a specified number of incoming vectors. - /// - /// # Arguments - /// - /// * `capacity` - The desired total capacity, including the current size. - pub fn reserve(self: &Index, capacity: usize) -> Result<(), cxx::Exception> { - self.inner.reserve(capacity) - } - - /// Reserves memory for a specified number of incoming vectors & active threads. - /// - /// # Arguments - /// - /// * `capacity` - The desired total capacity, including the current size. - /// * `threads` - The number of threads to use for the operation. - pub fn reserve_capacity_and_threads( - self: &Index, - capacity: usize, - threads: usize, - ) -> Result<(), cxx::Exception> { - self.inner.reserve_capacity_and_threads(capacity, threads) + ) -> Result { + self.inner.export(key, vector) } - /// Retrieves the number of dimensions in the vectors indexed. - pub fn dimensions(self: &Index) -> usize { + fn dimensions(&self) -> usize { self.inner.dimensions() } - /// Retrieves the connectivity parameter that limits connections-per-node in the graph. - pub fn connectivity(self: &Index) -> usize { + fn connectivity(&self) -> usize { self.inner.connectivity() } - /// Retrieves the current number of vectors in the index. - pub fn size(self: &Index) -> usize { + fn size(&self) -> usize { self.inner.size() } - /// Retrieves the total capacity of the index, including reserved space. - pub fn capacity(self: &Index) -> usize { + fn capacity(&self) -> usize { self.inner.capacity() } - /// Reports expected file size after serialization. - pub fn serialized_length(self: &Index) -> usize { + fn serialized_length(&self) -> usize { self.inner.serialized_length() } - /// Removes the vector associated with the given key from the index. - /// - /// # Arguments - /// - /// * `key` - The key of the vector to be removed. - /// - /// # Returns - /// - /// `true` if the vector is successfully removed, `false` otherwise. - pub fn remove(self: &Index, key: Key) -> Result { - self.inner.remove(key) - } - - /// Renames the vector under a specific key. - /// - /// # Arguments - /// - /// * `from` - The key of the vector to be renamed. - /// * `to` - The new name. - /// - /// # Returns - /// - /// `true` if the vector is renamed, `false` otherwise. - pub fn rename(self: &Index, from: Key, to: Key) -> Result { - self.inner.rename(from, to) - } - - /// Checks if the index contains a vector with a specified key. - /// - /// # Arguments - /// - /// * `key` - The key to be checked. - /// - /// # Returns - /// - /// `true` if the index contains the vector with the given key, `false` otherwise. - pub fn contains(self: &Index, key: Key) -> bool { + fn contains(&self, key: Key) -> bool { self.inner.contains(key) } - /// Count the count of vectors with the same specified key. - /// - /// # Arguments - /// - /// * `key` - The key to be checked. - /// - /// # Returns - /// - /// Number of vectors found. - pub fn count(self: &Index, key: Key) -> usize { + fn count(&self, key: Key) -> usize { self.inner.count(key) } - /// Saves the index to a specified file. - /// - /// # Arguments - /// - /// * `path` - The file path where the index will be saved. - pub fn save(self: &Index, path: &str) -> Result<(), cxx::Exception> { + fn save(&self, path: &str) -> Result<(), cxx::Exception> { self.inner.save(path) } - /// Loads the index from a specified file. - /// - /// # Arguments - /// - /// * `path` - The file path from where the index will be loaded. - pub fn load(self: &Index, path: &str) -> Result<(), cxx::Exception> { - self.inner.load(path) - } - - /// Creates a view of the index from a file without loading it into memory. - /// - /// # Arguments - /// - /// * `path` - The file path from where the view will be created. - pub fn view(self: &Index, path: &str) -> Result<(), cxx::Exception> { - self.inner.view(path) - } - - /// Erases all members from the index, closes files, and returns RAM to OS. - pub fn reset(self: &Index) -> Result<(), cxx::Exception> { - self.inner.reset() - } - - /// A relatively accurate lower bound on the amount of memory consumed by the system. - /// In practice, its error will be below 10%. - pub fn memory_usage(self: &Index) -> usize { + fn memory_usage(&self) -> usize { self.inner.memory_usage() } - /// Saves the index to a specified file. - /// - /// # Arguments - /// - /// * `buffer` - The buffer where the index will be saved. - pub fn save_to_buffer(self: &Index, buffer: &mut [u8]) -> Result<(), cxx::Exception> { + fn save_to_buffer(&self, buffer: &mut [u8]) -> Result<(), cxx::Exception> { self.inner.save_to_buffer(buffer) } +} - /// Loads the index from a specified file. +impl<'buf> IndexView<'buf> { + /// Creates a view of the index from a file without loading it into memory. /// /// # Arguments /// - /// * `buffer` - The buffer from where the index will be loaded. - pub fn load_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> { - self.inner.load_from_buffer(buffer) + /// * `buffer` - The buffer from where the view will be created. + pub fn new_from_buffer(buffer: &'buf [u8]) -> Result { + let inner = Index::new(&IndexOptions::default())?; + inner.inner.view_from_buffer(buffer)?; + Ok(IndexView { + inner, + _phantom_data: PhantomData, + }) } +} +impl IndexView<'static> { /// Creates a view of the index from a file without loading it into memory. /// /// # Arguments /// - /// * `buffer` - The buffer from where the view will be created. - /// - /// # Safety - /// - /// This function is marked as `unsafe` because it stores a pointer to the input buffer. - /// The caller must ensure that the buffer outlives the index and is not dropped - /// or modified for the duration of the index's use. Dereferencing a pointer to a - /// temporary buffer after it has been dropped can lead to undefined behavior, - /// which violates Rust's memory safety guarantees. - /// - /// Example of misuse: - /// - /// ```rust,ignore - /// let index: usearch::Index = usearch::new_index(&usearch::IndexOptions::default()).unwrap(); - /// - /// let temporary = vec![0u8; 100]; - /// index.view_from_buffer(&temporary); - /// std::mem::drop(temporary); - /// - /// let query = vec![0.0; 256]; - /// let results = index.search(&query, 5).unwrap(); - /// ``` - /// - /// The above example would result in use-after-free and undefined behavior. - pub unsafe fn view_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> { - self.inner.view_from_buffer(buffer) + /// * `path` - The file path from where the view will be created. + pub fn new_from_file(path: &str) -> Result { + let inner = Index::new(&IndexOptions::default())?; + inner.inner.view(path)?; + Ok(IndexView { + inner, + _phantom_data: PhantomData, + }) } } @@ -1439,8 +1867,9 @@ mod tests { use crate::b1x8; use crate::new_index; - use crate::Index; use crate::Key; + use crate::{Index, IndexView}; + use crate::{IndexMethods, IndexViewMethods}; use std::env; @@ -1531,10 +1960,12 @@ mod tests { let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1]; let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1]; let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0]; + let wrong_type: [i8; 5] = [1, 2, 3, 4, 5]; assert!(index.add(1, &first).is_ok()); assert!(index.add(2, &second).is_ok()); assert!(index.add(3, &too_long).is_err()); assert!(index.add(4, &too_short).is_err()); + assert!(index.add(5, &wrong_type).is_err()); assert_eq!(index.size(), 2); // Test using Vec @@ -1567,6 +1998,7 @@ mod tests { let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1]; let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1]; let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0]; + let wrong_type: [i8; 5] = [1, 2, 3, 4, 5]; assert!(index.add(1, &first).is_ok()); assert!(index.add(2, &second).is_ok()); assert_eq!(index.size(), 2); @@ -1574,7 +2006,11 @@ mod tests { //assert!(index.add(4, &too_short).is_err()); assert!(index.search(&too_long, 1).is_err()); + assert!(index.exact_search(&too_long, 1).is_err()); assert!(index.search(&too_short, 1).is_err()); + assert!(index.exact_search(&too_short, 1).is_err()); + assert!(index.search(&wrong_type, 1).is_err()); + assert!(index.exact_search(&wrong_type, 1).is_err()); } #[test] @@ -1685,7 +2121,13 @@ mod tests { // Validate serialization assert!(index.save("index.rust.usearch").is_ok()); assert!(index.load("index.rust.usearch").is_ok()); - assert!(index.view("index.rust.usearch").is_ok()); + + let index_view = IndexView::new_from_file("index.rust.usearch").unwrap(); + let results = index_view.search(&first, 10).unwrap(); + println!("{:?}", results); + assert_eq!(results.keys.len(), 2); + let mut out = [0f32; 5]; + assert!(index_view.get(43, &mut out).is_ok()); // Make sure every function is called at least once assert!(new_index(&options).is_ok()); @@ -1887,12 +2329,43 @@ mod tests { (a_slice[0] - b_slice[0]).abs() * first_factor + (a_slice[1] - b_slice[1]).abs() * second_factor }); - index.change_metric(stateful_distance); + assert!(index.change_metric(stateful_distance).is_ok()); + + let wrong_type = Box::new(move |_: *const b1x8, _: *const b1x8| 0.0); + assert!(index.change_metric(wrong_type).is_err()); let another_vector: [f32; 2] = [0.0, 1.0]; index.add(2, &another_vector).unwrap(); } + #[test] + fn test_change_metric_leak() { + let options = IndexOptions { + dimensions: 2, + quantization: ScalarKind::F32, + ..Default::default() + }; + let mut index = Index::new(&options).unwrap(); + index.reserve(10).unwrap(); + + let vector: [f32; 2] = [1.0, 0.0]; + index.add(1, &vector).unwrap(); + + let counter = std::sync::Arc::new(std::sync::Mutex::new(0f32)); + + let n: i32 = 100; + for _ in 0..n { + let counter_copy = counter.clone(); + let metric = + Box::new(move |_: *const f32, _: *const f32| *counter_copy.lock().unwrap()); + index.change_metric(metric).unwrap(); + } + drop(index); + // Only one reference to counter (the one held in this scope, not the closure) + // should be left. + assert_eq!(std::sync::Arc::strong_count(&counter), 1); + } + #[test] fn test_binary_vectors_and_hamming_distance() { let index = Index::new(&IndexOptions { @@ -2035,4 +2508,15 @@ mod tests { "All searches should find exact matches" ); } + + #[test] + fn test_index_file_view_is_sync() { + #[allow(unused)] + fn assert_sync() {} + #[allow(unused)] + fn assert_send() {} + + assert_sync::>(); + assert_send::>(); + } }