Skip to content

Commit 33fb157

Browse files
committed
update user agent saver and primary key
1 parent 272f444 commit 33fb157

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

src/langchain_google_spanner/loader.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .version import __version__
2727

2828
USER_AGENT_LOADER = "langchain-google-spanner-python:document_loader" + __version__
29+
USER_AGENT_SAVER = "langchain-google-spanner-python:document_saver" + __version__
2930

3031
OPERATION_TIMEOUT_SECONDS = 240
3132
MUTATION_BATCH_SIZE = 1000
@@ -44,6 +45,7 @@ class Column:
4445
def client_with_user_agent(client: Optional[Client], user_agent: str) -> Client:
4546
if not client:
4647
client = Client()
48+
4749
client_agent = client._client_info.user_agent
4850
if not client_agent:
4951
client._client_info.user_agent = user_agent
@@ -75,7 +77,6 @@ def _load_row_to_doc(
7577
metadata: Dict[str, Any] = {}
7678
if metadata_json_column in metadata_columns and row.get(metadata_json_column):
7779
metadata = row[metadata_json_column]
78-
7980
for c in metadata_columns:
8081
if c != metadata_json_column:
8182
metadata[c] = row[c]
@@ -113,6 +114,7 @@ def _load_doc_to_row(
113114
del doc_metadata[col]
114115
if col == content_column:
115116
row.append(doc.page_content)
117+
116118
if metadata_json_column in table_fields:
117119
metadata_json = {}
118120
if metadata_json_column in doc_metadata:
@@ -121,6 +123,7 @@ def _load_doc_to_row(
121123
metadata_json = {**metadata_json, **doc_metadata}
122124
j = json.dumps(metadata_json) if parse_json else metadata_json
123125
row.append(j) # type: ignore
126+
124127
return tuple(row)
125128

126129

@@ -169,15 +172,18 @@ def __init__(
169172
self.metadata_columns = metadata_columns
170173
self.format = format
171174
self.metadata_json_column = metadata_json_column
172-
formats = ["JSON", "text", "YAML", "CSV"]
173-
if self.format not in formats:
174-
raise Exception("Use one of 'text', 'JSON', 'YAML', 'CSV'.")
175175
self.databoost = databoost
176176
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
177177
self.staleness = staleness
178+
179+
formats = ["JSON", "text", "YAML", "CSV"]
180+
if self.format not in formats:
181+
raise Exception("Use one of 'text', 'JSON', 'YAML', 'CSV'.")
182+
178183
instance = self.client.instance(instance_id)
179184
if not instance.exists():
180185
raise Exception("Instance doesn't exist.")
186+
181187
database = instance.database(database_id)
182188
if not database.exists():
183189
raise Exception("Database doesn't exist.")
@@ -217,6 +223,7 @@ def lazy_load(self) -> Iterator[Document]:
217223
partitions = snapshot.generate_query_batches(
218224
sql=self.query, data_boost_enabled=self.databoost
219225
)
226+
220227
for partition in partitions:
221228
r = snapshot.process_query_batch(partition)
222229
results = r.to_dict_list()
@@ -227,6 +234,7 @@ def lazy_load(self) -> Iterator[Document]:
227234
metadata_columns = self.metadata_columns or [
228235
col for col in column_names if col not in content_columns
229236
]
237+
230238
for row in results:
231239
yield _load_row_to_doc(
232240
self.format,
@@ -248,6 +256,7 @@ def __init__(
248256
content_column: str = CONTENT_COL_NAME,
249257
metadata_columns: List[str] = [],
250258
metadata_json_column: str = METADATA_COL_NAME,
259+
primary_key: Optional[str] = None,
251260
client: Optional[Client] = None,
252261
):
253262
"""Initialize Spanner document saver.
@@ -268,23 +277,30 @@ def __init__(
268277
self.content_column = content_column
269278
self.metadata_columns = metadata_columns
270279
self.metadata_json_column = metadata_json_column
271-
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
280+
self.primary_key = primary_key or self.content_column
281+
self.client = client_with_user_agent(client, USER_AGENT_SAVER)
282+
272283
instance = self.client.instance(instance_id)
273284
if not instance.exists():
274285
raise Exception("Instance doesn't exist.")
286+
275287
database = instance.database(database_id)
276288
if not database.exists():
277289
raise Exception("Database doesn't exist.")
290+
278291
database.reload()
279292
self.dialect = database.database_dialect
293+
280294
table = database.table(table_name)
281295
if not table.exists():
282296
raise Exception(
283297
"Table doesn't exist. Create table with SpannerDocumentSaver.init_document_table function."
284298
)
285-
self._table_fields = [
286-
n.name for n in table.schema if n.name != metadata_json_column
287-
]
299+
300+
self._table_fields = [self.primary_key]
301+
for n in table.schema:
302+
if n.name != metadata_json_column and n.name != self.primary_key:
303+
self._table_fields.append(n.name)
288304
self._table_fields.append(metadata_json_column)
289305

290306
def add_documents(self, documents: List[Document]):
@@ -328,6 +344,7 @@ def delete(self, documents: List[Document]):
328344
keyset=docs_keys,
329345
partition_size_bytes=5000000,
330346
)
347+
331348
for partition in partitions:
332349
keys_to_delete = []
333350
for row in snapshot.process_read_batch(partition):
@@ -363,15 +380,18 @@ def init_document_table(
363380
Defaulted to true.
364381
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
365382
"""
383+
client = Client()
366384
primary_key = primary_key or content_column
367-
client = client_with_user_agent(None, USER_AGENT_LOADER)
368385
metadata_json_column = metadata_json_column if store_metadata else ""
386+
369387
instance = client.instance(instance_id)
370388
if not instance.exists():
371389
raise Exception("Instance doesn't exist.")
390+
372391
database = instance.database(database_id)
373392
if not database.exists():
374393
raise Exception("Database doesn't exist.")
394+
375395
# create table with custom schema
376396
SpannerDocumentSaver.create_table(
377397
client,
@@ -395,7 +415,19 @@ def create_table(
395415
content_column: str,
396416
metadata_columns: List[Column],
397417
):
398-
"""Create a new table in Spanner database."""
418+
"""
419+
Create a new table in Spanner database.
420+
421+
Args:
422+
client: The connection object to use.
423+
instance_id: The Spanner instance to load data to.
424+
database_id: The Spanner database to load data to.
425+
table_name: The table name to load data to.
426+
primary_key: The name of the primary key for the table.
427+
metadata_json_column: The name of the special JSON column.
428+
content_column: The name of the content column.
429+
metadata_columns: The metadata columns for custom schema.
430+
"""
399431
database = client.instance(instance_id).database(database_id)
400432
database.reload()
401433
dialect = database.database_dialect

0 commit comments

Comments
 (0)