Skip to content

Commit 0dadbb0

Browse files
committed
add document loader
1 parent f5349d7 commit 0dadbb0

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
from typing import List, Optional
17+
18+
from google.cloud.spanner import Client, KeySet
19+
from langchain_community.document_loaders.base import BaseLoader
20+
from langchain_core.documents import Document
21+
22+
CONTENT_COL_NAME = "page_content"
23+
METADATA_COL_NAME = "langchain_metadata"
24+
25+
26+
class SpannerLoader(BaseLoader):
27+
"""Loads data from Google CLoud Spanner."""
28+
29+
def __init__(
30+
self,
31+
instance: str,
32+
database: str,
33+
query: str,
34+
client: Optional[Client] = Client(),
35+
staleness: Optional[int] = 0,
36+
content_columns: List[str] = [],
37+
metadata_columns: List[str] = [],
38+
format: str = "text",
39+
databoost: bool = False,
40+
):
41+
"""Initialize Spanner document loader.
42+
43+
Args:
44+
instance: The Spanner instance to load data from.
45+
database: The Spanner database to load data from.
46+
query: A GoogleSQL or PostgreSQL query. Users must match dialect to their database.
47+
client: Optional. The connection object to use. This can be used to customize project id and credentials.
48+
staleness: Optional. The time bound for stale read.
49+
content_columns: The list of column(s) or field(s) to use for a Document's page content.
50+
Page content is the default field for embeddings generation.
51+
metadata_columns: The list of column(s) or field(s) to use for metadata.
52+
format: Set the format of page content if using multiple columns or fields.
53+
Format included: 'text', 'JSON', 'YAML', 'CSV'.
54+
databoost: Use data boost on read. Note: needs extra IAM permissions and higher cost.
55+
"""
56+
self.instance = instance
57+
self.database = database
58+
self.query = query
59+
self.client = client
60+
self.staleness = staleness
61+
self.content_columns = content_columns
62+
self.metadata_columns = metadata_columns
63+
self.format = format
64+
formats = ["JSON", "text", "YAML", "CSV"]
65+
if self.format not in formats:
66+
raise Exception("Use on of 'text', 'JSON', 'YAML', 'CSV'")
67+
self.databoost = databoost
68+
if not self.client.instance(self.instance).exists():
69+
raise Exception("Instance doesn't exist.")
70+
if not self.client.instance(self.instance).database(self.database).exists():
71+
raise Exception("Database doesn't exist.")
72+
73+
def load(self) -> List[Document]:
74+
"""Load documents."""
75+
return list(self.lazy_load())
76+
77+
def lazy_load(self) -> List[Document]:
78+
"""A lazy loader for Documents."""
79+
# use streaming to overcome response size limits
80+
81+
# use either generate_query_batches() with process_query_batch() to
82+
# create partitioned batch requests with Data Boost enabled and
83+
# streaming response.
84+
instance = self.client.instance(self.instance)
85+
db = instance.database(self.database)
86+
duration = datetime.timedelta(seconds=self.staleness)
87+
with db.snapshot(exact_staleness=duration) as snapshot:
88+
keyset = KeySet(all_=True)
89+
try:
90+
results = snapshot.execute_sql(
91+
sql=self.query, data_boost_enabled=self.databoost
92+
).to_dict_list()
93+
except:
94+
raise Exception("Fail to execute query")
95+
formatted_results = [self.load_row_to_document(row) for row in results]
96+
print(formatted_results)
97+
yield formatted_results
98+
99+
def load_row_to_document(self, row):
100+
include_headers = ["JSON", "YAML"]
101+
page_content = ""
102+
if self.format in include_headers:
103+
page_content = f"{CONTENT_COL_NAME}: "
104+
page_content += row[CONTENT_COL_NAME]
105+
if not page_content:
106+
raise Exception("column page_content doesn't exist.")
107+
108+
for c in self.content_columns:
109+
if self.format in include_headers:
110+
page_content += f"\n{c}: "
111+
page_content += f" {row[c]}"
112+
113+
if self.metadata_columns:
114+
metadata = {}
115+
for m in self.metadata_columns:
116+
metadata = {**metadata, **row[m]}
117+
else:
118+
metadata = row[METADATA_COL_NAME]
119+
for k in row:
120+
if (k != CONTENT_COL_NAME) and (k != METADATA_COL_NAME):
121+
metadata[k] = row[k]
122+
123+
return Document(page_content=page_content, metadata=metadata)

0 commit comments

Comments
 (0)