Skip to content

Commit 8315d9d

Browse files
committed
add user agent
1 parent 5cea61c commit 8315d9d

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

src/langchain_google_spanner/loader.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
import datetime
1616
import json
1717
from dataclasses import dataclass
18-
from typing import Any, Dict, Iterator, List, Union
18+
from typing import Any, Dict, Iterator, List, Optional, Union
1919

2020
from google.cloud.spanner import Client, KeySet # type: ignore
2121
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect # type: ignore
2222
from google.cloud.spanner_v1.data_types import JsonObject # type: ignore
2323
from langchain_community.document_loaders.base import BaseLoader
2424
from langchain_core.documents import Document
2525

26+
from .version import __version__
27+
28+
USER_AGENT_LOADER = "langchain-google-spanner-python:document_loader" + __version__
29+
2630
OPERATION_TIMEOUT_SECONDS = 240
2731
MUTATION_BATCH_SIZE = 1000
2832

@@ -37,6 +41,17 @@ class Column:
3741
nullable: bool = True
3842

3943

44+
def client_with_user_agent(client: Optional[Client], user_agent: str) -> Client:
45+
if not client:
46+
client = Client()
47+
client_agent = client._client_info.user_agent
48+
if not client_agent:
49+
client._client_info.user_agent = user_agent
50+
elif user_agent not in client_agent:
51+
client._client_info.user_agent = " ".join([client_agent, user_agent])
52+
return client
53+
54+
4055
def _load_row_to_doc(
4156
format: str,
4257
content_columns: List[str],
@@ -123,21 +138,20 @@ def __init__(
123138
instance_id: str,
124139
database_id: str,
125140
query: str,
126-
client: Client = Client(),
127141
content_columns: List[str] = [],
128142
metadata_columns: List[str] = [],
129143
format: str = "text",
130144
databoost: bool = False,
131145
metadata_json_column: str = METADATA_COL_NAME,
132146
staleness: Union[float, datetime.datetime] = 0.0,
147+
client: Optional[Client] = None,
133148
):
134149
"""Initialize Spanner document loader.
135150
136151
Args:
137152
instance_id: The Spanner instance to load data from.
138153
database_id: The Spanner database to load data from.
139154
query: A GoogleSQL or PostgreSQL query. Users must match dialect to their database.
140-
client: The connection object to use. This can be used to customize project id and credentials.
141155
content_columns: The list of column(s) or field(s) to use for a Document's page content.
142156
Page content is the default field for embeddings generation.
143157
metadata_columns: The list of column(s) or field(s) to use for metadata.
@@ -146,6 +160,7 @@ def __init__(
146160
databoost: Use data boost on read. Note: needs extra IAM permissions and higher cost.
147161
metadata_json_column: The name of the JSON column to use as the metadata's base dictionary.
148162
staleness: The time bound for stale read. Takes either a datetime or float.
163+
client: The connection object to use. This can be used to customize project id and credentials.
149164
"""
150165
self.instance_id = instance_id
151166
self.database_id = database_id
@@ -158,7 +173,7 @@ def __init__(
158173
if self.format not in formats:
159174
raise Exception("Use one of 'text', 'JSON', 'YAML', 'CSV'.")
160175
self.databoost = databoost
161-
self.client = client
176+
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
162177
self.staleness = staleness
163178
instance = self.client.instance(instance_id)
164179
if not instance.exists():
@@ -230,30 +245,30 @@ def __init__(
230245
instance_id: str,
231246
database_id: str,
232247
table_name: str,
233-
client: Client = Client(),
234248
content_column: str = CONTENT_COL_NAME,
235249
metadata_columns: List[str] = [],
236250
metadata_json_column: str = METADATA_COL_NAME,
251+
client: Optional[Client] = None,
237252
):
238253
"""Initialize Spanner document saver.
239254
240255
Args:
241256
instance_id: The Spanner instance to load data to.
242257
database_id: The Spanner database to load data to.
243258
table_name: The table name to load data to.
244-
client: The connection object to use. This can be used to customized project id and credentials.
245259
content_column: The name of the content column. Defaulted to the first column.
246260
metadata_columns: This is for user to opt-in a selection of columns to use. Defaulted to use
247261
all columns.
248262
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
263+
client: The connection object to use. This can be used to customized project id and credentials.
249264
"""
250265
self.instance_id = instance_id
251266
self.database_id = database_id
252267
self.table_name = table_name
253268
self.content_column = content_column
254269
self.metadata_columns = metadata_columns
255270
self.metadata_json_column = metadata_json_column
256-
self.client = client
271+
self.client = client_with_user_agent(client, USER_AGENT_LOADER)
257272
instance = self.client.instance(instance_id)
258273
if not instance.exists():
259274
raise Exception("Instance doesn't exist.")
@@ -349,7 +364,7 @@ def init_document_table(
349364
metadata_json_column: The name of the special JSON column. Defaulted to use "langchain_metadata".
350365
"""
351366
primary_key = primary_key or content_column
352-
client = Client()
367+
client = client_with_user_agent(None, USER_AGENT_LOADER)
353368
metadata_json_column = metadata_json_column if store_metadata else ""
354369
instance = client.instance(instance_id)
355370
if not instance.exists():

tests/test_spanner_loader.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def setup_database(self, client):
5858
instance_id,
5959
google_database,
6060
table_name,
61-
client,
61+
client=client,
6262
content_column="product_id",
6363
metadata_columns=["product_name", "description", "price", "dummy_col"],
6464
)
@@ -355,7 +355,7 @@ def test_loader_custom_format_error(self, client):
355355
instance_id,
356356
google_database,
357357
query,
358-
client,
358+
client=client,
359359
format="NOT_A_FORMAT",
360360
)
361361
docs = loader.load()
@@ -367,7 +367,7 @@ def test_loader_custom_content_key_error(self, client):
367367
instance_id,
368368
google_database,
369369
query,
370-
client,
370+
client=client,
371371
content_columns=["NOT_A_COLUMN"],
372372
)
373373
docs = loader.load()
@@ -379,7 +379,7 @@ def test_loader_custom_metadata_key_error(self, client):
379379
instance_id,
380380
google_database,
381381
query,
382-
client,
382+
client=client,
383383
metadata_columns=["NOT_A_COLUMN"],
384384
)
385385
docs = loader.load()
@@ -405,7 +405,7 @@ def test_loader_custom_json_metadata(self, client):
405405
instance_id,
406406
google_database,
407407
table_name,
408-
client,
408+
client=client,
409409
content_column="product_id",
410410
metadata_columns=["product_name", "description", "price"],
411411
metadata_json_column="my_metadata",
@@ -471,7 +471,7 @@ def setup_database(self, client):
471471
instance_id,
472472
pg_database,
473473
table_name,
474-
client,
474+
client=client,
475475
content_column="product_id",
476476
metadata_columns=["product_name", "description", "price", "dummy_col"],
477477
)
@@ -768,7 +768,7 @@ def test_loader_custom_format_error(self, client):
768768
instance_id,
769769
pg_database,
770770
query,
771-
client,
771+
client=client,
772772
format="NOT_A_FORMAT",
773773
)
774774

@@ -779,7 +779,7 @@ def test_loader_custom_content_key_error(self, client):
779779
instance_id,
780780
pg_database,
781781
query,
782-
client,
782+
client=client,
783783
content_columns=["NOT_A_COLUMN"],
784784
)
785785
docs = loader.load()
@@ -791,7 +791,7 @@ def test_loader_custom_metadata_key_error(self, client):
791791
instance_id,
792792
pg_database,
793793
query,
794-
client,
794+
client=client,
795795
metadata_columns=["NOT_A_COLUMN"],
796796
)
797797
docs = loader.load()
@@ -817,7 +817,7 @@ def test_loader_custom_json_metadata(self, client):
817817
instance_id,
818818
pg_database,
819819
table_name,
820-
client,
820+
client=client,
821821
content_column="product_id",
822822
metadata_columns=["product_name", "description", "price"],
823823
metadata_json_column="my_metadata",
@@ -881,7 +881,7 @@ def test_saver_google_sql(self, google_client):
881881
instance_id, google_database, table_name
882882
)
883883
saver = SpannerDocumentSaver(
884-
instance_id, google_database, table_name, google_client
884+
instance_id, google_database, table_name, client=google_client
885885
)
886886
query = f"SELECT * FROM {table_name}"
887887
loader = SpannerLoader(
@@ -901,7 +901,9 @@ def test_saver_google_sql(self, google_client):
901901

902902
def test_saver_pg(self, pg_client):
903903
SpannerDocumentSaver.init_document_table(instance_id, pg_database, table_name)
904-
saver = SpannerDocumentSaver(instance_id, pg_database, table_name, pg_client)
904+
saver = SpannerDocumentSaver(
905+
instance_id, pg_database, table_name, client=pg_client
906+
)
905907
query = f"SELECT * FROM {table_name}"
906908
loader = SpannerLoader(
907909
client=pg_client,
@@ -935,7 +937,7 @@ def test_saver_google_sql_with_custom_schema(self, google_client):
935937
instance_id,
936938
google_database,
937939
table_name,
938-
google_client,
940+
client=google_client,
939941
content_column="my_page_content",
940942
)
941943
query = f"SELECT * FROM {table_name}"
@@ -981,7 +983,7 @@ def test_saver_pg_with_custom_schema(self, pg_client):
981983
instance_id,
982984
pg_database,
983985
table_name,
984-
pg_client,
986+
client=pg_client,
985987
content_column="my_page_content",
986988
)
987989
query = f"SELECT * FROM {table_name}"
@@ -1015,7 +1017,7 @@ def test_delete(self, google_client):
10151017
instance_id, google_database, table_name
10161018
)
10171019
saver = SpannerDocumentSaver(
1018-
instance_id, google_database, table_name, google_client
1020+
instance_id, google_database, table_name, client=google_client
10191021
)
10201022
query = f"SELECT * FROM {table_name}"
10211023
loader = SpannerLoader(
@@ -1040,7 +1042,7 @@ def test_saver_with_bad_docs(self, google_client):
10401042
instance_id, google_database, table_name
10411043
)
10421044
saver = SpannerDocumentSaver(
1043-
instance_id, google_database, table_name, google_client
1045+
instance_id, google_database, table_name, client=google_client
10441046
)
10451047
with pytest.raises(Exception):
10461048
saver.add_documents([1, 2, 3])

0 commit comments

Comments
 (0)