2626from .version import __version__
2727
2828USER_AGENT_LOADER = "langchain-google-spanner-python:document_loader" + __version__
29+ USER_AGENT_SAVER = "langchain-google-spanner-python:document_saver" + __version__
2930
3031OPERATION_TIMEOUT_SECONDS = 240
3132MUTATION_BATCH_SIZE = 1000
@@ -44,6 +45,7 @@ class Column:
4445def client_with_user_agent (client : Optional [Client ], user_agent : str ) -> Client :
4546 if not client :
4647 client = Client ()
48+
4749 client_agent = client ._client_info .user_agent
4850 if not client_agent :
4951 client ._client_info .user_agent = user_agent
@@ -75,7 +77,6 @@ def _load_row_to_doc(
7577 metadata : Dict [str , Any ] = {}
7678 if metadata_json_column in metadata_columns and row .get (metadata_json_column ):
7779 metadata = row [metadata_json_column ]
78-
7980 for c in metadata_columns :
8081 if c != metadata_json_column :
8182 metadata [c ] = row [c ]
@@ -113,6 +114,7 @@ def _load_doc_to_row(
113114 del doc_metadata [col ]
114115 if col == content_column :
115116 row .append (doc .page_content )
117+
116118 if metadata_json_column in table_fields :
117119 metadata_json = {}
118120 if metadata_json_column in doc_metadata :
@@ -121,6 +123,7 @@ def _load_doc_to_row(
121123 metadata_json = {** metadata_json , ** doc_metadata }
122124 j = json .dumps (metadata_json ) if parse_json else metadata_json
123125 row .append (j ) # type: ignore
126+
124127 return tuple (row )
125128
126129
@@ -169,15 +172,18 @@ def __init__(
169172 self .metadata_columns = metadata_columns
170173 self .format = format
171174 self .metadata_json_column = metadata_json_column
172- formats = ["JSON" , "text" , "YAML" , "CSV" ]
173- if self .format not in formats :
174- raise Exception ("Use one of 'text', 'JSON', 'YAML', 'CSV'." )
175175 self .databoost = databoost
176176 self .client = client_with_user_agent (client , USER_AGENT_LOADER )
177177 self .staleness = staleness
178+
179+ formats = ["JSON" , "text" , "YAML" , "CSV" ]
180+ if self .format not in formats :
181+ raise Exception ("Use one of 'text', 'JSON', 'YAML', 'CSV'." )
182+
178183 instance = self .client .instance (instance_id )
179184 if not instance .exists ():
180185 raise Exception ("Instance doesn't exist." )
186+
181187 database = instance .database (database_id )
182188 if not database .exists ():
183189 raise Exception ("Database doesn't exist." )
@@ -217,6 +223,7 @@ def lazy_load(self) -> Iterator[Document]:
217223 partitions = snapshot .generate_query_batches (
218224 sql = self .query , data_boost_enabled = self .databoost
219225 )
226+
220227 for partition in partitions :
221228 r = snapshot .process_query_batch (partition )
222229 results = r .to_dict_list ()
@@ -227,6 +234,7 @@ def lazy_load(self) -> Iterator[Document]:
227234 metadata_columns = self .metadata_columns or [
228235 col for col in column_names if col not in content_columns
229236 ]
237+
230238 for row in results :
231239 yield _load_row_to_doc (
232240 self .format ,
@@ -248,6 +256,7 @@ def __init__(
248256 content_column : str = CONTENT_COL_NAME ,
249257 metadata_columns : List [str ] = [],
250258 metadata_json_column : str = METADATA_COL_NAME ,
259+ primary_key : Optional [str ] = None ,
251260 client : Optional [Client ] = None ,
252261 ):
253262 """Initialize Spanner document saver.
@@ -268,23 +277,30 @@ def __init__(
268277 self .content_column = content_column
269278 self .metadata_columns = metadata_columns
270279 self .metadata_json_column = metadata_json_column
271- self .client = client_with_user_agent (client , USER_AGENT_LOADER )
280+ self .primary_key = primary_key or self .content_column
281+ self .client = client_with_user_agent (client , USER_AGENT_SAVER )
282+
272283 instance = self .client .instance (instance_id )
273284 if not instance .exists ():
274285 raise Exception ("Instance doesn't exist." )
286+
275287 database = instance .database (database_id )
276288 if not database .exists ():
277289 raise Exception ("Database doesn't exist." )
290+
278291 database .reload ()
279292 self .dialect = database .database_dialect
293+
280294 table = database .table (table_name )
281295 if not table .exists ():
282296 raise Exception (
283297 "Table doesn't exist. Create table with SpannerDocumentSaver.init_document_table function."
284298 )
285- self ._table_fields = [
286- n .name for n in table .schema if n .name != metadata_json_column
287- ]
299+
300+ self ._table_fields = [self .primary_key ]
301+ for n in table .schema :
302+ if n .name != metadata_json_column and n .name != self .primary_key :
303+ self ._table_fields .append (n .name )
288304 self ._table_fields .append (metadata_json_column )
289305
290306 def add_documents (self , documents : List [Document ]):
@@ -328,6 +344,7 @@ def delete(self, documents: List[Document]):
328344 keyset = docs_keys ,
329345 partition_size_bytes = 5000000 ,
330346 )
347+
331348 for partition in partitions :
332349 keys_to_delete = []
333350 for row in snapshot .process_read_batch (partition ):
@@ -363,15 +380,18 @@ def init_document_table(
363380 Defaulted to true.
364381 metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
365382 """
383+ client = Client ()
366384 primary_key = primary_key or content_column
367- client = client_with_user_agent (None , USER_AGENT_LOADER )
368385 metadata_json_column = metadata_json_column if store_metadata else ""
386+
369387 instance = client .instance (instance_id )
370388 if not instance .exists ():
371389 raise Exception ("Instance doesn't exist." )
390+
372391 database = instance .database (database_id )
373392 if not database .exists ():
374393 raise Exception ("Database doesn't exist." )
394+
375395 # create table with custom schema
376396 SpannerDocumentSaver .create_table (
377397 client ,
@@ -395,7 +415,19 @@ def create_table(
395415 content_column : str ,
396416 metadata_columns : List [Column ],
397417 ):
398- """Create a new table in Spanner database."""
418+ """
419+ Create a new table in Spanner database.
420+
421+ Args:
422+ client: The connection object to use.
423+ instance_id: The Spanner instance to load data to.
424+ database_id: The Spanner database to load data to.
425+ table_name: The table name to load data to.
426+ primary_key: The name of the primary key for the table.
427+ metadata_json_column: The name of the special JSON column.
428+ content_column: The name of the content column.
429+ metadata_columns: The metadata columns for custom schema.
430+ """
399431 database = client .instance (instance_id ).database (database_id )
400432 database .reload ()
401433 dialect = database .database_dialect
0 commit comments