@@ -260,76 +260,139 @@ class MultiVectorQuery(AggregationQuery):
260260            results = index.query(query) 
261261
262262
263+ 
264+         FT.AGGREGATE multi_vector_test  
265+         "@user_embedding:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} 
266+         | @image_embedding:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}"  
267+         PARAMS 4 
268+         vector_0 "\xcd \xcc \xcc =\xcd \xcc \xcc =\x00 \x00 \x00 ?"  
269+         vector_1 "\x9a \x99 \x99 \x99 \x99 \x99 \xb9 ?\x9a \x99 \x99 \x99 \x99 \x99 \xb9 ?\x9a \x99 \x99 \x99 \x99 \x99 \xb9 ?\x9a \x99 \x99 \x99 \x99 \x99 \xb9 ?\x9a \x99 \x99 \x99 \x99 \x99 \xb9 ?"  
270+         APPLY "(2 - @distance_0)/2" AS score_0 
271+         APPLY "(2 - @distance_1)/2" AS score_1  
272+         DIALECT 2 
273+         APPLY "(@score_0 + @score_1)" AS combined_score 
274+         SORTBY 2 @combined_score  
275+         ASC  
276+         MAX 10  
277+         LOAD 2 score_0 score_1 
278+ 
279+ 
280+ 
281+ 
282+ 
263283    FT.AGGREGATE 'idx:characters' 
264-      "@embedding1:[VECTOR_RANGE .7 $vector1]=>{$YIELD_DISTANCE_AS: vector_distance1} | @embedding2:[VECTOR_RANGE 1.0 $vector2]=>{$YIELD_DISTANCE_AS: vector_distance2} | @embedding3:[VECTOR_RANGE 1.7 $vector3]=>{$YIELD_DISTANCE_AS: vector_distance3}  | @name:(James)" 
265-      ADDSCORES 
266-      SCORER BM25STD.NORM 
267-      LOAD 2 created_at @embedding 
268-      APPLY 'case(exists(@vector_distance1), @vector_distance1, 0.0)' as v1 
269-      APPLY 'case(exists(@vector_distance2), @vector_distance2, 0.0)' as v2 
270-      APPLY 'case(exists(@vector_distance3), @vector_distance3, 0.0)' as v3 
284+      "@embedding1:[VECTOR_RANGE .7 $vector1]=>{$YIELD_DISTANCE_AS: vector_distance1} 
285+      | @embedding2:[VECTOR_RANGE 1.0 $vector2]=>{$YIELD_DISTANCE_AS: vector_distance2} 
286+      | @embedding3:[VECTOR_RANGE 1.7 $vector3]=>{$YIELD_DISTANCE_AS: vector_distance3} 
287+      | @name:(James)" 
288+      ### ADDSCORES 
289+      ### SCORER BM25STD.NORM 
290+      ### LOAD 2 created_at @embedding 
291+      APPLY '(2 - @vector_distance1)/2' as v1 
292+      APPLY '(2 - @vector_distance2)/2' as v2 
293+      APPLY '(2 - @vector_distance3)/2' as v3 
271294     APPLY '(@__score * 0.3 + (@v1 * 0.3) + (@v2 * 1.2) + (@v3 * 0.1))' AS final_score 
272295     PARAMS 6 vector1 "\xe4 \xd6 ..." vector2 "\x89 \xa0 ..." vector3 "\x3c \x19 ..." 
273296     SORTBY 2 @final_score DESC 
274297     DIALECT 2 
275298     LIMIT 0 100 
276299
277- 
278300    """ 
279301
280302    DISTANCE_ID : str  =  "vector_distance" 
281-     VECTOR_PARAM : str  =  "vector" 
282303
283304    def  __init__ (
284305        self ,
285306        vectors : Union [bytes , List [bytes ], List [float ], List [List [float ]]],
286307        vector_field_names : Union [str , List [str ]],
308+         weights : List [float ] =  [1.0 ],
309+         return_fields : Optional [List [str ]] =  None ,
287310        filter_expression : Optional [Union [str , FilterExpression ]] =  None ,
288-         weights : Union [float , List [float ]] =  1.0 ,
289-         dtypes : Union [str , List [str ]] =  "float32" ,
311+         dtypes : List [str ] =  ["float32" ],
290312        num_results : int  =  10 ,
291-         return_fields :  Optional [ List [ str ]]  =  None ,
313+         return_score :  bool  =  False ,
292314        dialect : int  =  2 ,
293315    ):
294316        """ 
295317        Instantiates a MultiVectorQuery object. 
296318
297319        Args: 
298320            vectors (Union[bytes, List[bytes], List[float], List[List[float]]): The vectors to perform vector similarity search. 
299-             vector_field_names (str): The vector field names to search in. 
300-             filter_expression (Optional[FilterExpression], optional): The filter expression to use. 
301-                 Defaults to None. 
302-             weights (Union[float, List[float]], optional): The weights of the vector similarity. 
321+             vector_field_names (Union[str, List[str]]): The vector field names to search in. 
322+             weights (List[float]): The weights of the vector similarity. 
303323                Documents will be scored as: 
304324                score = (w1) * score1 + (w2) * score2 + (w3) * score3 + ... 
305-                 Defaults to 1.0, which corresponds to equal weighting 
306-             dtype (Union[str, List[str]] optional): The data types of the vectors. Defaults to "float32" for all vectors. 
307-             num_results (int, optional): The number of results to return. Defaults to 10. 
325+                 Defaults to [1.0], which corresponds to equal weighting 
308326            return_fields (Optional[List[str]], optional): The fields to return. Defaults to None. 
327+             filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use. 
328+                 Defaults to None. 
329+             dtypes (List[str]): The data types of the vectors. Defaults to ["float32"] for all vectors. 
330+             num_results (int, optional): The number of results to return. Defaults to 10. 
331+             return_score (bool): Whether to return the combined vector similarity score. 
332+                 Defaults to False. 
309333            dialect (int, optional): The Redis dialect version. Defaults to 2. 
310334
311335        Raises: 
312336            ValueError: The number of vectors, vector field names, and weights do not agree. 
313-             TypeError: If the stopwords are not a set, list, or tuple of strings. 
314337        """ 
315338
316-         self ._vectors  =  vectors 
317-         self ._vector_fields  =  vector_field_names 
318339        self ._filter_expression  =  filter_expression 
319-         self ._weights  =  weights 
320340        self ._dtypes  =  dtypes 
321341        self ._num_results  =  num_results 
322342
343+         if  len (vectors ) ==  0  or  len (vector_field_names ) ==  0  or  len (weights ) ==  0 :
344+             raise  ValueError (
345+                 f"""The number of vectors and vector field names must be equal. 
346+                              If weights are specified their number must match the number of vectors and vector field names also. 
347+                             Length of vectors list: { len (vectors ) =  }  
348+                             Length of vector_field_names list: { len (vector_field_names ) =  }  
349+                             Length of weights list: { len (weights ) =  }  
350+                             """ 
351+             )
352+ 
353+         if  isinstance (vectors , bytes ) or  isinstance (vectors [0 ], float ):
354+             self ._vectors  =  [vectors ]
355+         else :
356+             self ._vectors  =  vectors 
357+         if  isinstance (vector_field_names , str ):
358+             self ._vector_field_names  =  [vector_field_names ]
359+         else :
360+             self ._vector_field_names  =  vector_field_names 
361+         if  len (weights ) ==  1 :
362+             self ._weights  =  weights  *  len (vectors )
363+         else :
364+             self ._weights  =  weights 
365+         if  len (dtypes ) ==  1 :
366+             self ._dtypes  =  dtypes  *  len (vectors )
367+         else :
368+             self ._dtypes  =  dtypes 
369+ 
370+         if  (len (self ._vectors ) !=  len (self ._vector_field_names )) or  (
371+             len (self ._vectors ) !=  len (self ._weights )
372+         ):
373+             raise  ValueError (
374+                 f"""The number of vectors and vector field names must be equal. 
375+                              If weights are specified their number must match the number of vectors and vector field names also. 
376+                             Length of vectors list: { len (self ._vectors ) =  }  
377+                             Length of vector_field_names list: { len (self ._vector_field_names ) =  }  
378+                             Length of weights list: { len (self ._weights ) =  }  
379+                             """ 
380+             )
381+ 
323382        query_string  =  self ._build_query_string ()
324383        super ().__init__ (query_string )
325384
326-         self .scorer (text_scorer )
327-         self .add_scores ()
328-         self .apply (
329-             vector_similarity = f"(2 - @{ self .DISTANCE_ID }  )/2" , text_score = "@__score" 
330-         )
331-         self .apply (hybrid_score = f"{ 1 - alpha }  *@text_score + { alpha }  *@vector_similarity" )
332-         self .sort_by (Desc ("@hybrid_score" ), max = num_results )  # type: ignore 
385+         # construct the scoring string based on the vector similarity scores and weights 
386+         combined_scores  =  []
387+         for  i , w  in  enumerate (self ._weights ):
388+             combined_scores .append (f"@score_{ i }   * { w }  " )
389+         combined_score_string  =  " + " .join (combined_scores )
390+         combined_score_string  =  f"'({ combined_score_string }  )'" 
391+ 
392+         self .apply (combined_score = combined_score_string )
393+ 
394+         # self.add_scores() 
395+         self .sort_by (Desc ("@combined_score" ), max = num_results )  # type: ignore 
333396        self .dialect (dialect )
334397        if  return_fields :
335398            self .load (* return_fields )  # type: ignore[arg-type] 
@@ -341,49 +404,45 @@ def params(self) -> Dict[str, Any]:
341404        Returns: 
342405            Dict[str, Any]: The parameters for the aggregation. 
343406        """ 
344-         if   isinstance ( self . _vector ,  list ): 
345-              vector   =   array_to_buffer ( self . _vector , dtype = self . _dtype ) 
346-         else : 
347-              vector   =   self . _vector 
348- 
349-         params   =  { self . VECTOR_PARAM :  vector } 
350- 
407+         params   =  {} 
408+         for   i , ( vector ,  vector_field , dtype )  in   enumerate ( zip ( 
409+              self . _vectors ,  self . _vector_field_names ,  self . _dtypes 
410+         )): 
411+              if   isinstance ( vector ,  list ): 
412+                  vector   =   array_to_buffer ( vector ,  dtype = dtype ) 
413+              params [ f"vector_ { i } " ]  =   vector 
351414        return  params 
352415
353-     def  _tokenize_and_escape_query (self , user_query : str ) ->  str :
354-         """Convert a raw user query to a redis full text query joined by ORs 
355-         Args: 
356-             user_query (str): The user query to tokenize and escape. 
357- 
358-         Returns: 
359-             str: The tokenized and escaped query string. 
360-         Raises: 
361-             ValueError: If the text string becomes empty after stopwords are removed. 
362-         """ 
363-         escaper  =  TokenEscaper ()
364- 
365-         tokens  =  [
366-             escaper .escape (
367-                 token .strip ().strip ("," ).replace ("“" , "" ).replace ("”" , "" ).lower ()
368-             )
369-             for  token  in  user_query .split ()
370-         ]
371-         tokenized  =  " | " .join (
372-             [token  for  token  in  tokens  if  token  and  token  not  in   self ._stopwords ]
373-         )
374- 
375-         if  not  tokenized :
376-             raise  ValueError ("text string cannot be empty after removing stopwords" )
377-         return  tokenized 
378- 
379416    def  _build_query_string (self ) ->  str :
380417        """Build the full query string for text search with optional filtering.""" 
418+ 
419+         filter_expression  =  self ._filter_expression 
381420        if  isinstance (self ._filter_expression , FilterExpression ):
382421            filter_expression  =  str (self ._filter_expression )
383-         else :
384-             filter_expression  =  "" 
385422
386423        # base KNN query 
387-         knn_query  =  f"KNN { self ._num_results }   @{ self ._vector_field }   ${ self .VECTOR_PARAM }   AS { self .DISTANCE_ID }  " 
424+         knn_queries  =  []
425+         range_queries  =  []
426+         for  i , (vector , field ) in  enumerate (zip (self ._vectors , self ._vector_field_names )):
427+             knn_queries .append (f"[KNN { self ._num_results }   @{ field }   $vector_{ i }   AS distance_{ i }  ]" )
428+             range_queries .append (f"@{ field }  :[VECTOR_RANGE 2.0 $vector_{ i }  ]=>{{$YIELD_DISTANCE_AS: distance_{ i }  }}" )
429+ 
430+         knn_query  =  " | " .join (knn_queries ) ## knn_queries format doesn't work 
431+         knn_query  =  " | " .join (range_queries )
432+ 
433+         # calculate the respective vector similarities 
434+         apply_string  =  "" 
435+         for  i , (vector , field_name , weight ) in  enumerate (
436+             zip (self ._vectors , self ._vector_field_names , self ._weights )
437+         ):
438+             apply_string  +=  f'APPLY "(2 - @distance_{ i }  )/2" AS score_{ i }   ' 
388439
389-         return  f"{ filter_expression }  )=>[{ knn_query }  ]" 
440+         return  (
441+             f"{ knn_query }   { filter_expression }   { apply_string }  " 
442+             if  filter_expression 
443+             else  f"{ knn_query }   { apply_string }  " 
444+         )
445+ 
446+     def  __str__ (self ) ->  str :
447+         """Return the string representation of the query.""" 
448+         return  " " .join ([str (x ) for  x  in  self .build_args ()])
0 commit comments