Skip to content

Commit 09172ac

Browse files
committed
update lint and tests
1 parent 6cf4e32 commit 09172ac

File tree

3 files changed

+71
-67
lines changed

3 files changed

+71
-67
lines changed

docs/document_loader.ipynb

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,7 @@
143143
"outputs": [],
144144
"source": [
145145
"custom_content_loader = SpannerLoader(\n",
146-
" instance_id,\n",
147-
" database_id,\n",
148-
" query,\n",
149-
" content_columns = [\"custom_content\"]\n",
146+
" instance_id, database_id, query, content_columns=[\"custom_content\"]\n",
150147
")"
151148
]
152149
},
@@ -173,10 +170,7 @@
173170
"outputs": [],
174171
"source": [
175172
"custom_metadata_loader = SpannerLoader(\n",
176-
" instance_id,\n",
177-
" database_id,\n",
178-
" query,\n",
179-
" metadata_columns = [\"column1\", \"column2\"]\n",
173+
" instance_id, database_id, query, metadata_columns=[\"column1\", \"column2\"]\n",
180174
")"
181175
]
182176
},
@@ -196,10 +190,7 @@
196190
"outputs": [],
197191
"source": [
198192
"custom_metadata_json_loader = SpannerLoader(\n",
199-
" instance_id,\n",
200-
" database_id,\n",
201-
" query,\n",
202-
" metadata_json_column = \"another-json-column\"\n",
193+
" instance_id, database_id, query, metadata_json_column=\"another-json-column\"\n",
203194
")"
204195
]
205196
},
@@ -223,7 +214,7 @@
223214
" instance_id,\n",
224215
" database_id,\n",
225216
" query,\n",
226-
" staleness = timestamp,\n",
217+
" staleness=timestamp,\n",
227218
")"
228219
]
229220
},
@@ -238,7 +229,7 @@
238229
" instance_id,\n",
239230
" database_id,\n",
240231
" query,\n",
241-
" staleness = duration,\n",
232+
" staleness=duration,\n",
242233
")"
243234
]
244235
},
@@ -261,7 +252,7 @@
261252
" instance_id,\n",
262253
" database_id,\n",
263254
" query,\n",
264-
" databoost = True,\n",
255+
" databoost=True,\n",
265256
")"
266257
]
267258
},
@@ -404,6 +395,8 @@
404395
"metadata": {},
405396
"outputs": [],
406397
"source": [
398+
"from langchain_google_spanner import Column\n",
399+
"\n",
407400
"new_table_name = \"my_new_table\"\n",
408401
"\n",
409402
"SpannerDocumentSaver.init_document_table(\n",
@@ -412,8 +405,8 @@
412405
" new_table_name,\n",
413406
" content_column=\"my-page-content\",\n",
414407
" metadata_columns=[\n",
415-
" ('category', 'STRING(36)', True),\n",
416-
" ('price', 'FLOAT64', False),\n",
408+
" Column(\"category\", \"STRING(36)\", True),\n",
409+
" Column(\"price\", \"FLOAT64\", False),\n",
417410
" ],\n",
418411
")"
419412
]

src/langchain_google_spanner/document_loader.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import datetime
1616
import json
1717
from dataclasses import dataclass
18-
from typing import Any, Dict, Iterator, List, Optional, Union
18+
from typing import Any, Dict, Iterator, List, Union
1919

2020
from google.cloud.spanner import Client, KeySet # type: ignore
2121
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect # type: ignore
@@ -107,7 +107,7 @@ def _load_doc_to_row(
107107
del doc_metadata[metadata_json_column]
108108
metadata_json = {**metadata_json, **doc_metadata}
109109
j = json.dumps(metadata_json) if parse_json else metadata_json
110-
row.append(j)
110+
row.append(j) # type: ignore
111111
return tuple(row)
112112

113113

@@ -131,7 +131,7 @@ def __init__(
131131
format: str = "text",
132132
databoost: bool = False,
133133
metadata_json_column: str = METADATA_COL_NAME,
134-
staleness: Union[float, datetime.datetime] = 15.0,
134+
staleness: Union[float, datetime.datetime] = 0.0,
135135
):
136136
"""Initialize Spanner document loader.
137137
@@ -162,13 +162,11 @@ def __init__(
162162
self.databoost = databoost
163163
self.client = client
164164
self.staleness = staleness
165-
if not self.client.instance(self.instance_id).exists():
165+
instance = self.client.instance(instance_id)
166+
if not instance.exists():
166167
raise Exception("Instance doesn't exist.")
167-
if (
168-
not self.client.instance(self.instance_id)
169-
.database(self.database_id)
170-
.exists()
171-
):
168+
database = instance.database(database_id)
169+
if not database.exists():
172170
raise Exception("Database doesn't exist.")
173171

174172
def load(self) -> List[Document]:
@@ -235,9 +233,9 @@ def __init__(
235233
database_id: str,
236234
table_name: str,
237235
client: Client = Client(),
238-
content_column: Optional[str] = "",
239-
metadata_columns: Optional[List[str]] = [],
240-
metadata_json_column: Optional[str] = METADATA_COL_NAME,
236+
content_column: str = "",
237+
metadata_columns: List[str] = [],
238+
metadata_json_column: str = METADATA_COL_NAME,
241239
):
242240
"""Initialize Spanner document saver.
243241
@@ -246,11 +244,10 @@ def __init__(
246244
database_id: The Spanner database to load data to.
247245
table_name: The table name to load data to.
248246
client: The connection object to use. This can be used to customized project id and credentials.
249-
content_column: Optional. The name of the content column. Defaulted to the first column.
250-
metadata_columns: Optional. This is for user to opt-in a selection of columns to use. Defaulted to use
247+
content_column: The name of the content column. Defaulted to the first column.
248+
metadata_columns: This is for user to opt-in a selection of columns to use. Defaulted to use
251249
all columns.
252-
store_metadata: If true, extra metadata will be stored in the "langchain_metadata" column.
253-
metadata_json_column: Optional. The name of the special JSON column. Defaulted to use "langchain_metadata".
250+
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
254251
"""
255252
self.instance_id = instance_id
256253
self.database_id = database_id
@@ -329,7 +326,7 @@ def init_document_table(
329326
metadata_columns: List[Column] = [],
330327
primary_key: str = "",
331328
store_metadata: bool = True,
332-
metadata_json_column: Optional[str] = None,
329+
metadata_json_column: str = METADATA_COL_NAME,
333330
):
334331
"""
335332
Create a new table to store docs with a custom schema.
@@ -343,13 +340,11 @@ def init_document_table(
343340
primary_key: The name of the primary key.
344341
store_metadata: If true, extra metadata will be stored in the "langchain_metadata" column.
345342
Defaulted to true.
346-
metadata_json_column: Optional. The name of the special JSON column. Defaulted to use "langchain_metadata".
343+
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
347344
"""
348345
primary_key = primary_key or content_column
349-
metadata_json_column = (
350-
(metadata_json_column or METADATA_COL_NAME) if store_metadata else None
351-
)
352346
client = Client()
347+
metadata_json_column = metadata_json_column if store_metadata else ""
353348
instance = client.instance(instance_id)
354349
if not instance.exists():
355350
raise Exception("Instance doesn't exist.")
@@ -375,7 +370,7 @@ def create_table(
375370
database_id: str,
376371
table_name: str,
377372
primary_key: str,
378-
metadata_json_column: Optional[str],
373+
metadata_json_column: str,
379374
content_column: str,
380375
metadata_columns: List[Column],
381376
):

tests/document_loader_test.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from google.cloud.spanner import Client, KeySet # type: ignore
2020
from langchain_core.documents import Document
2121

22-
from langchain_google_spanner.document_loader import SpannerDocumentSaver, SpannerLoader
22+
from langchain_google_spanner.document_loader import (
23+
Column,
24+
SpannerDocumentSaver,
25+
SpannerLoader,
26+
)
2327

2428
project_id = os.environ["PROJECT_ID"]
2529
instance_id = os.environ["INSTANCE_ID"]
@@ -47,9 +51,9 @@ def setup_database(self, client):
4751
table_name,
4852
content_column="product_id",
4953
metadata_columns=[
50-
("product_name", "STRING(1024)", True),
51-
("description", "STRING(1024)", False),
52-
("price", "INT64", False),
54+
Column("product_name", "STRING(1024)", True),
55+
Column("description", "STRING(1024)", False),
56+
Column("price", "INT64", False),
5357
],
5458
)
5559

@@ -258,9 +262,9 @@ def test_loader_custom_json_metadata(self, client):
258262
table_name,
259263
content_column="product_id",
260264
metadata_columns=[
261-
("product_name", "STRING(1024)", True),
262-
("description", "STRING(1024)", False),
263-
("price", "INT64", False),
265+
Column("product_name", "STRING(1024)", True),
266+
Column("description", "STRING(1024)", False),
267+
Column("price", "INT64", False),
264268
],
265269
metadata_json_column="my_metadata",
266270
)
@@ -324,9 +328,9 @@ def setup_database(self, client):
324328
table_name,
325329
content_column="product_id",
326330
metadata_columns=[
327-
("product_name", "VARCHAR(1024)", True),
328-
("description", "VARCHAR(1024)", False),
329-
("price", "INT", False),
331+
Column("product_name", "VARCHAR(1024)", True),
332+
Column("description", "VARCHAR(1024)", False),
333+
Column("price", "INT", False),
330334
],
331335
)
332336

@@ -535,9 +539,9 @@ def test_loader_custom_json_metadata(self, client):
535539
table_name,
536540
content_column="product_id",
537541
metadata_columns=[
538-
("product_name", "VARCHAR(1024)", True),
539-
("description", "VARCHAR(1024)", False),
540-
("price", "INT", False),
542+
Column("product_name", "VARCHAR(1024)", True),
543+
Column("description", "VARCHAR(1024)", False),
544+
Column("price", "INT", False),
541545
],
542546
metadata_json_column="my_metadata",
543547
)
@@ -606,15 +610,17 @@ def setup_pg_client(self, client) -> Client:
606610
yield client
607611

608612
def test_saver_google_sql(self, google_client):
609-
SpannerDocumentSaver.init_document_table(instance_id, google_database, table_name)
613+
SpannerDocumentSaver.init_document_table(
614+
instance_id, google_database, table_name
615+
)
610616
saver = SpannerDocumentSaver(
611617
instance_id, google_database, table_name, google_client
612618
)
613619
query = f"SELECT * FROM {table_name}"
614620
loader = SpannerLoader(
615621
client=google_client,
616-
instance=instance_id,
617-
database=google_database,
622+
instance_id=instance_id,
623+
database_id=google_database,
618624
query=query,
619625
)
620626
expected_docs = [
@@ -631,7 +637,10 @@ def test_saver_pg(self, pg_client):
631637
saver = SpannerDocumentSaver(instance_id, pg_database, table_name, pg_client)
632638
query = f"SELECT * FROM {table_name}"
633639
loader = SpannerLoader(
634-
client=pg_client, instance=instance_id, database=pg_database, query=query
640+
client=pg_client,
641+
instance_id=instance_id,
642+
database_id=pg_database,
643+
query=query,
635644
)
636645
expected_docs = [
637646
Document(page_content="Hello, World!", metadata={"source": "my-computer"}),
@@ -649,8 +658,8 @@ def test_saver_google_sql_with_custom_schema(self, google_client):
649658
table_name,
650659
content_column="my_page_content",
651660
metadata_columns=[
652-
("category", "STRING(35)", True),
653-
("price", "INT64", False),
661+
Column("category", "STRING(35)", True),
662+
Column("price", "INT64", False),
654663
],
655664
primary_key="my_page_content",
656665
store_metadata=True,
@@ -661,8 +670,8 @@ def test_saver_google_sql_with_custom_schema(self, google_client):
661670
query = f"SELECT * FROM {table_name}"
662671
loader = SpannerLoader(
663672
client=google_client,
664-
instance=instance_id,
665-
database=google_database,
673+
instance_id=instance_id,
674+
database_id=google_database,
666675
query=query,
667676
)
668677
expected_docs = [
@@ -691,16 +700,19 @@ def test_saver_pg_with_custom_schema(self, pg_client):
691700
table_name,
692701
content_column="my_page_content",
693702
metadata_columns=[
694-
("category", "VARCHAR(35)", True),
695-
("price", "INT", False),
703+
Column("category", "VARCHAR(35)", True),
704+
Column("price", "INT", False),
696705
],
697706
primary_key="my_page_content",
698707
store_metadata=True,
699708
)
700709
saver = SpannerDocumentSaver(instance_id, pg_database, table_name, pg_client)
701710
query = f"SELECT * FROM {table_name}"
702711
loader = SpannerLoader(
703-
client=pg_client, instance=instance_id, database=pg_database, query=query
712+
client=pg_client,
713+
instance_id=instance_id,
714+
database_id=pg_database,
715+
query=query,
704716
)
705717
expected_docs = [
706718
Document(
@@ -722,15 +734,17 @@ def test_saver_pg_with_custom_schema(self, pg_client):
722734
]
723735

724736
def test_delete(self, google_client):
725-
SpannerDocumentSaver.init_document_table(instance_id, google_database, table_name)
737+
SpannerDocumentSaver.init_document_table(
738+
instance_id, google_database, table_name
739+
)
726740
saver = SpannerDocumentSaver(
727741
instance_id, google_database, table_name, google_client
728742
)
729743
query = f"SELECT * FROM {table_name}"
730744
loader = SpannerLoader(
731745
client=google_client,
732-
instance=instance_id,
733-
database=google_database,
746+
instance_id=instance_id,
747+
database_id=google_database,
734748
query=query,
735749
)
736750
expected_docs = [
@@ -745,7 +759,9 @@ def test_delete(self, google_client):
745759
assert loader.load() == [expected_docs[1]]
746760

747761
def test_saver_with_bad_docs(self, google_client):
748-
SpannerDocumentSaver.init_document_table(instance_id, google_database, table_name)
762+
SpannerDocumentSaver.init_document_table(
763+
instance_id, google_database, table_name
764+
)
749765
saver = SpannerDocumentSaver(
750766
instance_id, google_database, table_name, google_client
751767
)

0 commit comments

Comments
 (0)