1414
1515import datetime
1616import os
17+ import sys
1718import uuid
19+ from typing import Dict
1820
1921import pytest
2022from google .cloud .spanner import Client # type: ignore
2628 QueryParameters ,
2729 SpannerVectorStore ,
2830 TableColumn ,
31+ VectorSearchIndex ,
2932)
3033
3134project_id = os .environ ["PROJECT_ID" ]
3235instance_id = os .environ ["INSTANCE_ID" ]
3336google_database = os .environ ["GOOGLE_DATABASE" ]
3437pg_database = os .environ ["PG_DATABASE" ]
3538table_name = "test_table" + str (uuid .uuid4 ()).replace ("-" , "_" )
39+ # Cloud Spanner takes 30+ minutes to create a Vector search index
40+ # hence in order to make integration tests usable, before
41+ # they fix the bad delay, let's reuse the same database and never DROP
42+ # the database nor table to allow for effective reuse.
43+ ann_db = os .environ .get ("GOOGLE_SPANNER_ANN_DB" , "my-spanner-db-ann" )
44+ uniq_py_suffix = f"_py{ sys .version_info .major } { sys .version_info .minor } "
45+ table_name_ANN = f"our_table_ann{ uniq_py_suffix } "
3646
3747
3848OPERATION_TIMEOUT_SECONDS = 240
@@ -50,7 +60,11 @@ def cleanupGSQL(client):
5060 print ("\n Performing GSQL cleanup after each test..." )
5161
5262 database = client .instance (instance_id ).database (google_database )
53- operation = database .update_ddl ([f"DROP TABLE IF EXISTS { table_name } " ])
63+ operation = database .update_ddl (
64+ [
65+ f"DROP TABLE IF EXISTS { table_name } " ,
66+ ]
67+ )
5468 operation .result (OPERATION_TIMEOUT_SECONDS )
5569
5670 # Code to perform teardown after each test goes here
@@ -71,7 +85,7 @@ def cleanupPGSQL(client):
7185 print ("\n PGSQL Cleanup complete." )
7286
7387
74- class TestStaticUtilityGoogleSQL :
88+ class TestStaticUtilityGoogleSQL_KNN :
7589 @pytest .fixture (autouse = True )
7690 def setup_database (self , client , cleanupGSQL ):
7791 yield
@@ -389,6 +403,252 @@ def test_spanner_vector_search_data4(self, setup_database):
389403 assert len (docs ) == 3
390404
391405
406+ title_vector_size = 3
407+ title_vector_index_name = f"title_v_index{ uniq_py_suffix } "
408+ title_vector_embedding_column = TableColumn (
409+ name = "title_embedding" , type = "ARRAY<FLOAT64>" , is_null = True
410+ )
411+
412+
413+ class TestSpannerVectorStoreGoogleSQL_ANN :
414+ # Sadly currently Cloud Spanner Vector Search Indices being
415+ # created and destroyed takes a very long time hence we
416+ # are creating and tearing down the indices exactly one.
417+ initialized : Dict [str , int ] = dict ()
418+
419+ @pytest .fixture (scope = "class" )
420+ def setup_database (self , client ):
421+ if len (self .initialized ) == 0 :
422+ self .initialized ["a" ] = 1
423+
424+ SpannerVectorStore .init_vector_store_table (
425+ instance_id = instance_id ,
426+ database_id = ann_db ,
427+ table_name = table_name_ANN ,
428+ vector_size = title_vector_size ,
429+ id_column = "row_id" ,
430+ metadata_columns = [
431+ TableColumn (name = "metadata" , type = "JSON" , is_null = True ),
432+ TableColumn (name = "title" , type = "STRING(MAX)" , is_null = False ),
433+ ],
434+ embedding_column = title_vector_embedding_column ,
435+ secondary_indexes = [
436+ VectorSearchIndex (
437+ index_name = title_vector_index_name ,
438+ columns = [title_vector_embedding_column .name ],
439+ distance_type = DistanceStrategy .COSINE ,
440+ nullable_column = True ,
441+ num_branches = 1000 ,
442+ tree_depth = 3 ,
443+ num_leaves = 100000 ,
444+ ),
445+ ],
446+ )
447+
448+ loader = HNLoader ("https://news.ycombinator.com/item?id=34817881" )
449+ embeddings = FakeEmbeddings (size = title_vector_size )
450+
451+ def cleanup_db ():
452+ print ("\n Performing GSQL cleanup..." )
453+ database = client .instance (instance_id ).database (ann_db )
454+
455+ def delete_from_table (txn ):
456+ return txn .execute_update (f"DELETE FROM { table_name_ANN } WHERE 1=1" )
457+
458+ database .run_in_transaction (delete_from_table )
459+
460+ # Cloud Spanner Vector index creation takes multitudes of time
461+ # hence trying to drop and recreate indices can make tests
462+ # run for more than 2+ hours, hence this comment.
463+ # TODO: Uncoment these operations when Cloud Spanner has fixed the problem.
464+ # operation = database.update_ddl([
465+ # f"DROP VECTOR INDEX IF EXISTS {title_vector_index_name}",
466+ # f"DROP TABLE IF EXISTS {table_name_ANN}",
467+ # ])
468+ # operation.result(OPERATION_TIMEOUT_SECONDS)
469+ print ("\n GSQL Cleanup complete." )
470+
471+ yield loader , embeddings , cleanup_db
472+
473+ def test_add_documents (self , setup_database ):
474+ loader , embeddings , _ = setup_database
475+
476+ db = SpannerVectorStore (
477+ instance_id = instance_id ,
478+ database_id = ann_db ,
479+ table_name = table_name_ANN ,
480+ id_column = "row_id" ,
481+ ignore_metadata_columns = [],
482+ vector_index_name = title_vector_index_name ,
483+ embedding_column = title_vector_embedding_column ,
484+ embedding_service = embeddings ,
485+ metadata_json_column = "metadata" ,
486+ )
487+
488+ docs = loader .load ()
489+ ids = [str (uuid .uuid4 ()) for _ in range (len (docs ))]
490+ ids_row_inserted = db .add_documents (documents = docs , ids = ids )
491+ assert ids == ids_row_inserted
492+
493+ def test_add_texts (self , setup_database ):
494+ loader , embeddings , _ = setup_database
495+
496+ db = SpannerVectorStore (
497+ instance_id = instance_id ,
498+ database_id = ann_db ,
499+ table_name = table_name_ANN ,
500+ id_column = "row_id" ,
501+ ignore_metadata_columns = [],
502+ vector_index_name = title_vector_index_name ,
503+ embedding_column = title_vector_embedding_column ,
504+ embedding_service = embeddings ,
505+ metadata_json_column = "metadata" ,
506+ )
507+
508+ texts = [
509+ "Langchain Test Text 1" ,
510+ "Langchain Test Text 2" ,
511+ "Langchain Test Text 3" ,
512+ ]
513+ ids = [str (uuid .uuid4 ()) for _ in range (len (texts ))]
514+ ids_row_inserted = db .add_texts (
515+ texts = texts ,
516+ ids = ids ,
517+ metadatas = [
518+ {"title" : "Title 1" },
519+ {"title" : "Title 2" },
520+ {"title" : "Title 3" },
521+ ],
522+ )
523+ assert ids == ids_row_inserted
524+
525+ def test_delete (self , setup_database ):
526+ loader , embeddings , _ = setup_database
527+
528+ db = SpannerVectorStore (
529+ instance_id = instance_id ,
530+ database_id = ann_db ,
531+ table_name = table_name_ANN ,
532+ id_column = "row_id" ,
533+ ignore_metadata_columns = [],
534+ vector_index_name = title_vector_index_name ,
535+ embedding_column = title_vector_embedding_column ,
536+ embedding_service = embeddings ,
537+ metadata_json_column = "metadata" ,
538+ )
539+
540+ docs = loader .load ()
541+ deleted = db .delete (documents = [docs [0 ], docs [1 ]])
542+
543+ assert deleted
544+
545+ def test_similarity_search (self , setup_database ):
546+ loader , embeddings , _ = setup_database
547+
548+ db = SpannerVectorStore (
549+ instance_id = instance_id ,
550+ database_id = ann_db ,
551+ table_name = table_name_ANN ,
552+ id_column = "row_id" ,
553+ ignore_metadata_columns = [],
554+ vector_index_name = title_vector_index_name ,
555+ embedding_service = embeddings ,
556+ embedding_column = title_vector_embedding_column ,
557+ metadata_json_column = "metadata" ,
558+ query_parameters = QueryParameters (
559+ algorithm = QueryParameters .NearestNeighborsAlgorithm .APPROXIMATE_NEAREST_NEIGHBOR ,
560+ distance_strategy = DistanceStrategy .COSINE ,
561+ max_staleness = datetime .timedelta (seconds = 15 ),
562+ ),
563+ )
564+
565+ docs = loader .load ()
566+ ids = [str (uuid .uuid4 ()) for _ in range (len (docs ))]
567+ db .add_documents (documents = docs , ids = ids )
568+
569+ docs = db .similarity_search (
570+ "Testing the langchain integration with spanner" ,
571+ k = 2 ,
572+ )
573+
574+ assert len (docs ) == 2
575+
576+ def test_similarity_search_by_vector (self , setup_database ):
577+ loader , embeddings , _ = setup_database
578+
579+ db = SpannerVectorStore (
580+ instance_id = instance_id ,
581+ database_id = ann_db ,
582+ table_name = table_name_ANN ,
583+ id_column = "row_id" ,
584+ ignore_metadata_columns = [],
585+ vector_index_name = title_vector_index_name ,
586+ embedding_service = embeddings ,
587+ embedding_column = title_vector_embedding_column ,
588+ metadata_json_column = "metadata" ,
589+ query_parameters = QueryParameters (
590+ algorithm = QueryParameters .NearestNeighborsAlgorithm .APPROXIMATE_NEAREST_NEIGHBOR ,
591+ distance_strategy = DistanceStrategy .COSINE ,
592+ max_staleness = datetime .timedelta (seconds = 15 ),
593+ ),
594+ )
595+
596+ docs = loader .load ()
597+ ids = [str (uuid .uuid4 ()) for _ in range (len (docs ))]
598+ db .add_documents (documents = docs , ids = ids )
599+
600+ embeds = embeddings .embed_query (
601+ "Testing the langchain integration with spanner"
602+ )
603+
604+ docs = db .similarity_search_by_vector (
605+ embeds ,
606+ k = 3 ,
607+ )
608+
609+ assert len (docs ) == 3
610+
611+ def test_max_marginal_relevance_search_with_score_by_vector (self , setup_database ):
612+ loader , embeddings , _ = setup_database
613+
614+ db = SpannerVectorStore (
615+ instance_id = instance_id ,
616+ database_id = ann_db ,
617+ table_name = table_name_ANN ,
618+ id_column = "row_id" ,
619+ ignore_metadata_columns = [],
620+ vector_index_name = title_vector_index_name ,
621+ embedding_service = embeddings ,
622+ metadata_json_column = "metadata" ,
623+ embedding_column = title_vector_embedding_column ,
624+ query_parameters = QueryParameters (
625+ algorithm = QueryParameters .NearestNeighborsAlgorithm .APPROXIMATE_NEAREST_NEIGHBOR ,
626+ distance_strategy = DistanceStrategy .COSINE ,
627+ max_staleness = datetime .timedelta (seconds = 15 ),
628+ ),
629+ )
630+
631+ docs = loader .load ()
632+ ids = [str (uuid .uuid4 ()) for _ in range (len (docs ))]
633+ db .add_documents (documents = docs , ids = ids )
634+
635+ embeds = embeddings .embed_query (
636+ "Testing the langchain integration with spanner"
637+ )
638+
639+ docs = db .max_marginal_relevance_search_with_score_by_vector (
640+ embeds ,
641+ k = 3 ,
642+ )
643+
644+ assert len (docs ) == 3
645+
646+ def test_last_for_cleanup (self , setup_database ):
647+ loader , _ , cleanup = setup_database
648+ _ = loader
649+ cleanup ()
650+
651+
392652class TestSpannerVectorStorePGSQL :
393653 @pytest .fixture (scope = "class" )
394654 def setup_database (self , client ):
@@ -565,7 +825,8 @@ def test_spanner_vector_search_data4(self, setup_database):
565825 )
566826
567827 docs = db .max_marginal_relevance_search (
568- "Testing the langchain integration with spanner" , k = 3
828+ "Testing the langchain integration with spanner" ,
829+ k = 3 ,
569830 )
570831
571832 assert len (docs ) == 3
0 commit comments