Skip to content

Commit 4b7f366

Browse files
committed
Merge branch 'dev'
2 parents ecff70f + f82314b commit 4b7f366

32 files changed

+769
-254
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Local data
2-
data/local_data/
2+
resources/local_data/
33

44
# Prototyping notebooks
55
notebooks/

data/paul_graham_essay.docx

-48.1 KB
Binary file not shown.

knowledge_gpt/components/sidebar.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
load_dotenv()
88

99

10-
def set_openai_api_key(api_key: str):
11-
st.session_state["OPENAI_API_KEY"] = api_key
12-
13-
1410
def sidebar():
1511
with st.sidebar:
1612
st.markdown(
@@ -28,8 +24,7 @@ def sidebar():
2824
or st.session_state.get("OPENAI_API_KEY", ""),
2925
)
3026

31-
if api_key_input:
32-
set_openai_api_key(api_key_input)
27+
st.session_state["OPENAI_API_KEY"] = api_key_input
3328

3429
st.markdown("---")
3530
st.markdown("# About")

knowledge_gpt/core/caching.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import streamlit as st
2+
from streamlit.runtime.caching.hashing import HashFuncsDict
3+
4+
import knowledge_gpt.core.parsing as parsing
5+
import knowledge_gpt.core.chunking as chunking
6+
import knowledge_gpt.core.embedding as embedding
7+
from knowledge_gpt.core.parsing import File
8+
9+
10+
def file_hash_func(file: File) -> str:
11+
"""Get a unique hash for a file"""
12+
return file.id
13+
14+
15+
@st.cache_resource()
16+
def bootstrap_caching():
17+
"""Patch module functions with caching"""
18+
19+
# Get all substypes of File from module
20+
file_subtypes = [
21+
cls
22+
for cls in vars(parsing).values()
23+
if isinstance(cls, type) and issubclass(cls, File) and cls != File
24+
]
25+
file_hash_funcs: HashFuncsDict = {cls: file_hash_func for cls in file_subtypes}
26+
27+
parsing.read_file = st.cache_data(show_spinner=False)(parsing.read_file)
28+
chunking.chunk_file = st.cache_data(show_spinner=False, hash_funcs=file_hash_funcs)(
29+
chunking.chunk_file
30+
)
31+
embedding.embed_files = st.cache_data(
32+
show_spinner=False, hash_funcs=file_hash_funcs
33+
)(embedding.embed_files)

knowledge_gpt/core/chunking.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from langchain.docstore.document import Document
2+
from langchain.text_splitter import RecursiveCharacterTextSplitter
3+
from knowledge_gpt.core.parsing import File
4+
5+
6+
def chunk_file(
7+
file: File, chunk_size: int, chunk_overlap: int = 0, model_name="gpt-3.5-turbo"
8+
) -> File:
9+
"""Chunks each document in a file into smaller documents
10+
according to the specified chunk size and overlap
11+
where the size is determined by the number of token for the specified model.
12+
"""
13+
14+
# split each document into chunks
15+
chunked_docs = []
16+
for doc in file.docs:
17+
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
18+
model_name=model_name,
19+
chunk_size=chunk_size,
20+
chunk_overlap=chunk_overlap,
21+
)
22+
23+
chunks = text_splitter.split_text(doc.page_content)
24+
25+
for i, chunk in enumerate(chunks):
26+
doc = Document(
27+
page_content=chunk,
28+
metadata={
29+
"page": doc.metadata.get("page", 1),
30+
"chunk": i + 1,
31+
"source": f"{doc.metadata.get('page', 1)}-{i + 1}",
32+
},
33+
)
34+
chunked_docs.append(doc)
35+
36+
chunked_file = file.copy()
37+
chunked_file.docs = chunked_docs
38+
return chunked_file

knowledge_gpt/core/embedding.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from langchain.vectorstores import VectorStore
2+
from knowledge_gpt.core.parsing import File
3+
from langchain.vectorstores.faiss import FAISS
4+
from langchain.embeddings import OpenAIEmbeddings
5+
from langchain.embeddings.base import Embeddings
6+
from typing import List, Type
7+
from langchain.docstore.document import Document
8+
9+
10+
class FolderIndex:
11+
"""Index for a collection of files (a folder)"""
12+
13+
def __init__(self, files: List[File], index: VectorStore):
14+
self.name: str = "default"
15+
self.files = files
16+
self.index: VectorStore = index
17+
18+
@staticmethod
19+
def _combine_files(files: List[File]) -> List[Document]:
20+
"""Combines all the documents in a list of files into a single list."""
21+
22+
all_texts = []
23+
for file in files:
24+
for doc in file.docs:
25+
doc.metadata["file_name"] = file.name
26+
doc.metadata["file_id"] = file.id
27+
all_texts.append(doc)
28+
29+
return all_texts
30+
31+
@classmethod
32+
def from_files(
33+
cls, files: List[File], embeddings: Embeddings, vector_store: Type[VectorStore]
34+
) -> "FolderIndex":
35+
"""Creates an index from files."""
36+
37+
all_docs = cls._combine_files(files)
38+
39+
index = vector_store.from_documents(
40+
documents=all_docs,
41+
embedding=embeddings,
42+
)
43+
44+
return cls(files=files, index=index)
45+
46+
47+
def embed_files(
48+
files: List[File], embedding: str, vector_store: str, **kwargs
49+
) -> FolderIndex:
50+
"""Embeds a collection of files and stores them in a FolderIndex."""
51+
52+
supported_embeddings: dict[str, Type[Embeddings]] = {
53+
"openai": OpenAIEmbeddings,
54+
}
55+
supported_vector_stores: dict[str, Type[VectorStore]] = {
56+
"faiss": FAISS,
57+
}
58+
59+
if embedding in supported_embeddings:
60+
_embeddings = supported_embeddings[embedding](**kwargs)
61+
else:
62+
raise NotImplementedError(f"Embedding {embedding} not supported.")
63+
64+
if vector_store in supported_vector_stores:
65+
_vector_store = supported_vector_stores[vector_store]
66+
else:
67+
raise NotImplementedError(f"Vector store {vector_store} not supported.")
68+
69+
return FolderIndex.from_files(
70+
files=files, embeddings=_embeddings, vector_store=_vector_store
71+
)

knowledge_gpt/core/parsing.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from io import BytesIO
2+
from typing import List, Any, Optional
3+
import re
4+
5+
import docx2txt
6+
from langchain.docstore.document import Document
7+
from pypdf import PdfReader
8+
from hashlib import md5
9+
10+
from abc import abstractmethod, ABC
11+
from copy import deepcopy
12+
13+
14+
class File(ABC):
15+
"""Represents an uploaded file comprised of Documents"""
16+
17+
def __init__(
18+
self,
19+
name: str,
20+
id: str,
21+
metadata: Optional[dict[str, Any]] = None,
22+
docs: Optional[List[Document]] = None,
23+
):
24+
self.name = name
25+
self.id = id
26+
self.metadata = metadata or {}
27+
self.docs = docs or []
28+
29+
@classmethod
30+
@abstractmethod
31+
def from_bytes(cls, file: BytesIO) -> "File":
32+
"""Creates a File from a BytesIO object"""
33+
34+
def __repr__(self) -> str:
35+
return (
36+
f"File(name={self.name}, id={self.id},"
37+
" metadata={self.metadata}, docs={self.docs})"
38+
)
39+
40+
def __str__(self) -> str:
41+
return f"File(name={self.name}, id={self.id}, metadata={self.metadata})"
42+
43+
def copy(self) -> "File":
44+
"""Create a deep copy of this File"""
45+
return self.__class__(
46+
name=self.name,
47+
id=self.id,
48+
metadata=deepcopy(self.metadata),
49+
docs=deepcopy(self.docs),
50+
)
51+
52+
53+
def strip_consecutive_newlines(text: str) -> str:
54+
"""Strips consecutive newlines from a string
55+
possibly with whitespace in between
56+
"""
57+
return re.sub(r"\s*\n\s*", "\n", text)
58+
59+
60+
class DocxFile(File):
61+
@classmethod
62+
def from_bytes(cls, file: BytesIO) -> "DocxFile":
63+
text = docx2txt.process(file)
64+
text = strip_consecutive_newlines(text)
65+
doc = Document(page_content=text.strip())
66+
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])
67+
68+
69+
class PdfFile(File):
70+
@classmethod
71+
def from_bytes(cls, file: BytesIO) -> "PdfFile":
72+
pdf = PdfReader(file)
73+
docs = []
74+
for i, page in enumerate(pdf.pages):
75+
text = page.extract_text()
76+
text = strip_consecutive_newlines(text)
77+
doc = Document(page_content=text.strip())
78+
doc.metadata["page"] = i + 1
79+
docs.append(doc)
80+
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=docs)
81+
82+
83+
class TxtFile(File):
84+
@classmethod
85+
def from_bytes(cls, file: BytesIO) -> "TxtFile":
86+
text = file.read().decode("utf-8")
87+
text = strip_consecutive_newlines(text)
88+
file.seek(0)
89+
doc = Document(page_content=text.strip())
90+
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])
91+
92+
93+
def read_file(file: BytesIO) -> File:
94+
"""Reads an uploaded file and returns a File object"""
95+
if file.name.endswith(".docx"):
96+
return DocxFile.from_bytes(file)
97+
elif file.name.endswith(".pdf"):
98+
return PdfFile.from_bytes(file)
99+
elif file.name.endswith(".txt"):
100+
return TxtFile.from_bytes(file)
101+
else:
102+
raise NotImplementedError

knowledge_gpt/core/qa.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import Any, List
2+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
3+
from knowledge_gpt.core.prompts import STUFF_PROMPT
4+
from langchain.docstore.document import Document
5+
from langchain.chat_models import ChatOpenAI
6+
from knowledge_gpt.core.embedding import FolderIndex
7+
from pydantic import BaseModel
8+
9+
10+
class AnswerWithSources(BaseModel):
11+
answer: str
12+
sources: List[Document]
13+
14+
15+
def query_folder(
16+
query: str, folder_index: FolderIndex, return_all: bool = False, **model_kwargs: Any
17+
) -> AnswerWithSources:
18+
"""Queries a folder index for an answer.
19+
20+
Args:
21+
query (str): The query to search for.
22+
folder_index (FolderIndex): The folder index to search.
23+
return_all (bool): Whether to return all the documents from the embedding or
24+
just the sources for the answer.
25+
**model_kwargs (Any): Keyword arguments for the model.
26+
27+
Returns:
28+
AnswerWithSources: The answer and the source documents.
29+
"""
30+
31+
chain = load_qa_with_sources_chain(
32+
llm=ChatOpenAI(**model_kwargs),
33+
chain_type="stuff",
34+
prompt=STUFF_PROMPT,
35+
)
36+
37+
relevant_docs = folder_index.index.similarity_search(query, k=5)
38+
result = chain(
39+
{"input_documents": relevant_docs, "question": query}, return_only_outputs=True
40+
)
41+
sources = relevant_docs
42+
43+
if not return_all:
44+
sources = get_sources(result["output_text"], folder_index)
45+
46+
answer = result["output_text"].split("SOURCES: ")[0]
47+
48+
return AnswerWithSources(answer=answer, sources=sources)
49+
50+
51+
def get_sources(answer: str, folder_index: FolderIndex) -> List[Document]:
52+
"""Retrieves the docs that were used to answer the question the generated answer."""
53+
54+
source_keys = [s for s in answer.split("SOURCES: ")[-1].split(", ")]
55+
56+
source_docs = []
57+
for file in folder_index.files:
58+
for doc in file.docs:
59+
if doc.metadata["source"] in source_keys:
60+
source_docs.append(doc)
61+
62+
return source_docs

0 commit comments

Comments
 (0)