1717import datetime
1818import logging
1919from abc import ABC , abstractmethod
20+ from dataclasses import dataclass
2021from enum import Enum
2122from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Type , Union
2223
4041
4142USER_AGENT_VECTOR_STORE = "langchain-google-spanner-python:vector_store/" + __version__
4243
43- KNN_DISTANCE_SEARCH_QUERY_ALIAS = "distance"
44-
45- from dataclasses import dataclass
44+ DISTANCE_SEARCH_QUERY_ALIAS = "distance"
4645
4746
4847def client_with_user_agent (
@@ -423,7 +422,7 @@ def _generate_sql(
423422 primary_key ,
424423 secondary_indexes : Optional [List [SecondaryIndex | VectorSearchIndex ]] = None ,
425424 vector_size : Optional [int ] = None ,
426- ):
425+ ) -> List [ str ] :
427426 """
428427 Generate SQL for creating the vector store table.
429428
@@ -637,11 +636,12 @@ def __init__(
637636 embedding_service : Embeddings ,
638637 id_column : str = ID_COLUMN_NAME ,
639638 content_column : str = CONTENT_COLUMN_NAME ,
640- embedding_column : str = EMBEDDING_COLUMN_NAME ,
639+ embedding_column : Optional [ str | TableColumn ] = None ,
641640 client : Optional [spanner .Client ] = None ,
642641 metadata_columns : Optional [List [str ]] = None ,
643642 ignore_metadata_columns : Optional [List [str ]] = None ,
644643 metadata_json_column : Optional [str ] = None ,
644+ vector_index_name : Optional [str ] = None , # For ANN.
645645 query_parameters : QueryParameters = QueryParameters (),
646646 ):
647647 """
@@ -667,8 +667,22 @@ def __init__(
667667 self ._client = client_with_user_agent (client , USER_AGENT_VECTOR_STORE )
668668 self ._id_column = id_column
669669 self ._content_column = content_column
670- self ._embedding_column = embedding_column
670+ if embedding_column is None :
671+ embedding_column = EMBEDDING_COLUMN_NAME
672+ self ._embedding_column = ""
673+ self ._embedding_column_type = ""
674+ self ._embedding_column_is_nullable = False
675+ if isinstance (embedding_column , TableColumn ):
676+ self ._embedding_column_type = embedding_column .type
677+ self ._embedding_column = embedding_column .name
678+ self ._embedding_column_is_nullable = embedding_column .is_null
679+ embedding_column = embedding_column .name
680+ elif isinstance (embedding_column , str ):
681+ self ._embedding_column = embedding_column
671682 self ._metadata_json_column = metadata_json_column
683+ self ._vector_index_name = ""
684+ if vector_index_name :
685+ self ._vector_index_name = vector_index_name
672686
673687 self ._query_parameters = query_parameters
674688 self ._embedding_service = embedding_service
@@ -1022,9 +1036,9 @@ def similarity_search_with_score_by_vector(
10221036 """
10231037 if self .__using_ANN :
10241038 results , column_order_map = self ._get_rows_by_similarity_search_ann (
1025- embedding ,
1026- k ,
1027- pre_filter ,
1039+ embedding = embedding ,
1040+ k = k ,
1041+ pre_filter = pre_filter ,
10281042 ** kwargs ,
10291043 )
10301044 else :
@@ -1050,14 +1064,16 @@ def _get_rows_by_similarity_search_ann(
10501064 ):
10511065 sql = SpannerVectorStore ._generate_sql_for_ANN (
10521066 self ._table_name ,
1053- index_name ,
1067+ index_name or self . _vector_index_name ,
10541068 self ._embedding_column ,
10551069 embedding ,
10561070 num_leaves ,
10571071 k ,
10581072 self ._query_parameters .distance_strategy ,
10591073 pre_filter = pre_filter ,
1060- embedding_column_is_nullable = embedding_column_is_nullable ,
1074+ embedding_column_type = self ._embedding_column_type ,
1075+ embedding_column_is_nullable = self ._embedding_column_is_nullable
1076+ or embedding_column_is_nullable ,
10611077 ascending = ascending ,
10621078 return_columns = return_columns or self ._columns_to_insert ,
10631079 )
@@ -1066,10 +1082,10 @@ def _get_rows_by_similarity_search_ann(
10661082 ** staleness if staleness is not None else {}
10671083 ) as snapshot :
10681084 results = snapshot .execute_sql (sql = sql )
1069- column_order_map = {
1070- value : index for index , value in enumerate ( self . _columns_to_insert )
1071- }
1072- return results , column_order_map
1085+ columns = ( self . _columns_to_insert or []). copy ()
1086+ columns . append ( DISTANCE_SEARCH_QUERY_ALIAS )
1087+ column_order_map = { value : index for index , value in enumerate ( columns ) }
1088+ return list ( results ) , column_order_map
10731089
10741090 @staticmethod
10751091 def _generate_sql_for_ANN (
@@ -1081,13 +1097,17 @@ def _generate_sql_for_ANN(
10811097 k : int ,
10821098 strategy : DistanceStrategy = DistanceStrategy .COSINE ,
10831099 pre_filter : Optional [str ] = None ,
1100+ embedding_column_type : str = "ARRAY<FLOAT32>" ,
10841101 embedding_column_is_nullable : bool = False ,
10851102 ascending : bool = True ,
10861103 return_columns : Optional [List [str ]] = None ,
10871104 ) -> str :
10881105 if not embedding_column_name :
10891106 raise Exception ("embedding_column_name must be set" )
10901107
1108+ if not index_name :
1109+ raise Exception ("index_name must be set" )
1110+
10911111 ann_strategy_name = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS .get (strategy , None )
10921112 if not ann_strategy_name :
10931113 raise Exception (f"{ strategy } is not supported for ANN" )
@@ -1099,8 +1119,12 @@ def _generate_sql_for_ANN(
10991119 if not column_names :
11001120 column_names = "*"
11011121
1122+ distance_alias = DISTANCE_SEARCH_QUERY_ALIAS
11021123 sql = (
1103- f"SELECT { column_names } FROM { table_name } "
1124+ f"SELECT { column_names } , { ann_strategy_name } ("
1125+ + f"{ embedding_column_type } { embedding } , { embedding_column_name } , options => JSON '"
1126+ + ('{"num_leaves_to_search": %s}\' ) as %s\n ' % (num_leaves , distance_alias ))
1127+ + f"FROM { table_name } "
11041128 + "@{FORCE_INDEX="
11051129 + f"{ index_name } "
11061130 + (
@@ -1111,10 +1135,9 @@ def _generate_sql_for_ANN(
11111135 + ("" if not pre_filter else f" AND { pre_filter } " )
11121136 + "\n "
11131137 )
1114- + f"ORDER BY { ann_strategy_name } (\n "
1115- + f" ARRAY<FLOAT32>{ embedding } , { embedding_column_name } , options => JSON '"
1116- + '{"num_leaves_to_search": %s}\' )%s\n '
1117- % (num_leaves , "" if ascending else " DESC" )
1138+ + f"ORDER BY { distance_alias } "
1139+ + ("" if ascending else " DESC" )
1140+ + "\n "
11181141 )
11191142
11201143 if k :
@@ -1144,7 +1167,7 @@ def _get_rows_by_similarity_search_knn(
11441167 column_order_map = {
11451168 value : index for index , value in enumerate (self ._columns_to_insert )
11461169 }
1147- column_order_map [KNN_DISTANCE_SEARCH_QUERY_ALIAS ] = len (self ._columns_to_insert )
1170+ column_order_map [DISTANCE_SEARCH_QUERY_ALIAS ] = len (self ._columns_to_insert )
11481171
11491172 sql_query = """
11501173 SELECT {select_column_names} {distance_function}({embedding_column}, {vector_embedding_placeholder}) AS {distance_alias}
@@ -1160,7 +1183,7 @@ def _get_rows_by_similarity_search_knn(
11601183 filter = pre_filter if pre_filter is not None else "1 = 1" ,
11611184 k_count = k ,
11621185 distance_function = distance_function ,
1163- distance_alias = KNN_DISTANCE_SEARCH_QUERY_ALIAS ,
1186+ distance_alias = DISTANCE_SEARCH_QUERY_ALIAS ,
11641187 )
11651188
11661189 with self ._database .snapshot (
@@ -1195,9 +1218,7 @@ def _get_documents_from_query_results(
11951218 }
11961219
11971220 doc = Document (page_content = page_content , metadata = metadata )
1198- documents .append (
1199- (doc , row [column_order_map [KNN_DISTANCE_SEARCH_QUERY_ALIAS ]])
1200- )
1221+ documents .append ((doc , row [column_order_map [DISTANCE_SEARCH_QUERY_ALIAS ]]))
12011222
12021223 return documents
12031224
@@ -1221,7 +1242,10 @@ def similarity_search(
12211242 """
12221243 embedding = self ._embedding_service .embed_query (query )
12231244 documents = self .similarity_search_with_score_by_vector (
1224- embedding = embedding , k = k , pre_filter = pre_filter
1245+ embedding = embedding ,
1246+ k = k ,
1247+ pre_filter = pre_filter ,
1248+ ** kwargs ,
12251249 )
12261250 return [doc for doc , _ in documents ]
12271251
@@ -1245,7 +1269,10 @@ def similarity_search_with_score(
12451269 """
12461270 embedding = self ._embedding_service .embed_query (query )
12471271 documents = self .similarity_search_with_score_by_vector (
1248- embedding = embedding , k = k , pre_filter = pre_filter
1272+ embedding = embedding ,
1273+ k = k ,
1274+ pre_filter = pre_filter ,
1275+ ** kwargs ,
12491276 )
12501277 return documents
12511278
@@ -1313,9 +1340,9 @@ def max_marginal_relevance_search_with_score_by_vector(
13131340 """
13141341 if self .__using_ANN :
13151342 results , column_order_map = self ._get_rows_by_similarity_search_ann (
1316- embedding ,
1317- fetch_k ,
1318- pre_filter ,
1343+ embedding = embedding ,
1344+ k = fetch_k ,
1345+ pre_filter = pre_filter ,
13191346 ** kwargs ,
13201347 )
13211348 else :
@@ -1367,7 +1394,12 @@ def max_marginal_relevance_search_by_vector(
13671394 List of Documents selected by maximal marginal relevance.
13681395 """
13691396 documents_with_scores = self .max_marginal_relevance_search_with_score_by_vector (
1370- embedding , k , fetch_k , lambda_mult , pre_filter
1397+ embedding ,
1398+ k ,
1399+ fetch_k ,
1400+ lambda_mult ,
1401+ pre_filter ,
1402+ ** kwargs ,
13711403 )
13721404
13731405 return [doc for doc , _ in documents_with_scores ]
0 commit comments