Skip to content

Commit 742be06

Browse files
authored
Fix/localai (langgenius#2840)
1 parent af98954 commit 742be06

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

api/core/model_runtime/model_providers/localai/llm/llm.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import Generator
22
from typing import cast
3-
from urllib.parse import urljoin
43

54
from httpx import Timeout
65
from openai import (
@@ -19,6 +18,7 @@
1918
from openai.types.chat import ChatCompletion, ChatCompletionChunk
2019
from openai.types.chat.chat_completion_message import FunctionCall
2120
from openai.types.completion import Completion
21+
from yarl import URL
2222

2323
from core.model_runtime.entities.common_entities import I18nObject
2424
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
@@ -181,7 +181,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
181181
UserPromptMessage(content='ping')
182182
], model_parameters={
183183
'max_tokens': 10,
184-
}, stop=[])
184+
}, stop=[], stream=False)
185185
except Exception as ex:
186186
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
187187

@@ -227,14 +227,20 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
227227
)
228228
]
229229

230+
model_properties = {
231+
ModelPropertyKey.MODE: completion_model,
232+
} if completion_model else {}
233+
234+
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
235+
230236
entity = AIModelEntity(
231237
model=model,
232238
label=I18nObject(
233239
en_US=model
234240
),
235241
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
236242
model_type=ModelType.LLM,
237-
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
243+
model_properties=model_properties,
238244
parameter_rules=rules
239245
)
240246

@@ -319,7 +325,7 @@ def _to_client_kwargs(self, credentials: dict) -> dict:
319325
client_kwargs = {
320326
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
321327
"api_key": "1",
322-
"base_url": urljoin(credentials['server_url'], 'v1'),
328+
"base_url": str(URL(credentials['server_url']) / 'v1'),
323329
}
324330

325331
return client_kwargs

api/core/model_runtime/model_providers/localai/localai.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,12 @@ model_credential_schema:
5656
placeholder:
5757
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
5858
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
59+
- variable: context_size
60+
label:
61+
zh_Hans: 上下文大小
62+
en_US: Context size
63+
placeholder:
64+
zh_Hans: 输入上下文大小
65+
en_US: Enter context size
66+
required: false
67+
type: text-input

api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import time
22
from json import JSONDecodeError, dumps
3-
from os.path import join
43
from typing import Optional
54

65
from requests import post
6+
from yarl import URL
77

8-
from core.model_runtime.entities.model_entities import PriceType
8+
from core.model_runtime.entities.common_entities import I18nObject
9+
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
910
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
1011
from core.model_runtime.errors.invoke import (
1112
InvokeAuthorizationError,
@@ -57,7 +58,7 @@ def _invoke(self, model: str, credentials: dict,
5758
}
5859

5960
try:
60-
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
61+
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
6162
except Exception as e:
6263
raise InvokeConnectionError(str(e))
6364

@@ -113,6 +114,27 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int
113114
# use GPT2Tokenizer to get num tokens
114115
num_tokens += self._get_num_tokens_by_gpt2(text)
115116
return num_tokens
117+
118+
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
119+
"""
120+
Get customizable model schema
121+
122+
:param model: model name
123+
:param credentials: model credentials
124+
:return: model schema
125+
"""
126+
return AIModelEntity(
127+
model=model,
128+
label=I18nObject(zh_Hans=model, en_US=model),
129+
model_type=ModelType.TEXT_EMBEDDING,
130+
features=[],
131+
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
132+
model_properties={
133+
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
134+
ModelPropertyKey.MAX_CHUNKS: 1,
135+
},
136+
parameter_rules=[]
137+
)
116138

117139
def validate_credentials(self, model: str, credentials: dict) -> None:
118140
"""

0 commit comments

Comments
 (0)