Skip to content

Commit 9b7b9c4

Browse files
committed
update loader formatting
1 parent 46b229a commit 9b7b9c4

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

src/langchain_google_spanner/loader.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,23 @@ def _load_row_to_doc(
4646
) -> Document:
4747
page_content = ""
4848
if format == "text":
49-
page_content = " ".join(str(row[c]) for c in content_columns if c in row)
49+
page_content = " ".join(str(row[c]) for c in content_columns)
5050
elif format == "YAML":
51-
page_content = "\n".join(
52-
f"{c}: {str(row[c])}" for c in content_columns if c in row
53-
)
51+
page_content = "\n".join(f"{c}: {str(row[c])}" for c in content_columns)
5452
elif format == "JSON":
5553
j = {}
5654
for c in content_columns:
57-
if c in row:
58-
j[c] = row[c]
55+
j[c] = row[c]
5956
page_content = json.dumps(j)
6057
elif format == "CSV":
61-
page_content = ", ".join(str(row[c]) for c in content_columns if c in row)
62-
if not page_content:
63-
raise Exception("column page_content doesn't exist.")
58+
page_content = ", ".join(str(row[c]) for c in content_columns)
6459

6560
metadata: Dict[str, Any] = {}
6661
if metadata_json_column in metadata_columns and row.get(metadata_json_column):
6762
metadata = row[metadata_json_column]
6863

6964
for c in metadata_columns:
70-
if c in row and c != metadata_json_column:
65+
if c != metadata_json_column:
7166
metadata[c] = row[c]
7267

7368
return Document(page_content=page_content, metadata=metadata)
@@ -273,9 +268,9 @@ def __init__(
273268
"Table doesn't exist. Create table with SpannerDocumentSaver.init_document_table function."
274269
)
275270
self._table_fields = [content_column]
276-
self._table_fields.append(
277-
[n.name for n in table.schema if n.name != metadata_json_column]
278-
)
271+
for n in table.schema:
272+
if n.name != metadata_json_column:
273+
self._table_fields.append(n.name)
279274
self._table_fields.append(metadata_json_column)
280275

281276
def add_documents(self, documents: List[Document]):

tests/test_spanner_loader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,18 @@ def test_loader_custom_format_error(self, client):
255255
client,
256256
format="NOT_A_FORMAT",
257257
)
258+
docs = loader.load()
259+
260+
def test_loader_custom_key_error(self, client):
261+
query = f"SELECT * FROM {table_name}"
262+
with pytest.raises(Exception):
263+
SpannerLoader(
264+
instance_id,
265+
google_database,
266+
query,
267+
client,
268+
content_columns=["NOT_A_COLUMN"],
269+
)
258270

259271
def test_loader_custom_json_metadata(self, client):
260272
database = client.instance(instance_id).database(google_database)

0 commit comments

Comments
 (0)