Skip to content

Commit d0dff0b

Browse files
wip: working multivector query
1 parent 36e4f05 commit d0dff0b

File tree

1 file changed

+126
-67
lines changed

1 file changed

+126
-67
lines changed

redisvl/query/aggregate.py

Lines changed: 126 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -260,76 +260,139 @@ class MultiVectorQuery(AggregationQuery):
260260
results = index.query(query)
261261
262262
263+
264+
FT.AGGREGATE multi_vector_test
265+
"@user_embedding:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0}
266+
| @image_embedding:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}"
267+
PARAMS 4
268+
vector_0 "\xcd\xcc\xcc=\xcd\xcc\xcc=\x00\x00\x00?"
269+
vector_1 "\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?\x9a\x99\x99\x99\x99\x99\xb9?"
270+
APPLY "(2 - @distance_0)/2" AS score_0
271+
APPLY "(2 - @distance_1)/2" AS score_1
272+
DIALECT 2
273+
APPLY "(@score_0 + @score_1)" AS combined_score
274+
SORTBY 2 @combined_score
275+
ASC
276+
MAX 10
277+
LOAD 2 score_0 score_1
278+
279+
280+
281+
282+
263283
FT.AGGREGATE 'idx:characters'
264-
"@embedding1:[VECTOR_RANGE .7 $vector1]=>{$YIELD_DISTANCE_AS: vector_distance1} | @embedding2:[VECTOR_RANGE 1.0 $vector2]=>{$YIELD_DISTANCE_AS: vector_distance2} | @embedding3:[VECTOR_RANGE 1.7 $vector3]=>{$YIELD_DISTANCE_AS: vector_distance3} | @name:(James)"
265-
ADDSCORES
266-
SCORER BM25STD.NORM
267-
LOAD 2 created_at @embedding
268-
APPLY 'case(exists(@vector_distance1), @vector_distance1, 0.0)' as v1
269-
APPLY 'case(exists(@vector_distance2), @vector_distance2, 0.0)' as v2
270-
APPLY 'case(exists(@vector_distance3), @vector_distance3, 0.0)' as v3
284+
"@embedding1:[VECTOR_RANGE .7 $vector1]=>{$YIELD_DISTANCE_AS: vector_distance1}
285+
| @embedding2:[VECTOR_RANGE 1.0 $vector2]=>{$YIELD_DISTANCE_AS: vector_distance2}
286+
| @embedding3:[VECTOR_RANGE 1.7 $vector3]=>{$YIELD_DISTANCE_AS: vector_distance3}
287+
| @name:(James)"
288+
### ADDSCORES
289+
### SCORER BM25STD.NORM
290+
### LOAD 2 created_at @embedding
291+
APPLY '(2 - @vector_distance1)/2' as v1
292+
APPLY '(2 - @vector_distance2)/2' as v2
293+
APPLY '(2 - @vector_distance3)/2' as v3
271294
APPLY '(@__score * 0.3 + (@v1 * 0.3) + (@v2 * 1.2) + (@v3 * 0.1))' AS final_score
272295
PARAMS 6 vector1 "\xe4\xd6..." vector2 "\x89\xa0..." vector3 "\x3c\x19..."
273296
SORTBY 2 @final_score DESC
274297
DIALECT 2
275298
LIMIT 0 100
276299
277-
278300
"""
279301

280302
DISTANCE_ID: str = "vector_distance"
281-
VECTOR_PARAM: str = "vector"
282303

283304
def __init__(
284305
self,
285306
vectors: Union[bytes, List[bytes], List[float], List[List[float]]],
286307
vector_field_names: Union[str, List[str]],
308+
weights: List[float] = [1.0],
309+
return_fields: Optional[List[str]] = None,
287310
filter_expression: Optional[Union[str, FilterExpression]] = None,
288-
weights: Union[float, List[float]] = 1.0,
289-
dtypes: Union[str, List[str]] = "float32",
311+
dtypes: List[str] = ["float32"],
290312
num_results: int = 10,
291-
return_fields: Optional[List[str]] = None,
313+
return_score: bool = False,
292314
dialect: int = 2,
293315
):
294316
"""
295317
Instantiates a MultiVectorQuery object.
296318
297319
Args:
298320
vectors (Union[bytes, List[bytes], List[float], List[List[float]]): The vectors to perform vector similarity search.
299-
vector_field_names (str): The vector field names to search in.
300-
filter_expression (Optional[FilterExpression], optional): The filter expression to use.
301-
Defaults to None.
302-
weights (Union[float, List[float]], optional): The weights of the vector similarity.
321+
vector_field_names (Union[str, List[str]]): The vector field names to search in.
322+
weights (List[float]): The weights of the vector similarity.
303323
Documents will be scored as:
304324
score = (w1) * score1 + (w2) * score2 + (w3) * score3 + ...
305-
Defaults to 1.0, which corresponds to equal weighting
306-
dtype (Union[str, List[str]] optional): The data types of the vectors. Defaults to "float32" for all vectors.
307-
num_results (int, optional): The number of results to return. Defaults to 10.
325+
Defaults to [1.0], which corresponds to equal weighting
308326
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
327+
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
328+
Defaults to None.
329+
dtypes (List[str]): The data types of the vectors. Defaults to ["float32"] for all vectors.
330+
num_results (int, optional): The number of results to return. Defaults to 10.
331+
return_score (bool): Whether to return the combined vector similarity score.
332+
Defaults to False.
309333
dialect (int, optional): The Redis dialect version. Defaults to 2.
310334
311335
Raises:
312336
ValueError: The number of vectors, vector field names, and weights do not agree.
313-
TypeError: If the stopwords are not a set, list, or tuple of strings.
314337
"""
315338

316-
self._vectors = vectors
317-
self._vector_fields = vector_field_names
318339
self._filter_expression = filter_expression
319-
self._weights = weights
320340
self._dtypes = dtypes
321341
self._num_results = num_results
322342

343+
if len(vectors) == 0 or len(vector_field_names) == 0 or len(weights) == 0:
344+
raise ValueError(
345+
f"""The number of vectors and vector field names must be equal.
346+
If weights are specified their number must match the number of vectors and vector field names also.
347+
Length of vectors list: {len(vectors) = }
348+
Length of vector_field_names list: {len(vector_field_names) = }
349+
Length of weights list: {len(weights) = }
350+
"""
351+
)
352+
353+
if isinstance(vectors, bytes) or isinstance(vectors[0], float):
354+
self._vectors = [vectors]
355+
else:
356+
self._vectors = vectors
357+
if isinstance(vector_field_names, str):
358+
self._vector_field_names = [vector_field_names]
359+
else:
360+
self._vector_field_names = vector_field_names
361+
if len(weights) == 1:
362+
self._weights = weights * len(vectors)
363+
else:
364+
self._weights = weights
365+
if len(dtypes) == 1:
366+
self._dtypes = dtypes * len(vectors)
367+
else:
368+
self._dtypes = dtypes
369+
370+
if (len(self._vectors) != len(self._vector_field_names)) or (
371+
len(self._vectors) != len(self._weights)
372+
):
373+
raise ValueError(
374+
f"""The number of vectors and vector field names must be equal.
375+
If weights are specified their number must match the number of vectors and vector field names also.
376+
Length of vectors list: {len(self._vectors) = }
377+
Length of vector_field_names list: {len(self._vector_field_names) = }
378+
Length of weights list: {len(self._weights) = }
379+
"""
380+
)
381+
323382
query_string = self._build_query_string()
324383
super().__init__(query_string)
325384

326-
self.scorer(text_scorer)
327-
self.add_scores()
328-
self.apply(
329-
vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score"
330-
)
331-
self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity")
332-
self.sort_by(Desc("@hybrid_score"), max=num_results) # type: ignore
385+
# construct the scoring string based on the vector similarity scores and weights
386+
combined_scores = []
387+
for i, w in enumerate(self._weights):
388+
combined_scores.append(f"@score_{i} * {w}")
389+
combined_score_string = " + ".join(combined_scores)
390+
combined_score_string = f"'({combined_score_string})'"
391+
392+
self.apply(combined_score=combined_score_string)
393+
394+
# self.add_scores()
395+
self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore
333396
self.dialect(dialect)
334397
if return_fields:
335398
self.load(*return_fields) # type: ignore[arg-type]
@@ -341,49 +404,45 @@ def params(self) -> Dict[str, Any]:
341404
Returns:
342405
Dict[str, Any]: The parameters for the aggregation.
343406
"""
344-
if isinstance(self._vector, list):
345-
vector = array_to_buffer(self._vector, dtype=self._dtype)
346-
else:
347-
vector = self._vector
348-
349-
params = {self.VECTOR_PARAM: vector}
350-
407+
params = {}
408+
for i, (vector, vector_field, dtype) in enumerate(zip(
409+
self._vectors, self._vector_field_names, self._dtypes
410+
)):
411+
if isinstance(vector, list):
412+
vector = array_to_buffer(vector, dtype=dtype)
413+
params[f"vector_{i}"] = vector
351414
return params
352415

353-
def _tokenize_and_escape_query(self, user_query: str) -> str:
354-
"""Convert a raw user query to a redis full text query joined by ORs
355-
Args:
356-
user_query (str): The user query to tokenize and escape.
357-
358-
Returns:
359-
str: The tokenized and escaped query string.
360-
Raises:
361-
ValueError: If the text string becomes empty after stopwords are removed.
362-
"""
363-
escaper = TokenEscaper()
364-
365-
tokens = [
366-
escaper.escape(
367-
token.strip().strip(",").replace("“", "").replace("”", "").lower()
368-
)
369-
for token in user_query.split()
370-
]
371-
tokenized = " | ".join(
372-
[token for token in tokens if token and token not in self._stopwords]
373-
)
374-
375-
if not tokenized:
376-
raise ValueError("text string cannot be empty after removing stopwords")
377-
return tokenized
378-
379416
def _build_query_string(self) -> str:
380417
"""Build the full query string for text search with optional filtering."""
418+
419+
filter_expression = self._filter_expression
381420
if isinstance(self._filter_expression, FilterExpression):
382421
filter_expression = str(self._filter_expression)
383-
else:
384-
filter_expression = ""
385422

386423
# base KNN query
387-
knn_query = f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}"
424+
knn_queries = []
425+
range_queries = []
426+
for i, (vector, field) in enumerate(zip(self._vectors, self._vector_field_names)):
427+
knn_queries.append(f"[KNN {self._num_results} @{field} $vector_{i} AS distance_{i}]")
428+
range_queries.append(f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}")
429+
430+
knn_query = " | ".join(knn_queries) ## knn_queries format doesn't work
431+
knn_query = " | ".join(range_queries)
432+
433+
# calculate the respective vector similarities
434+
apply_string = ""
435+
for i, (vector, field_name, weight) in enumerate(
436+
zip(self._vectors, self._vector_field_names, self._weights)
437+
):
438+
apply_string += f'APPLY "(2 - @distance_{i})/2" AS score_{i} '
388439

389-
return f"{filter_expression})=>[{knn_query}]"
440+
return (
441+
f"{knn_query} {filter_expression} {apply_string}"
442+
if filter_expression
443+
else f"{knn_query} {apply_string}"
444+
)
445+
446+
def __str__(self) -> str:
447+
"""Return the string representation of the query."""
448+
return " ".join([str(x) for x in self.build_args()])

0 commit comments

Comments
 (0)