1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import pytest
16- import os
1715import json
16+ import os
17+
18+ import pytest
1819from google .cloud .spanner import Client , KeySet # type: ignore
1920from langchain_core .documents import Document
2021
2425instance = os .environ ["INSTANCE_ID" ]
2526google_database = os .environ ["GOOGLE_DATABASE" ]
2627pg_database = os .environ ["PG_DATABASE" ]
27- table_name = os .environ ["TABLE_NAME" ]
28+ table_name = os .environ ["TABLE_NAME" ]. replace ( "-" , "_" )
2829
2930OPERATION_TIMEOUT_SECONDS = 240
3031
@@ -51,7 +52,7 @@ def setup_pg_client(client) -> Client:
5152 yield client
5253
5354
54- class TestSpannerDocumentLoader :
55+ class TestSpannerDocumentLoaderGoogleSQL :
5556 @pytest .fixture (autouse = True , scope = "class" )
5657 def setup_database (self , client ):
5758 database = client .instance (instance ).database (google_database )
@@ -275,9 +276,10 @@ def test_loader_custom_json_metadata(self, client):
275276 content_column = "product_id" ,
276277 metadata_columns = [
277278 ("product_name" , "STRING(1024)" , True ),
278- ("description" , "JSON " , False ),
279+ ("description" , "STRING(1024) " , False ),
279280 ("price" , "INT64" , False ),
280281 ],
282+ metadata_json_column = "my_metadata" ,
281283 )
282284
283285 saver = SpannerDocumentSaver (
@@ -287,16 +289,17 @@ def test_loader_custom_json_metadata(self, client):
287289 client ,
288290 content_column = "product_id" ,
289291 metadata_columns = ["product_name" , "description" , "price" ],
292+ metadata_json_column = "my_metadata" ,
290293 )
291294 test_documents = [
292295 Document (
293296 page_content = "1" ,
294297 metadata = {
295298 "product_name" : "cards" ,
296- "description" : json . loads ( '{"player1": " playing cards are cool"}' ) ,
299+ "description" : " playing cards are cool" ,
297300 "price" : 10 ,
298301 "extra_metadata" : "foobar" ,
299- "langchain_metadata " : {
302+ "my_metadata " : {
300303 "foo" : "bar" ,
301304 },
302305 },
@@ -309,23 +312,298 @@ def test_loader_custom_json_metadata(self, client):
309312 google_database ,
310313 query ,
311314 client = client ,
312- metadata_json_column = "description " ,
315+ metadata_json_column = "my_metadata " ,
313316 )
314317 docs = loader .load ()
315318 assert docs == [
316319 Document (
317320 page_content = "1" ,
318321 metadata = {
319322 "product_name" : "cards" ,
320- "player1 " : "playing cards are cool" ,
323+ "description " : "playing cards are cool" ,
321324 "price" : 10 ,
325+ "foo" : "bar" ,
326+ "extra_metadata" : "foobar" ,
327+ },
328+ ),
329+ ]
330+
331+
332+ class TestSpannerDocumentLoaderPostgreSQL :
333+ @pytest .fixture (autouse = True , scope = "class" )
334+ def setup_database (self , client ):
335+ database = client .instance (instance ).database (pg_database )
336+ operation = database .update_ddl ([f"DROP TABLE IF EXISTS { table_name } " ])
337+ operation .result (OPERATION_TIMEOUT_SECONDS )
338+ SpannerDocumentSaver .init_document_table (
339+ instance ,
340+ pg_database ,
341+ table_name ,
342+ content_column = "product_id" ,
343+ metadata_columns = [
344+ ("product_name" , "VARCHAR(1024)" , True ),
345+ ("description" , "VARCHAR(1024)" , False ),
346+ ("price" , "INT" , False ),
347+ ],
348+ )
349+
350+ saver = SpannerDocumentSaver (
351+ instance ,
352+ pg_database ,
353+ table_name ,
354+ client ,
355+ content_column = "product_id" ,
356+ metadata_columns = ["product_name" , "description" , "price" ],
357+ )
358+ test_documents = [
359+ Document (
360+ page_content = "1" ,
361+ metadata = {
362+ "product_name" : "cards" ,
363+ "description" : "playing cards are cool" ,
364+ "price" : 10 ,
365+ "extra_metadata" : "foobar" ,
322366 "langchain_metadata" : {
323367 "foo" : "bar" ,
324- "extra_metadata" : "foobar" ,
325368 },
326369 },
327370 ),
328371 ]
372+ saver .add_documents (test_documents )
373+
374+ # Default CUJs
375+ @pytest .mark .parametrize (
376+ "query, expected" ,
377+ [
378+ pytest .param (
379+ f"SELECT * FROM { table_name } " ,
380+ [
381+ Document (
382+ page_content = "1" ,
383+ metadata = {
384+ "extra_metadata" : "foobar" ,
385+ "foo" : "bar" ,
386+ "product_name" : "cards" ,
387+ "description" : "playing cards are cool" ,
388+ "price" : 10 ,
389+ },
390+ )
391+ ],
392+ ),
393+ pytest .param (
394+ f"SELECT product_name, description FROM { table_name } " ,
395+ [
396+ Document (
397+ page_content = "cards" ,
398+ metadata = {
399+ "description" : "playing cards are cool" ,
400+ },
401+ )
402+ ],
403+ ),
404+ ],
405+ )
406+ def test_loader_with_query (self , client , query , expected ):
407+ loader = SpannerLoader (instance , pg_database , query , client = client )
408+ docs = loader .load ()
409+ assert docs == expected
410+
411+ def test_loader_missing_table_and_query (self ):
412+ with pytest .raises (Exception ):
413+ SpannerLoader (instance , pg_database )
414+
415+ # Custom CUJs
416+ def test_loader_custom_content (self , client ):
417+ query = f"SELECT * FROM { table_name } "
418+ loader = SpannerLoader (
419+ instance ,
420+ pg_database ,
421+ query ,
422+ client = client ,
423+ content_columns = ["description" , "price" ],
424+ )
425+ docs = loader .load ()
426+ assert docs == [
427+ Document (
428+ page_content = "playing cards are cool 10" ,
429+ metadata = {
430+ "extra_metadata" : "foobar" ,
431+ "foo" : "bar" ,
432+ "product_id" : "1" ,
433+ "product_name" : "cards" ,
434+ },
435+ ),
436+ ]
437+
438+ def test_loader_custom_metadata (self , client ):
439+ query = f"SELECT * FROM { table_name } "
440+ loader = SpannerLoader (
441+ instance ,
442+ pg_database ,
443+ query ,
444+ client = client ,
445+ metadata_columns = ["product_name" , "price" ],
446+ )
447+ docs = loader .load ()
448+ assert docs == [
449+ Document (
450+ page_content = "1" ,
451+ metadata = {"product_name" : "cards" , "price" : 10 },
452+ ),
453+ ]
454+
455+ def test_loader_custom_content_and_metadata (self , client ):
456+ query = f"SELECT * FROM { table_name } "
457+ loader = SpannerLoader (
458+ instance ,
459+ pg_database ,
460+ query ,
461+ client = client ,
462+ content_columns = ["product_name" ],
463+ metadata_columns = ["product_id" , "price" ],
464+ )
465+ docs = loader .load ()
466+ assert docs == [
467+ Document (
468+ page_content = "cards" ,
469+ metadata = {"product_id" : "1" , "price" : 10 },
470+ ),
471+ ]
472+
473+ def test_loader_custom_format_json (self , client ):
474+ query = f"SELECT * FROM { table_name } "
475+ loader = SpannerLoader (
476+ instance ,
477+ pg_database ,
478+ query ,
479+ client = client ,
480+ content_columns = ["product_id" , "product_name" ],
481+ format = "JSON" ,
482+ )
483+ docs = loader .load ()
484+ assert docs == [
485+ Document (
486+ page_content = "product_id: 1 product_name: cards" ,
487+ metadata = {
488+ "extra_metadata" : "foobar" ,
489+ "foo" : "bar" ,
490+ "description" : "playing cards are cool" ,
491+ "price" : 10 ,
492+ },
493+ )
494+ ]
495+
496+ def test_loader_custom_format_yaml (self , client ):
497+ query = f"SELECT * FROM { table_name } "
498+ loader = SpannerLoader (
499+ instance , pg_database , query , client = client , format = "YAML"
500+ )
501+ docs = loader .load ()
502+ assert docs == [
503+ Document (
504+ page_content = "product_id: 1" ,
505+ metadata = {
506+ "product_name" : "cards" ,
507+ "description" : "playing cards are cool" ,
508+ "price" : 10 ,
509+ "foo" : "bar" ,
510+ "extra_metadata" : "foobar" ,
511+ },
512+ )
513+ ]
514+
515+ def test_loader_custom_format_csv (self , client ):
516+ query = f"SELECT * FROM { table_name } "
517+ loader = SpannerLoader (
518+ instance , pg_database , query , client = client , format = "CSV"
519+ )
520+ docs = loader .load ()
521+ assert docs == [
522+ Document (
523+ page_content = "1" ,
524+ metadata = {
525+ "product_name" : "cards" ,
526+ "description" : "playing cards are cool" ,
527+ "price" : 10 ,
528+ "foo" : "bar" ,
529+ "extra_metadata" : "foobar" ,
530+ },
531+ )
532+ ]
533+
534+ def test_loader_custom_format_error (self , client ):
535+ query = f"SELECT * FROM { table_name } "
536+ with pytest .raises (Exception ):
537+ SpannerLoader (
538+ instance ,
539+ pg_database ,
540+ query ,
541+ client ,
542+ format = "NOT_A_FORMAT" ,
543+ )
544+
545+ def test_loader_custom_json_metadata (self , client ):
546+ database = client .instance (instance ).database (pg_database )
547+ operation = database .update_ddl ([f"DROP TABLE IF EXISTS { table_name } " ])
548+ operation .result (OPERATION_TIMEOUT_SECONDS )
549+ SpannerDocumentSaver .init_document_table (
550+ instance ,
551+ pg_database ,
552+ table_name ,
553+ content_column = "product_id" ,
554+ metadata_columns = [
555+ ("product_name" , "VARCHAR(1024)" , True ),
556+ ("description" , "VARCHAR(1024)" , False ),
557+ ("price" , "INT" , False ),
558+ ],
559+ metadata_json_column = "my_metadata" ,
560+ )
561+
562+ saver = SpannerDocumentSaver (
563+ instance ,
564+ pg_database ,
565+ table_name ,
566+ client ,
567+ content_column = "product_id" ,
568+ metadata_columns = ["product_name" , "description" , "price" ],
569+ metadata_json_column = "my_metadata" ,
570+ )
571+ test_documents = [
572+ Document (
573+ page_content = "1" ,
574+ metadata = {
575+ "product_name" : "cards" ,
576+ "description" : "playing cards are cool" ,
577+ "price" : 10 ,
578+ "extra_metadata" : "foobar" ,
579+ "my_metadata" : {
580+ "foo" : "bar" ,
581+ },
582+ },
583+ ),
584+ ]
585+ saver .add_documents (test_documents )
586+ query = f"SELECT * FROM { table_name } "
587+ loader = SpannerLoader (
588+ instance ,
589+ pg_database ,
590+ query ,
591+ client = client ,
592+ metadata_json_column = "my_metadata" ,
593+ )
594+ docs = loader .load ()
595+ assert docs == [
596+ Document (
597+ page_content = "1" ,
598+ metadata = {
599+ "product_name" : "cards" ,
600+ "description" : "playing cards are cool" ,
601+ "price" : 10 ,
602+ "foo" : "bar" ,
603+ "extra_metadata" : "foobar" ,
604+ },
605+ ),
606+ ]
329607
330608
331609class TestSpannerDocumentSaver :
@@ -384,7 +662,10 @@ def test_saver_google_sql_with_custom_schema(self, google_client):
384662 )
385663 query = f"SELECT * FROM { table_name } "
386664 loader = SpannerLoader (
387- client = google_client , instance = instance , database = pg_database , query = query
665+ client = google_client ,
666+ instance = instance ,
667+ database = google_database ,
668+ query = query ,
388669 )
389670 expected_docs = [
390671 Document (
0 commit comments