Skip to content
Merged
Next Next commit
Implement vector search api
  • Loading branch information
cfraz89 committed Apr 24, 2025
commit 2b0a325b49cd36b09c2a1ff3c0f4d0ce5a743876
183 changes: 183 additions & 0 deletions async-openai/src/types/vector_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,186 @@ pub struct VectorStoreFileBatchObject {
pub status: VectorStoreFileBatchStatus,
pub file_counts: VectorStoreFileBatchCounts,
}

#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
#[builder(name = "VectorStoreSearchRequestArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct VectorStoreSearchRequest {
/// A query string for a search.
pub query: VectorStoreSearchQuery,

/// Whether to rewrite the natural language query for vector search.
#[serde(skip_serializing_if = "Option::is_none")]
pub rewrite_query: Option<bool>,

/// The maximum number of results to return. This number should be between 1 and 50 inclusive.
#[serde(skip_serializing_if = "Option::is_none")]
pub max_num_results: Option<u8>,

/// A filter to apply based on file attributes.
#[serde(skip_serializing_if = "Option::is_none")]
pub filters: Option<VectorStoreSearchFilter>,

/// Ranking options for search.
#[serde(skip_serializing_if = "Option::is_none")]
pub ranking_options: Option<RankingOptions>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum VectorStoreSearchQuery {
/// A single query to search for.
Text(String),
/// A list of queries to search for.
Array(Vec<String>),
}

impl Default for VectorStoreSearchQuery {
fn default() -> Self {
Self::Text(String::new())
}
}

impl From<String> for VectorStoreSearchQuery {
fn from(query: String) -> Self {
Self::Text(query)
}
}

impl From<&str> for VectorStoreSearchQuery {
fn from(query: &str) -> Self {
Self::Text(query.to_string())
}
}

impl From<Vec<String>> for VectorStoreSearchQuery {
fn from(query: Vec<String>) -> Self {
Self::Array(query)
}
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum VectorStoreSearchFilter {
Comparison(ComparisonFilter),
Compound(CompoundFilter),
}

/// A filter used to compare a specified attribute key to a given value using a defined comparison operation.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ComparisonFilter {
/// Specifies the comparison operator: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`.
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<ComparisonType>,

/// The key to compare against the value.
pub key: String,

/// The value to compare against the attribute key; supports string, number, or boolean types.
pub value: ComparisonValue,
}

/// Specifies the comparison operator: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ComparisonType {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
}

/// The value to compare against the attribute key; supports string, number, or boolean types.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum ComparisonValue {
String(String),
Number(i64),
Boolean(bool),
}

/// Ranking options for search.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct RankingOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub ranker: Option<Ranker>,

#[serde(skip_serializing_if = "Option::is_none")]
pub score_threshold: Option<f32>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub enum Ranker {
#[serde(rename = "auto")]
Auto,
#[serde(rename = "default-2024-11-15")]
Default20241115,
}

/// Combine multiple filters using `and` or `or`.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct CompoundFilter {
/// Type of operation: `and` or `or`.
pub r#type: Option<CompoundFilterType>,

/// Array of filters to combine. Items can be `ComparisonFilter` or `CompoundFilter`
pub filters: Vec<ComparisonFilter>,
}

/// Type of operation: `and` or `or`.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum CompoundFilterType {
And,
Or,
}

#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct VectorStoreSearchResultsPage {
/// The object type, which is always `vector_store.search_results.page`.
pub object: String,

/// The query used for this search.
pub search_query: Vec<String>,

/// The list of search result items.
pub data: Vec<VectorStoreSearchResultItem>,

/// Indicates if there are more results to fetch.
pub has_more: bool,

/// The token for the next page, if any.
pub next_page: Option<String>,
}

#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct VectorStoreSearchResultItem {
/// The ID of the vector store file.
pub file_id: String,

/// The name of the vector store file.
pub filename: String,

/// The similarity score for the result.
pub score: f32, // minimum: 0, maximum: 1

/// Attributes of the vector store file.
pub attributes: HashMap<String, serde_json::Value>,

/// Content chunks from the file.
pub content: Vec<VectorStoreSearchResultContentObject>,
}

#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct VectorStoreSearchResultContentObject {
/// The type of content
pub r#type: String,

/// The text content returned from search.
pub text: String,
}
15 changes: 14 additions & 1 deletion async-openai/src/vector_stores.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::{
error::OpenAIError,
types::{
CreateVectorStoreRequest, DeleteVectorStoreResponse, ListVectorStoresResponse,
UpdateVectorStoreRequest, VectorStoreObject,
UpdateVectorStoreRequest, VectorStoreObject, VectorStoreSearchRequest,
VectorStoreSearchResultsPage,
},
vector_store_file_batches::VectorStoreFileBatches,
Client, VectorStoreFiles,
Expand Down Expand Up @@ -78,4 +79,16 @@ impl<'c, C: Config> VectorStores<'c, C> {
.post(&format!("/vector_stores/{vector_store_id}"), request)
.await
}

/// Searches a vector store.
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn search(
&self,
vector_store_id: &str,
request: VectorStoreSearchRequest,
) -> Result<VectorStoreSearchResultsPage, OpenAIError> {
self.client
.post(&format!("/vector_stores/{vector_store_id}/search"), request)
.await
}
}