Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: Use Taskprocessing TextToText provider as LLM
Signed-off-by: Marcel Klehr <[email protected]>
  • Loading branch information
marcelklehr committed Jul 29, 2024
commit 36bf26bcf5e351bdec0f2322f2d889a55525b0ab
2 changes: 2 additions & 0 deletions config.cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ embedding:
device: cpu

llm:
nc_texttotext:

llama:
model_path: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
n_batch: 512
Expand Down
4 changes: 3 additions & 1 deletion config.gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ embedding:
device: cuda

llm:
nc_texttotext:

llama:
model_path: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
n_batch: 512
Expand Down Expand Up @@ -69,4 +71,4 @@ llm:
pipeline_kwargs:
config:
max_length: 200
template: ""
template: ""
2 changes: 1 addition & 1 deletion context_chat_backend/chain/query_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_pruned_query(llm: LLM, config: TConfig, query: str, template: str, text_
or llm_config.get('config', {}).get('max_new_tokens') \
or max(
llm_config.get('pipeline_kwargs', {}).get('config', {}).get('max_new_tokens', 0),
llm_config.get('pipeline_kwargs', {}).get('config', {}).get('max_length')
llm_config.get('pipeline_kwargs', {}).get('config', {}).get('max_length', 0)
) \
or 4096

Expand Down
3 changes: 2 additions & 1 deletion context_chat_backend/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def background_init(app: FastAPI):
for model_type in ('embedding', 'llm'):
model_name = _get_model_name_or_path(config, model_type)
if model_name is None:
raise Exception(f'Error: Model name/path not found for {model_type}')
update_progress(app, progress := progress + 50)
continue

if not _download_model(model_name):
raise Exception(f'Error: Model download failed for {model_name}')
Expand Down
2 changes: 1 addition & 1 deletion context_chat_backend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.schema.embeddings import Embeddings

_embedding_models = ['llama', 'hugging_face', 'instructor']
_llm_models = ['llama', 'hugging_face', 'ctransformer']
_llm_models = ['nc_texttotext', 'llama', 'hugging_face', 'ctransformer']

models = {
'embedding': _embedding_models,
Expand Down
83 changes: 83 additions & 0 deletions context_chat_backend/models/nc_texttotext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import time
from typing import Any, Dict, List, Optional

from nc_py_api import Nextcloud
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM

def get_model_for(model_type: str, model_config: dict):
if model_config is None:
return None

if model_type == 'llm':
return CustomLLM()

return None

class CustomLLM(LLM):
"""A custom chat model that queries Nextcloud's TextToText provider
"""

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.

Override this method to implement the LLM logic.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
nc = Nextcloud()

print(json.dumps(prompt))

response = nc.ocs("POST", "/ocs/v1.php/taskprocessing/schedule", json={
"type": "core:text2text",
"appId": "context_chat_backend",
"input": {
"input": prompt
}
})

task_id = response["task"]["id"]

while response['task']['status'] != 'STATUS_SUCCESSFUL' and response['task']['status'] != 'STATUS_FAILED':
time.sleep(5)
response = nc.ocs("GET", f"/ocs/v1.php/taskprocessing/task/{task_id}")
print(json.dumps(response))

if response['task']['status'] == 'STATUS_FAILED':
raise RuntimeError('Nextcloud TaskProcessing Task failed')

return response['task']['output']['output']

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "NextcloudTextToTextProvider",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "nc_texttotetx"
1 change: 1 addition & 0 deletions requirements.in.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ unstructured @ git+https://github.com/kyteinsky/unstructured@d3a404cfb541dae8e16
unstructured-client
weaviate-client
xlrd
nc_py_api
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ mpmath==1.3.0
msg-parser==1.2.0
multidict==6.0.5
mypy-extensions==1.0.0
nc-py-api==0.14.0
nest-asyncio==1.6.0
networkx==3.3
nltk==3.8.1
Expand Down Expand Up @@ -189,5 +190,6 @@ websockets==12.0
wrapt==1.16.0
xlrd==2.0.1
XlsxWriter==3.2.0
xmltodict==0.13.0
yarl==1.9.4
zipp==3.19.2