Skip to content

Commit 524678b

Browse files
authored
fix: make ANN algorithm updates based off usage + testing (googleapis#140)
This change is carved out of PR googleapis#138, due to the fact that at present, Google Cloud Spanner vector index creation and deletion takes a very long time and is non-deterministic hence unreliable for integration tests so these updates which have unit tests are exclusive of the integration tests and allow for usage directly as Google engineering figures out the backend issues. Updates googleapis#94.
1 parent 5a25f91 commit 524678b

File tree

2 files changed

+103
-63
lines changed

2 files changed

+103
-63
lines changed

src/langchain_google_spanner/vector_store.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import datetime
1818
import logging
1919
from abc import ABC, abstractmethod
20+
from dataclasses import dataclass
2021
from enum import Enum
2122
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
2223

@@ -40,9 +41,7 @@
4041

4142
USER_AGENT_VECTOR_STORE = "langchain-google-spanner-python:vector_store/" + __version__
4243

43-
KNN_DISTANCE_SEARCH_QUERY_ALIAS = "distance"
44-
45-
from dataclasses import dataclass
44+
DISTANCE_SEARCH_QUERY_ALIAS = "distance"
4645

4746

4847
def client_with_user_agent(
@@ -423,7 +422,7 @@ def _generate_sql(
423422
primary_key,
424423
secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None,
425424
vector_size: Optional[int] = None,
426-
):
425+
) -> List[str]:
427426
"""
428427
Generate SQL for creating the vector store table.
429428
@@ -637,11 +636,12 @@ def __init__(
637636
embedding_service: Embeddings,
638637
id_column: str = ID_COLUMN_NAME,
639638
content_column: str = CONTENT_COLUMN_NAME,
640-
embedding_column: str = EMBEDDING_COLUMN_NAME,
639+
embedding_column: Optional[str | TableColumn] = None,
641640
client: Optional[spanner.Client] = None,
642641
metadata_columns: Optional[List[str]] = None,
643642
ignore_metadata_columns: Optional[List[str]] = None,
644643
metadata_json_column: Optional[str] = None,
644+
vector_index_name: Optional[str] = None, # For ANN.
645645
query_parameters: QueryParameters = QueryParameters(),
646646
):
647647
"""
@@ -667,8 +667,22 @@ def __init__(
667667
self._client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE)
668668
self._id_column = id_column
669669
self._content_column = content_column
670-
self._embedding_column = embedding_column
670+
if embedding_column is None:
671+
embedding_column = EMBEDDING_COLUMN_NAME
672+
self._embedding_column = ""
673+
self._embedding_column_type = ""
674+
self._embedding_column_is_nullable = False
675+
if isinstance(embedding_column, TableColumn):
676+
self._embedding_column_type = embedding_column.type
677+
self._embedding_column = embedding_column.name
678+
self._embedding_column_is_nullable = embedding_column.is_null
679+
embedding_column = embedding_column.name
680+
elif isinstance(embedding_column, str):
681+
self._embedding_column = embedding_column
671682
self._metadata_json_column = metadata_json_column
683+
self._vector_index_name = ""
684+
if vector_index_name:
685+
self._vector_index_name = vector_index_name
672686

673687
self._query_parameters = query_parameters
674688
self._embedding_service = embedding_service
@@ -1022,9 +1036,9 @@ def similarity_search_with_score_by_vector(
10221036
"""
10231037
if self.__using_ANN:
10241038
results, column_order_map = self._get_rows_by_similarity_search_ann(
1025-
embedding,
1026-
k,
1027-
pre_filter,
1039+
embedding=embedding,
1040+
k=k,
1041+
pre_filter=pre_filter,
10281042
**kwargs,
10291043
)
10301044
else:
@@ -1050,14 +1064,16 @@ def _get_rows_by_similarity_search_ann(
10501064
):
10511065
sql = SpannerVectorStore._generate_sql_for_ANN(
10521066
self._table_name,
1053-
index_name,
1067+
index_name or self._vector_index_name,
10541068
self._embedding_column,
10551069
embedding,
10561070
num_leaves,
10571071
k,
10581072
self._query_parameters.distance_strategy,
10591073
pre_filter=pre_filter,
1060-
embedding_column_is_nullable=embedding_column_is_nullable,
1074+
embedding_column_type=self._embedding_column_type,
1075+
embedding_column_is_nullable=self._embedding_column_is_nullable
1076+
or embedding_column_is_nullable,
10611077
ascending=ascending,
10621078
return_columns=return_columns or self._columns_to_insert,
10631079
)
@@ -1066,10 +1082,10 @@ def _get_rows_by_similarity_search_ann(
10661082
**staleness if staleness is not None else {}
10671083
) as snapshot:
10681084
results = snapshot.execute_sql(sql=sql)
1069-
column_order_map = {
1070-
value: index for index, value in enumerate(self._columns_to_insert)
1071-
}
1072-
return results, column_order_map
1085+
columns = (self._columns_to_insert or []).copy()
1086+
columns.append(DISTANCE_SEARCH_QUERY_ALIAS)
1087+
column_order_map = {value: index for index, value in enumerate(columns)}
1088+
return list(results), column_order_map
10731089

10741090
@staticmethod
10751091
def _generate_sql_for_ANN(
@@ -1081,13 +1097,17 @@ def _generate_sql_for_ANN(
10811097
k: int,
10821098
strategy: DistanceStrategy = DistanceStrategy.COSINE,
10831099
pre_filter: Optional[str] = None,
1100+
embedding_column_type: str = "ARRAY<FLOAT32>",
10841101
embedding_column_is_nullable: bool = False,
10851102
ascending: bool = True,
10861103
return_columns: Optional[List[str]] = None,
10871104
) -> str:
10881105
if not embedding_column_name:
10891106
raise Exception("embedding_column_name must be set")
10901107

1108+
if not index_name:
1109+
raise Exception("index_name must be set")
1110+
10911111
ann_strategy_name = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS.get(strategy, None)
10921112
if not ann_strategy_name:
10931113
raise Exception(f"{strategy} is not supported for ANN")
@@ -1099,8 +1119,12 @@ def _generate_sql_for_ANN(
10991119
if not column_names:
11001120
column_names = "*"
11011121

1122+
distance_alias = DISTANCE_SEARCH_QUERY_ALIAS
11021123
sql = (
1103-
f"SELECT {column_names} FROM {table_name}"
1124+
f"SELECT {column_names}, {ann_strategy_name}("
1125+
+ f"{embedding_column_type}{embedding}, {embedding_column_name}, options => JSON '"
1126+
+ ('{"num_leaves_to_search": %s}\') as %s\n' % (num_leaves, distance_alias))
1127+
+ f"FROM {table_name}"
11041128
+ "@{FORCE_INDEX="
11051129
+ f"{index_name}"
11061130
+ (
@@ -1111,10 +1135,9 @@ def _generate_sql_for_ANN(
11111135
+ ("" if not pre_filter else f" AND {pre_filter}")
11121136
+ "\n"
11131137
)
1114-
+ f"ORDER BY {ann_strategy_name}(\n"
1115-
+ f" ARRAY<FLOAT32>{embedding}, {embedding_column_name}, options => JSON '"
1116-
+ '{"num_leaves_to_search": %s}\')%s\n'
1117-
% (num_leaves, "" if ascending else " DESC")
1138+
+ f"ORDER BY {distance_alias}"
1139+
+ ("" if ascending else " DESC")
1140+
+ "\n"
11181141
)
11191142

11201143
if k:
@@ -1144,7 +1167,7 @@ def _get_rows_by_similarity_search_knn(
11441167
column_order_map = {
11451168
value: index for index, value in enumerate(self._columns_to_insert)
11461169
}
1147-
column_order_map[KNN_DISTANCE_SEARCH_QUERY_ALIAS] = len(self._columns_to_insert)
1170+
column_order_map[DISTANCE_SEARCH_QUERY_ALIAS] = len(self._columns_to_insert)
11481171

11491172
sql_query = """
11501173
SELECT {select_column_names} {distance_function}({embedding_column}, {vector_embedding_placeholder}) AS {distance_alias}
@@ -1160,7 +1183,7 @@ def _get_rows_by_similarity_search_knn(
11601183
filter=pre_filter if pre_filter is not None else "1 = 1",
11611184
k_count=k,
11621185
distance_function=distance_function,
1163-
distance_alias=KNN_DISTANCE_SEARCH_QUERY_ALIAS,
1186+
distance_alias=DISTANCE_SEARCH_QUERY_ALIAS,
11641187
)
11651188

11661189
with self._database.snapshot(
@@ -1195,9 +1218,7 @@ def _get_documents_from_query_results(
11951218
}
11961219

11971220
doc = Document(page_content=page_content, metadata=metadata)
1198-
documents.append(
1199-
(doc, row[column_order_map[KNN_DISTANCE_SEARCH_QUERY_ALIAS]])
1200-
)
1221+
documents.append((doc, row[column_order_map[DISTANCE_SEARCH_QUERY_ALIAS]]))
12011222

12021223
return documents
12031224

@@ -1221,7 +1242,10 @@ def similarity_search(
12211242
"""
12221243
embedding = self._embedding_service.embed_query(query)
12231244
documents = self.similarity_search_with_score_by_vector(
1224-
embedding=embedding, k=k, pre_filter=pre_filter
1245+
embedding=embedding,
1246+
k=k,
1247+
pre_filter=pre_filter,
1248+
**kwargs,
12251249
)
12261250
return [doc for doc, _ in documents]
12271251

@@ -1245,7 +1269,10 @@ def similarity_search_with_score(
12451269
"""
12461270
embedding = self._embedding_service.embed_query(query)
12471271
documents = self.similarity_search_with_score_by_vector(
1248-
embedding=embedding, k=k, pre_filter=pre_filter
1272+
embedding=embedding,
1273+
k=k,
1274+
pre_filter=pre_filter,
1275+
**kwargs,
12491276
)
12501277
return documents
12511278

@@ -1313,9 +1340,9 @@ def max_marginal_relevance_search_with_score_by_vector(
13131340
"""
13141341
if self.__using_ANN:
13151342
results, column_order_map = self._get_rows_by_similarity_search_ann(
1316-
embedding,
1317-
fetch_k,
1318-
pre_filter,
1343+
embedding=embedding,
1344+
k=fetch_k,
1345+
pre_filter=pre_filter,
13191346
**kwargs,
13201347
)
13211348
else:
@@ -1367,7 +1394,12 @@ def max_marginal_relevance_search_by_vector(
13671394
List of Documents selected by maximal marginal relevance.
13681395
"""
13691396
documents_with_scores = self.max_marginal_relevance_search_with_score_by_vector(
1370-
embedding, k, fetch_k, lambda_mult, pre_filter
1397+
embedding,
1398+
k,
1399+
fetch_k,
1400+
lambda_mult,
1401+
pre_filter,
1402+
**kwargs,
13711403
)
13721404

13731405
return [doc for doc, _ in documents_with_scores]

tests/unit/test_vectore_store.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,12 @@ def test_generate_sql_for_ANN(self):
215215
)
216216

217217
want = (
218-
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
218+
"SELECT DocId, APPROX_COSINE_DISTANCE("
219+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
220+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
221+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
219222
+ "WHERE 1=1\n"
220-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
221-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
222-
+ "'{\"num_leaves_to_search\": 10}')\n"
223+
+ "ORDER BY distance\n"
223224
+ "LIMIT 100"
224225
)
225226

@@ -239,11 +240,12 @@ def test_generate_sql_for_ANN_column_is_nullable(self):
239240
)
240241

241242
want = (
242-
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
243+
"SELECT DocId, APPROX_COSINE_DISTANCE("
244+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
245+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
246+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
243247
+ "WHERE DocEmbedding IS NOT NULL\n"
244-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
245-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
246-
+ "'{\"num_leaves_to_search\": 10}')\n"
248+
+ "ORDER BY distance\n"
247249
+ "LIMIT 100"
248250
)
249251

@@ -262,11 +264,12 @@ def test_generate_sql_for_ANN_column_unspecified_return_columns_star_result(self
262264
)
263265

264266
want = (
265-
"SELECT * FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
267+
"SELECT *, APPROX_COSINE_DISTANCE("
268+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
269+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
270+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
266271
+ "WHERE DocEmbedding IS NOT NULL\n"
267-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
268-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
269-
+ "'{\"num_leaves_to_search\": 10}')\n"
272+
+ "ORDER BY distance\n"
270273
+ "LIMIT 100"
271274
)
272275

@@ -286,11 +289,12 @@ def test_generate_sql_for_ANN_order_DESC(self):
286289
)
287290

288291
want = (
289-
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
292+
"SELECT DocId, APPROX_COSINE_DISTANCE("
293+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
294+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
295+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
290296
+ "WHERE 1=1\n"
291-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
292-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
293-
+ "'{\"num_leaves_to_search\": 10}') DESC\n"
297+
+ "ORDER BY distance DESC\n"
294298
+ "LIMIT 100"
295299
)
296300

@@ -309,11 +313,12 @@ def test_generate_sql_for_ANN_specified_limit(self):
309313
)
310314

311315
want = (
312-
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
316+
"SELECT DocId, APPROX_COSINE_DISTANCE("
317+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
318+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
319+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
313320
+ "WHERE 1=1\n"
314-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
315-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
316-
+ "'{\"num_leaves_to_search\": 10}')\n"
321+
+ "ORDER BY distance\n"
317322
+ "LIMIT 100"
318323
)
319324

@@ -333,11 +338,12 @@ def test_generate_sql_for_ANN_specified_pre_filter(self):
333338
)
334339

335340
want = (
336-
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
341+
"SELECT DocId, APPROX_COSINE_DISTANCE("
342+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
343+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
344+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
337345
+ "WHERE categoryId!=20\n"
338-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
339-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
340-
+ "'{\"num_leaves_to_search\": 10}')\n"
346+
+ "ORDER BY distance\n"
341347
+ "LIMIT 100"
342348
)
343349

@@ -358,11 +364,12 @@ def test_generate_sql_for_ANN_specified_pre_filter_with_nullable_column(self):
358364
)
359365

360366
want = (
361-
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
367+
"SELECT DocId, APPROX_COSINE_DISTANCE("
368+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
369+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
370+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
362371
+ "WHERE DocEmbedding IS NOT NULL AND categoryId!=9\n"
363-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
364-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
365-
+ "'{\"num_leaves_to_search\": 10}')\n"
372+
+ "ORDER BY distance\n"
366373
+ "LIMIT 100"
367374
)
368375

@@ -383,11 +390,12 @@ def test_generate_sql_for_ANN_no_pre_filter_non_nullable(self):
383390
)
384391

385392
want = (
386-
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
393+
"SELECT DocId, APPROX_COSINE_DISTANCE("
394+
+ "ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
395+
+ "'{\"num_leaves_to_search\": 10}') as distance\n"
396+
+ "FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
387397
+ "WHERE DocEmbedding IS NOT NULL AND DocId!=2\n"
388-
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
389-
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
390-
+ "'{\"num_leaves_to_search\": 10}')\n"
398+
+ "ORDER BY distance\n"
391399
+ "LIMIT 100"
392400
)
393401

0 commit comments

Comments
 (0)