Skip to content

Commit 12cfd97

Browse files
committed
update tests
1 parent 8541d97 commit 12cfd97

File tree

2 files changed

+293
-12
lines changed

2 files changed

+293
-12
lines changed

src/langchain_google_spanner/document_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import datetime
1616
import json
17-
from typing import Iterator, List, Optional, Dict, Any
17+
from typing import Any, Dict, Iterator, List, Optional
1818

1919
from google.cloud.spanner import Client, KeySet
2020
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect # type: ignore

src/tests/document_loader_test.py

Lines changed: 292 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pytest
16-
import os
1715
import json
16+
import os
17+
18+
import pytest
1819
from google.cloud.spanner import Client, KeySet # type: ignore
1920
from langchain_core.documents import Document
2021

@@ -24,7 +25,7 @@
2425
instance = os.environ["INSTANCE_ID"]
2526
google_database = os.environ["GOOGLE_DATABASE"]
2627
pg_database = os.environ["PG_DATABASE"]
27-
table_name = os.environ["TABLE_NAME"]
28+
table_name = os.environ["TABLE_NAME"].replace("-","_")
2829

2930
OPERATION_TIMEOUT_SECONDS = 240
3031

@@ -51,7 +52,7 @@ def setup_pg_client(client) -> Client:
5152
yield client
5253

5354

54-
class TestSpannerDocumentLoader:
55+
class TestSpannerDocumentLoaderGoogleSQL:
5556
@pytest.fixture(autouse=True, scope="class")
5657
def setup_database(self, client):
5758
database = client.instance(instance).database(google_database)
@@ -275,9 +276,10 @@ def test_loader_custom_json_metadata(self, client):
275276
content_column="product_id",
276277
metadata_columns=[
277278
("product_name", "STRING(1024)", True),
278-
("description", "JSON", False),
279+
("description", "STRING(1024)", False),
279280
("price", "INT64", False),
280281
],
282+
metadata_json_column="my_metadata",
281283
)
282284

283285
saver = SpannerDocumentSaver(
@@ -287,16 +289,17 @@ def test_loader_custom_json_metadata(self, client):
287289
client,
288290
content_column="product_id",
289291
metadata_columns=["product_name", "description", "price"],
292+
metadata_json_column="my_metadata",
290293
)
291294
test_documents = [
292295
Document(
293296
page_content="1",
294297
metadata={
295298
"product_name": "cards",
296-
"description": json.loads('{"player1": "playing cards are cool"}'),
299+
"description": "playing cards are cool",
297300
"price": 10,
298301
"extra_metadata": "foobar",
299-
"langchain_metadata": {
302+
"my_metadata": {
300303
"foo": "bar",
301304
},
302305
},
@@ -309,23 +312,298 @@ def test_loader_custom_json_metadata(self, client):
309312
google_database,
310313
query,
311314
client=client,
312-
metadata_json_column="description",
315+
metadata_json_column="my_metadata",
313316
)
314317
docs = loader.load()
315318
assert docs == [
316319
Document(
317320
page_content="1",
318321
metadata={
319322
"product_name": "cards",
320-
"player1": "playing cards are cool",
323+
"description": "playing cards are cool",
321324
"price": 10,
325+
"foo": "bar",
326+
"extra_metadata": "foobar",
327+
},
328+
),
329+
]
330+
331+
332+
class TestSpannerDocumentLoaderPostgreSQL:
333+
@pytest.fixture(autouse=True, scope="class")
334+
def setup_database(self, client):
335+
database = client.instance(instance).database(pg_database)
336+
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
337+
operation.result(OPERATION_TIMEOUT_SECONDS)
338+
SpannerDocumentSaver.init_document_table(
339+
instance,
340+
pg_database,
341+
table_name,
342+
content_column="product_id",
343+
metadata_columns=[
344+
("product_name", "VARCHAR(1024)", True),
345+
("description", "VARCHAR(1024)", False),
346+
("price", "INT", False),
347+
],
348+
)
349+
350+
saver = SpannerDocumentSaver(
351+
instance,
352+
pg_database,
353+
table_name,
354+
client,
355+
content_column="product_id",
356+
metadata_columns=["product_name", "description", "price"],
357+
)
358+
test_documents = [
359+
Document(
360+
page_content="1",
361+
metadata={
362+
"product_name": "cards",
363+
"description": "playing cards are cool",
364+
"price": 10,
365+
"extra_metadata": "foobar",
322366
"langchain_metadata": {
323367
"foo": "bar",
324-
"extra_metadata": "foobar",
325368
},
326369
},
327370
),
328371
]
372+
saver.add_documents(test_documents)
373+
374+
# Default CUJs
375+
@pytest.mark.parametrize(
376+
"query, expected",
377+
[
378+
pytest.param(
379+
f"SELECT * FROM {table_name}",
380+
[
381+
Document(
382+
page_content="1",
383+
metadata={
384+
"extra_metadata": "foobar",
385+
"foo": "bar",
386+
"product_name": "cards",
387+
"description": "playing cards are cool",
388+
"price": 10,
389+
},
390+
)
391+
],
392+
),
393+
pytest.param(
394+
f"SELECT product_name, description FROM {table_name}",
395+
[
396+
Document(
397+
page_content="cards",
398+
metadata={
399+
"description": "playing cards are cool",
400+
},
401+
)
402+
],
403+
),
404+
],
405+
)
406+
def test_loader_with_query(self, client, query, expected):
407+
loader = SpannerLoader(instance, pg_database, query, client=client)
408+
docs = loader.load()
409+
assert docs == expected
410+
411+
def test_loader_missing_table_and_query(self):
412+
with pytest.raises(Exception):
413+
SpannerLoader(instance, pg_database)
414+
415+
# Custom CUJs
416+
def test_loader_custom_content(self, client):
417+
query = f"SELECT * FROM {table_name}"
418+
loader = SpannerLoader(
419+
instance,
420+
pg_database,
421+
query,
422+
client=client,
423+
content_columns=["description", "price"],
424+
)
425+
docs = loader.load()
426+
assert docs == [
427+
Document(
428+
page_content="playing cards are cool 10",
429+
metadata={
430+
"extra_metadata": "foobar",
431+
"foo": "bar",
432+
"product_id": "1",
433+
"product_name": "cards",
434+
},
435+
),
436+
]
437+
438+
def test_loader_custom_metadata(self, client):
439+
query = f"SELECT * FROM {table_name}"
440+
loader = SpannerLoader(
441+
instance,
442+
pg_database,
443+
query,
444+
client=client,
445+
metadata_columns=["product_name", "price"],
446+
)
447+
docs = loader.load()
448+
assert docs == [
449+
Document(
450+
page_content="1",
451+
metadata={"product_name": "cards", "price": 10},
452+
),
453+
]
454+
455+
def test_loader_custom_content_and_metadata(self, client):
456+
query = f"SELECT * FROM {table_name}"
457+
loader = SpannerLoader(
458+
instance,
459+
pg_database,
460+
query,
461+
client=client,
462+
content_columns=["product_name"],
463+
metadata_columns=["product_id", "price"],
464+
)
465+
docs = loader.load()
466+
assert docs == [
467+
Document(
468+
page_content="cards",
469+
metadata={"product_id": "1", "price": 10},
470+
),
471+
]
472+
473+
def test_loader_custom_format_json(self, client):
474+
query = f"SELECT * FROM {table_name}"
475+
loader = SpannerLoader(
476+
instance,
477+
pg_database,
478+
query,
479+
client=client,
480+
content_columns=["product_id", "product_name"],
481+
format="JSON",
482+
)
483+
docs = loader.load()
484+
assert docs == [
485+
Document(
486+
page_content="product_id: 1 product_name: cards",
487+
metadata={
488+
"extra_metadata": "foobar",
489+
"foo": "bar",
490+
"description": "playing cards are cool",
491+
"price": 10,
492+
},
493+
)
494+
]
495+
496+
def test_loader_custom_format_yaml(self, client):
497+
query = f"SELECT * FROM {table_name}"
498+
loader = SpannerLoader(
499+
instance, pg_database, query, client=client, format="YAML"
500+
)
501+
docs = loader.load()
502+
assert docs == [
503+
Document(
504+
page_content="product_id: 1",
505+
metadata={
506+
"product_name": "cards",
507+
"description": "playing cards are cool",
508+
"price": 10,
509+
"foo": "bar",
510+
"extra_metadata": "foobar",
511+
},
512+
)
513+
]
514+
515+
def test_loader_custom_format_csv(self, client):
516+
query = f"SELECT * FROM {table_name}"
517+
loader = SpannerLoader(
518+
instance, pg_database, query, client=client, format="CSV"
519+
)
520+
docs = loader.load()
521+
assert docs == [
522+
Document(
523+
page_content="1",
524+
metadata={
525+
"product_name": "cards",
526+
"description": "playing cards are cool",
527+
"price": 10,
528+
"foo": "bar",
529+
"extra_metadata": "foobar",
530+
},
531+
)
532+
]
533+
534+
def test_loader_custom_format_error(self, client):
535+
query = f"SELECT * FROM {table_name}"
536+
with pytest.raises(Exception):
537+
SpannerLoader(
538+
instance,
539+
pg_database,
540+
query,
541+
client,
542+
format="NOT_A_FORMAT",
543+
)
544+
545+
def test_loader_custom_json_metadata(self, client):
546+
database = client.instance(instance).database(pg_database)
547+
operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"])
548+
operation.result(OPERATION_TIMEOUT_SECONDS)
549+
SpannerDocumentSaver.init_document_table(
550+
instance,
551+
pg_database,
552+
table_name,
553+
content_column="product_id",
554+
metadata_columns=[
555+
("product_name", "VARCHAR(1024)", True),
556+
("description", "VARCHAR(1024)", False),
557+
("price", "INT", False),
558+
],
559+
metadata_json_column="my_metadata",
560+
)
561+
562+
saver = SpannerDocumentSaver(
563+
instance,
564+
pg_database,
565+
table_name,
566+
client,
567+
content_column="product_id",
568+
metadata_columns=["product_name", "description", "price"],
569+
metadata_json_column="my_metadata",
570+
)
571+
test_documents = [
572+
Document(
573+
page_content="1",
574+
metadata={
575+
"product_name": "cards",
576+
"description": "playing cards are cool",
577+
"price": 10,
578+
"extra_metadata": "foobar",
579+
"my_metadata": {
580+
"foo": "bar",
581+
},
582+
},
583+
),
584+
]
585+
saver.add_documents(test_documents)
586+
query = f"SELECT * FROM {table_name}"
587+
loader = SpannerLoader(
588+
instance,
589+
pg_database,
590+
query,
591+
client=client,
592+
metadata_json_column="my_metadata",
593+
)
594+
docs = loader.load()
595+
assert docs == [
596+
Document(
597+
page_content="1",
598+
metadata={
599+
"product_name": "cards",
600+
"description": "playing cards are cool",
601+
"price": 10,
602+
"foo": "bar",
603+
"extra_metadata": "foobar",
604+
},
605+
),
606+
]
329607

330608

331609
class TestSpannerDocumentSaver:
@@ -384,7 +662,10 @@ def test_saver_google_sql_with_custom_schema(self, google_client):
384662
)
385663
query = f"SELECT * FROM {table_name}"
386664
loader = SpannerLoader(
387-
client=google_client, instance=instance, database=pg_database, query=query
665+
client=google_client,
666+
instance=instance,
667+
database=google_database,
668+
query=query,
388669
)
389670
expected_docs = [
390671
Document(

0 commit comments

Comments
 (0)