Skip to content

Commit ca217dd

Browse files
authored
test(ANN): add integration tests and retrofit (googleapis#138)
* tests(VectorStore): add integration tests for ANN This change adds in integration tests for Approximate Nearest Neighbor for the VectorStore. * Use specific DB for ANN integration tests * Use python_version.major+minor to avoid database conflicts
1 parent 3872af1 commit ca217dd

File tree

1 file changed

+264
-3
lines changed

1 file changed

+264
-3
lines changed

tests/integration/test_spanner_vector_store.py

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
import datetime
1616
import os
17+
import sys
1718
import uuid
19+
from typing import Dict
1820

1921
import pytest
2022
from google.cloud.spanner import Client # type: ignore
@@ -26,13 +28,21 @@
2628
QueryParameters,
2729
SpannerVectorStore,
2830
TableColumn,
31+
VectorSearchIndex,
2932
)
3033

3134
project_id = os.environ["PROJECT_ID"]
3235
instance_id = os.environ["INSTANCE_ID"]
3336
google_database = os.environ["GOOGLE_DATABASE"]
3437
pg_database = os.environ["PG_DATABASE"]
3538
table_name = "test_table" + str(uuid.uuid4()).replace("-", "_")
39+
# Cloud Spanner takes 30+ minutes to create a Vector search index
40+
# hence in order to make integration tests usable, before
41+
# they fix the bad delay, let's reuse the same database and never DROP
42+
# the database nor table to allow for effective reuse.
43+
ann_db = os.environ.get("GOOGLE_SPANNER_ANN_DB", "my-spanner-db-ann")
44+
uniq_py_suffix = f"_py{sys.version_info.major}{sys.version_info.minor}"
45+
table_name_ANN = f"our_table_ann{uniq_py_suffix}"
3646

3747

3848
OPERATION_TIMEOUT_SECONDS = 240
@@ -50,7 +60,11 @@ def cleanupGSQL(client):
5060
print("\nPerforming GSQL cleanup after each test...")
5161

5262
database = client.instance(instance_id).database(google_database)
53-
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
63+
operation = database.update_ddl(
64+
[
65+
f"DROP TABLE IF EXISTS {table_name}",
66+
]
67+
)
5468
operation.result(OPERATION_TIMEOUT_SECONDS)
5569

5670
# Code to perform teardown after each test goes here
@@ -71,7 +85,7 @@ def cleanupPGSQL(client):
7185
print("\n PGSQL Cleanup complete.")
7286

7387

74-
class TestStaticUtilityGoogleSQL:
88+
class TestStaticUtilityGoogleSQL_KNN:
7589
@pytest.fixture(autouse=True)
7690
def setup_database(self, client, cleanupGSQL):
7791
yield
@@ -389,6 +403,252 @@ def test_spanner_vector_search_data4(self, setup_database):
389403
assert len(docs) == 3
390404

391405

406+
title_vector_size = 3
407+
title_vector_index_name = f"title_v_index{uniq_py_suffix}"
408+
title_vector_embedding_column = TableColumn(
409+
name="title_embedding", type="ARRAY<FLOAT64>", is_null=True
410+
)
411+
412+
413+
class TestSpannerVectorStoreGoogleSQL_ANN:
414+
# Sadly currently Cloud Spanner Vector Search Indices being
415+
# created and destroyed takes a very long time hence we
416+
# are creating and tearing down the indices exactly one.
417+
initialized: Dict[str, int] = dict()
418+
419+
@pytest.fixture(scope="class")
420+
def setup_database(self, client):
421+
if len(self.initialized) == 0:
422+
self.initialized["a"] = 1
423+
424+
SpannerVectorStore.init_vector_store_table(
425+
instance_id=instance_id,
426+
database_id=ann_db,
427+
table_name=table_name_ANN,
428+
vector_size=title_vector_size,
429+
id_column="row_id",
430+
metadata_columns=[
431+
TableColumn(name="metadata", type="JSON", is_null=True),
432+
TableColumn(name="title", type="STRING(MAX)", is_null=False),
433+
],
434+
embedding_column=title_vector_embedding_column,
435+
secondary_indexes=[
436+
VectorSearchIndex(
437+
index_name=title_vector_index_name,
438+
columns=[title_vector_embedding_column.name],
439+
distance_type=DistanceStrategy.COSINE,
440+
nullable_column=True,
441+
num_branches=1000,
442+
tree_depth=3,
443+
num_leaves=100000,
444+
),
445+
],
446+
)
447+
448+
loader = HNLoader("https://news.ycombinator.com/item?id=34817881")
449+
embeddings = FakeEmbeddings(size=title_vector_size)
450+
451+
def cleanup_db():
452+
print("\nPerforming GSQL cleanup...")
453+
database = client.instance(instance_id).database(ann_db)
454+
455+
def delete_from_table(txn):
456+
return txn.execute_update(f"DELETE FROM {table_name_ANN} WHERE 1=1")
457+
458+
database.run_in_transaction(delete_from_table)
459+
460+
# Cloud Spanner Vector index creation takes multitudes of time
461+
# hence trying to drop and recreate indices can make tests
462+
# run for more than 2+ hours, hence this comment.
463+
# TODO: Uncoment these operations when Cloud Spanner has fixed the problem.
464+
# operation = database.update_ddl([
465+
# f"DROP VECTOR INDEX IF EXISTS {title_vector_index_name}",
466+
# f"DROP TABLE IF EXISTS {table_name_ANN}",
467+
# ])
468+
# operation.result(OPERATION_TIMEOUT_SECONDS)
469+
print("\nGSQL Cleanup complete.")
470+
471+
yield loader, embeddings, cleanup_db
472+
473+
def test_add_documents(self, setup_database):
474+
loader, embeddings, _ = setup_database
475+
476+
db = SpannerVectorStore(
477+
instance_id=instance_id,
478+
database_id=ann_db,
479+
table_name=table_name_ANN,
480+
id_column="row_id",
481+
ignore_metadata_columns=[],
482+
vector_index_name=title_vector_index_name,
483+
embedding_column=title_vector_embedding_column,
484+
embedding_service=embeddings,
485+
metadata_json_column="metadata",
486+
)
487+
488+
docs = loader.load()
489+
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
490+
ids_row_inserted = db.add_documents(documents=docs, ids=ids)
491+
assert ids == ids_row_inserted
492+
493+
def test_add_texts(self, setup_database):
494+
loader, embeddings, _ = setup_database
495+
496+
db = SpannerVectorStore(
497+
instance_id=instance_id,
498+
database_id=ann_db,
499+
table_name=table_name_ANN,
500+
id_column="row_id",
501+
ignore_metadata_columns=[],
502+
vector_index_name=title_vector_index_name,
503+
embedding_column=title_vector_embedding_column,
504+
embedding_service=embeddings,
505+
metadata_json_column="metadata",
506+
)
507+
508+
texts = [
509+
"Langchain Test Text 1",
510+
"Langchain Test Text 2",
511+
"Langchain Test Text 3",
512+
]
513+
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
514+
ids_row_inserted = db.add_texts(
515+
texts=texts,
516+
ids=ids,
517+
metadatas=[
518+
{"title": "Title 1"},
519+
{"title": "Title 2"},
520+
{"title": "Title 3"},
521+
],
522+
)
523+
assert ids == ids_row_inserted
524+
525+
def test_delete(self, setup_database):
526+
loader, embeddings, _ = setup_database
527+
528+
db = SpannerVectorStore(
529+
instance_id=instance_id,
530+
database_id=ann_db,
531+
table_name=table_name_ANN,
532+
id_column="row_id",
533+
ignore_metadata_columns=[],
534+
vector_index_name=title_vector_index_name,
535+
embedding_column=title_vector_embedding_column,
536+
embedding_service=embeddings,
537+
metadata_json_column="metadata",
538+
)
539+
540+
docs = loader.load()
541+
deleted = db.delete(documents=[docs[0], docs[1]])
542+
543+
assert deleted
544+
545+
def test_similarity_search(self, setup_database):
546+
loader, embeddings, _ = setup_database
547+
548+
db = SpannerVectorStore(
549+
instance_id=instance_id,
550+
database_id=ann_db,
551+
table_name=table_name_ANN,
552+
id_column="row_id",
553+
ignore_metadata_columns=[],
554+
vector_index_name=title_vector_index_name,
555+
embedding_service=embeddings,
556+
embedding_column=title_vector_embedding_column,
557+
metadata_json_column="metadata",
558+
query_parameters=QueryParameters(
559+
algorithm=QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR,
560+
distance_strategy=DistanceStrategy.COSINE,
561+
max_staleness=datetime.timedelta(seconds=15),
562+
),
563+
)
564+
565+
docs = loader.load()
566+
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
567+
db.add_documents(documents=docs, ids=ids)
568+
569+
docs = db.similarity_search(
570+
"Testing the langchain integration with spanner",
571+
k=2,
572+
)
573+
574+
assert len(docs) == 2
575+
576+
def test_similarity_search_by_vector(self, setup_database):
577+
loader, embeddings, _ = setup_database
578+
579+
db = SpannerVectorStore(
580+
instance_id=instance_id,
581+
database_id=ann_db,
582+
table_name=table_name_ANN,
583+
id_column="row_id",
584+
ignore_metadata_columns=[],
585+
vector_index_name=title_vector_index_name,
586+
embedding_service=embeddings,
587+
embedding_column=title_vector_embedding_column,
588+
metadata_json_column="metadata",
589+
query_parameters=QueryParameters(
590+
algorithm=QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR,
591+
distance_strategy=DistanceStrategy.COSINE,
592+
max_staleness=datetime.timedelta(seconds=15),
593+
),
594+
)
595+
596+
docs = loader.load()
597+
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
598+
db.add_documents(documents=docs, ids=ids)
599+
600+
embeds = embeddings.embed_query(
601+
"Testing the langchain integration with spanner"
602+
)
603+
604+
docs = db.similarity_search_by_vector(
605+
embeds,
606+
k=3,
607+
)
608+
609+
assert len(docs) == 3
610+
611+
def test_max_marginal_relevance_search_with_score_by_vector(self, setup_database):
612+
loader, embeddings, _ = setup_database
613+
614+
db = SpannerVectorStore(
615+
instance_id=instance_id,
616+
database_id=ann_db,
617+
table_name=table_name_ANN,
618+
id_column="row_id",
619+
ignore_metadata_columns=[],
620+
vector_index_name=title_vector_index_name,
621+
embedding_service=embeddings,
622+
metadata_json_column="metadata",
623+
embedding_column=title_vector_embedding_column,
624+
query_parameters=QueryParameters(
625+
algorithm=QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR,
626+
distance_strategy=DistanceStrategy.COSINE,
627+
max_staleness=datetime.timedelta(seconds=15),
628+
),
629+
)
630+
631+
docs = loader.load()
632+
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
633+
db.add_documents(documents=docs, ids=ids)
634+
635+
embeds = embeddings.embed_query(
636+
"Testing the langchain integration with spanner"
637+
)
638+
639+
docs = db.max_marginal_relevance_search_with_score_by_vector(
640+
embeds,
641+
k=3,
642+
)
643+
644+
assert len(docs) == 3
645+
646+
def test_last_for_cleanup(self, setup_database):
647+
loader, _, cleanup = setup_database
648+
_ = loader
649+
cleanup()
650+
651+
392652
class TestSpannerVectorStorePGSQL:
393653
@pytest.fixture(scope="class")
394654
def setup_database(self, client):
@@ -565,7 +825,8 @@ def test_spanner_vector_search_data4(self, setup_database):
565825
)
566826

567827
docs = db.max_marginal_relevance_search(
568-
"Testing the langchain integration with spanner", k=3
828+
"Testing the langchain integration with spanner",
829+
k=3,
569830
)
570831

571832
assert len(docs) == 3

0 commit comments

Comments
 (0)