Skip to content

Commit a561b82

Browse files
committed
fix mypy
1 parent c2836b8 commit a561b82

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/langchain_google_spanner/document_loader.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
from typing import Any, Dict, Iterator, List, Optional
1818

19-
from google.cloud.spanner import Client, KeySet
19+
from google.cloud.spanner import Client, KeySet # type: ignore
2020
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect # type: ignore
2121
from google.cloud.spanner_v1.data_types import JsonObject # type: ignore
2222
from langchain_community.document_loaders.base import BaseLoader
@@ -68,16 +68,15 @@ def _load_doc_to_row(table_fields: List[str], doc: Document, metadata_json_colum
6868
and col != metadata_json_column
6969
and col in doc_metadata
7070
):
71-
row += (doc_metadata[col],)
71+
row = (*row, doc_metadata[col]) # type: ignore
7272
del doc_metadata[col]
7373
if metadata_json_column in table_fields:
7474
metadata_json = {}
75-
print(f"metadata json column is {metadata_json_column}")
7675
if metadata_json_column in doc_metadata:
7776
metadata_json = doc_metadata[metadata_json_column]
7877
del doc_metadata[metadata_json_column]
7978
metadata_json = {**metadata_json, **doc_metadata}
80-
row += (json.dumps(metadata_json),)
79+
row += (json.dumps(metadata_json),) # type: ignore
8180
return row
8281

8382

@@ -95,29 +94,29 @@ def __init__(
9594
instance: str,
9695
database: str,
9796
query: str,
97+
client: Client = Client(),
9898
content_columns: List[str] = [],
9999
metadata_columns: List[str] = [],
100100
format: str = "text",
101101
databoost: bool = False,
102102
metadata_json_column: str = "",
103-
client: Optional[Client] = Client(),
104-
staleness: Optional[int] = 0,
103+
staleness: float = 0.0,
105104
):
106105
"""Initialize Spanner document loader.
107106
108107
Args:
109108
instance: The Spanner instance to load data from.
110109
database: The Spanner database to load data from.
111110
query: A GoogleSQL or PostgreSQL query. Users must match dialect to their database.
111+
client: The connection object to use. This can be used to customize project id and credentials.
112112
content_columns: The list of column(s) or field(s) to use for a Document's page content.
113113
Page content is the default field for embeddings generation.
114114
metadata_columns: The list of column(s) or field(s) to use for metadata.
115115
format: Set the format of page content if using multiple columns or fields.
116116
Format included: 'text', 'JSON', 'YAML', 'CSV'.
117117
databoost: Use data boost on read. Note: needs extra IAM permissions and higher cost.
118118
metadata_json_column: The name of the JSON column to use as the metadata's base dictionary.
119-
client: Optional. The connection object to use. This can be used to customize project id and credentials.
120-
staleness: Optional. The time bound for stale read.
119+
staleness: The time bound for stale read.
121120
"""
122121
self.instance = instance
123122
self.database = database
@@ -271,7 +270,7 @@ def init_document_table(
271270
metadata_columns: List[Any] = [],
272271
primary_key: str = "",
273272
store_metadata: bool = True,
274-
metadata_json_column: str = "",
273+
metadata_json_column: Optional[str] = "",
275274
):
276275
"""
277276
Create a new table to store docs with a custom schema.
@@ -285,6 +284,7 @@ def init_document_table(
285284
primary_key: The name of the primary key.
286285
store_metadata: If true, extra metadata will be stored in the "langchain_metadata" column.
287286
Defaulted to true.
287+
metadata_json_column: Optional. The name of the special JSON column. Defaulted to use "langchain_metadata".
288288
"""
289289
content_column = content_column or CONTENT_COL_NAME
290290
primary_key = primary_key or content_column
@@ -322,9 +322,9 @@ def create_table(
322322
metadata_columns: List[Any],
323323
):
324324
"""Create a new table in Spanner database."""
325-
database = client.instance(instance).database(database)
326-
database.reload()
327-
dialect = database.database_dialect
325+
spanner_database = client.instance(instance).database(database)
326+
spanner_database.reload()
327+
dialect = spanner_database.database_dialect
328328

329329
ddl = f"CREATE TABLE {table_name} ("
330330
if dialect == DatabaseDialect.POSTGRESQL:
@@ -344,5 +344,5 @@ def create_table(
344344
ddl += f"{metadata_json_column} JSON NOT NULL,"
345345
ddl += f") PRIMARY KEY ({primary_key})"
346346

347-
operation = database.update_ddl([ddl])
347+
operation = spanner_database.update_ddl([ddl])
348348
operation.result(OPERATION_TIMEOUT_SECONDS)

0 commit comments

Comments
 (0)