Skip to content

Commit b78e09b

Browse files
committed
add option to select model
1 parent c7ec8c7 commit b78e09b

File tree

4 files changed

+41
-31
lines changed

4 files changed

+41
-31
lines changed

knowledge_gpt/core/qa.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import Any, List
1+
from typing import List
22
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
33
from knowledge_gpt.core.prompts import STUFF_PROMPT
44
from langchain.docstore.document import Document
5-
from langchain.chat_models import ChatOpenAI
65
from knowledge_gpt.core.embedding import FolderIndex
7-
from knowledge_gpt.core.debug import FakeChatModel
86
from pydantic import BaseModel
7+
from langchain.chat_models.base import BaseChatModel
98

109

1110
class AnswerWithSources(BaseModel):
@@ -16,9 +15,8 @@ class AnswerWithSources(BaseModel):
1615
def query_folder(
1716
query: str,
1817
folder_index: FolderIndex,
18+
llm: BaseChatModel,
1919
return_all: bool = False,
20-
model: str = "openai",
21-
**model_kwargs: Any,
2220
) -> AnswerWithSources:
2321
"""Queries a folder index for an answer.
2422
@@ -33,15 +31,6 @@ def query_folder(
3331
Returns:
3432
AnswerWithSources: The answer and the source documents.
3533
"""
36-
supported_models = {
37-
"openai": ChatOpenAI,
38-
"debug": FakeChatModel,
39-
}
40-
41-
if model in supported_models:
42-
llm = supported_models[model](**model_kwargs)
43-
else:
44-
raise ValueError(f"Model {model} not supported.")
4534

4635
chain = load_qa_with_sources_chain(
4736
llm=llm,
@@ -73,5 +62,4 @@ def get_sources(answer: str, folder_index: FolderIndex) -> List[Document]:
7362
for doc in file.docs:
7463
if doc.metadata["source"] in source_keys:
7564
source_docs.append(doc)
76-
7765
return source_docs

knowledge_gpt/core/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
33
from langchain.docstore.document import Document
44

5+
from langchain.chat_models import ChatOpenAI
6+
from knowledge_gpt.core.debug import FakeChatModel
7+
from langchain.chat_models.base import BaseChatModel
8+
59

610
def pop_docs_upto_limit(
711
query: str, chain: StuffDocumentsChain, docs: List[Document], max_len: int
@@ -16,3 +20,13 @@ def pop_docs_upto_limit(
1620
token_count = chain.prompt_length(docs, question=query) # type: ignore
1721

1822
return docs
23+
24+
25+
def get_llm(model: str, **kwargs) -> BaseChatModel:
26+
if model == "debug":
27+
return FakeChatModel()
28+
29+
if "gpt" in model:
30+
return ChatOpenAI(model=model, **kwargs) # type: ignore
31+
32+
raise ValueError(f"Model {model} not supported!")

knowledge_gpt/main.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from knowledge_gpt.core.chunking import chunk_file
1717
from knowledge_gpt.core.embedding import embed_files
1818
from knowledge_gpt.core.qa import query_folder
19+
from knowledge_gpt.core.utils import get_llm
20+
1921

2022
EMBEDDING = "openai"
2123
VECTOR_STORE = "faiss"
22-
MODEL = "openai"
24+
MODEL_LIST = ["gpt-3.5-turbo", "gpt-4"]
2325

24-
# For testing
25-
EMBEDDING, VECTOR_STORE, MODEL = ["debug"] * 3
26+
# Uncomment to enable debug mode
27+
# MODEL_LIST.insert(0, "debug")
2628

2729
st.set_page_config(page_title="KnowledgeGPT", page_icon="📖", layout="wide")
2830
st.header("📖KnowledgeGPT")
@@ -48,6 +50,13 @@
4850
help="Scanned documents are not supported yet!",
4951
)
5052

53+
model: str = st.selectbox("Model", options=MODEL_LIST) # type: ignore
54+
55+
with st.expander("Advanced Options"):
56+
return_all_chunks = st.checkbox("Show all chunks retrieved from vector search")
57+
show_full_doc = st.checkbox("Show parsed contents of the document")
58+
59+
5160
if not uploaded_file:
5261
st.stop()
5362

@@ -61,15 +70,16 @@
6170
if not is_file_valid(file):
6271
st.stop()
6372

64-
if MODEL != "debug" and not is_open_ai_key_valid(openai_api_key):
73+
74+
if not is_open_ai_key_valid(openai_api_key, model):
6575
st.stop()
6676

6777

6878
with st.spinner("Indexing document... This may take a while⏳"):
6979
folder_index = embed_files(
7080
files=[chunked_file],
71-
embedding=EMBEDDING,
72-
vector_store=VECTOR_STORE,
81+
embedding=EMBEDDING if model != "debug" else "debug",
82+
vector_store=VECTOR_STORE if model != "debug" else "debug",
7383
openai_api_key=openai_api_key,
7484
)
7585

@@ -78,11 +88,6 @@
7888
submit = st.form_submit_button("Submit")
7989

8090

81-
with st.expander("Advanced Options"):
82-
return_all_chunks = st.checkbox("Show all chunks retrieved from vector search")
83-
show_full_doc = st.checkbox("Show parsed contents of the document")
84-
85-
8691
if show_full_doc:
8792
with st.expander("Document"):
8893
# Hack to get around st.markdown rendering LaTeX
@@ -96,13 +101,12 @@
96101
# Output Columns
97102
answer_col, sources_col = st.columns(2)
98103

104+
llm = get_llm(model=model, openai_api_key=openai_api_key, temperature=0)
99105
result = query_folder(
100106
folder_index=folder_index,
101107
query=query,
102108
return_all=return_all_chunks,
103-
model=MODEL,
104-
openai_api_key=openai_api_key,
105-
temperature=0,
109+
llm=llm,
106110
)
107111

108112
with answer_col:

knowledge_gpt/ui.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,22 @@ def display_file_read_error(e: Exception) -> NoReturn:
4343

4444

4545
@st.cache_data(show_spinner=False)
46-
def is_open_ai_key_valid(openai_api_key) -> bool:
46+
def is_open_ai_key_valid(openai_api_key, model: str) -> bool:
47+
if model == "debug":
48+
return True
49+
4750
if not openai_api_key:
4851
st.error("Please enter your OpenAI API key in the sidebar!")
4952
return False
5053
try:
5154
openai.ChatCompletion.create(
52-
model="gpt-3.5-turbo",
55+
model=model,
5356
messages=[{"role": "user", "content": "test"}],
5457
api_key=openai_api_key,
5558
)
5659
except Exception as e:
5760
st.error(f"{e.__class__.__name__}: {e}")
5861
logger.error(f"{e.__class__.__name__}: {e}")
5962
return False
63+
6064
return True

0 commit comments

Comments
 (0)