-
Notifications
You must be signed in to change notification settings - Fork 8.2k
fix: Add IBM watsonx.ai support to EmbeddingModel #10677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
55cf5d9
d927cc6
ce8aabd
ec14524
9fad275
7079625
4dc4ec9
5444d54
04f4a78
9f3541f
968a006
b2ee642
3f93924
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,16 @@ | ||
| from typing import Any | ||
|
|
||
| import requests | ||
| from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames | ||
| from langchain_openai import OpenAIEmbeddings | ||
|
|
||
| from lfx.base.embeddings.model import LCEmbeddingsModel | ||
| from lfx.base.models.model_utils import get_ollama_models, is_valid_ollama_url | ||
| from lfx.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES | ||
| from lfx.base.models.watsonx_constants import IBM_WATSONX_URLS, WATSONX_EMBEDDING_MODEL_NAMES | ||
| from lfx.base.models.watsonx_constants import ( | ||
| IBM_WATSONX_URLS, | ||
| WATSONX_EMBEDDING_MODEL_NAMES, | ||
| ) | ||
| from lfx.field_typing import Embeddings | ||
| from lfx.io import ( | ||
| BoolInput, | ||
|
|
@@ -77,6 +82,8 @@ class EmbeddingModelComponent(LCEmbeddingsModel): | |
| options=OPENAI_EMBEDDING_MODEL_NAMES, | ||
| value=OPENAI_EMBEDDING_MODEL_NAMES[0], | ||
| info="Select the embedding model to use", | ||
| real_time_refresh=True, | ||
| refresh_button=True, | ||
| ), | ||
| SecretStrInput( | ||
| name="api_key", | ||
|
|
@@ -110,8 +117,40 @@ class EmbeddingModelComponent(LCEmbeddingsModel): | |
| advanced=True, | ||
| info="Additional keyword arguments to pass to the model.", | ||
| ), | ||
| IntInput( | ||
| name="truncate_input_tokens", | ||
| display_name="Truncate Input Tokens", | ||
| advanced=True, | ||
| value=200, | ||
| show=False, | ||
| ), | ||
| BoolInput( | ||
| name="input_text", | ||
| display_name="Include the original text in the output", | ||
| value=True, | ||
| advanced=True, | ||
| show=False, | ||
| ), | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def fetch_ibm_models(base_url: str) -> list[str]: | ||
| """Fetch available models from the watsonx.ai API.""" | ||
| try: | ||
| endpoint = f"{base_url}/ml/v1/foundation_model_specs" | ||
| params = { | ||
| "version": "2024-09-16", | ||
| "filters": "function_embedding,!lifecycle_withdrawn:and", | ||
| } | ||
| response = requests.get(endpoint, params=params, timeout=10) | ||
|
||
| response.raise_for_status() | ||
| data = response.json() | ||
| models = [model["model_id"] for model in data.get("resources", [])] | ||
| return sorted(models) | ||
| except Exception: # noqa: BLE001 | ||
| logger.exception("Error fetching models") | ||
| return WATSONX_EMBEDDING_MODEL_NAMES | ||
|
|
||
| def build_embeddings(self) -> Embeddings: | ||
| provider = self.provider | ||
| model = self.model | ||
|
|
@@ -188,15 +227,26 @@ def build_embeddings(self) -> Embeddings: | |
| msg = "Project ID is required for IBM watsonx.ai provider" | ||
| raise ValueError(msg) | ||
|
|
||
| from ibm_watsonx_ai import APIClient, Credentials | ||
|
|
||
| credentials = Credentials( | ||
| api_key=self.api_key, | ||
| url=base_url_ibm_watsonx or "https://us-south.ml.cloud.ibm.com", | ||
| ) | ||
|
|
||
| api_client = APIClient(credentials) | ||
|
|
||
| params = { | ||
| "model_id": model, | ||
| "url": base_url_ibm_watsonx or "https://us-south.ml.cloud.ibm.com", | ||
| "apikey": api_key, | ||
| EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: self.truncate_input_tokens, | ||
| EmbedTextParamsMetaNames.RETURN_OPTIONS: {"input_text": self.input_text}, | ||
| } | ||
|
|
||
| params["project_id"] = project_id | ||
|
|
||
| return WatsonxEmbeddings(**params) | ||
| return WatsonxEmbeddings( | ||
| model_id=model, | ||
| params=params, | ||
| watsonx_client=api_client, | ||
| project_id=project_id, | ||
| ) | ||
|
|
||
| msg = f"Unknown provider: {provider}" | ||
| raise ValueError(msg) | ||
|
|
@@ -217,7 +267,8 @@ async def update_build_config( | |
| build_config["ollama_base_url"]["show"] = False | ||
| build_config["project_id"]["show"] = False | ||
| build_config["base_url_ibm_watsonx"]["show"] = False | ||
|
|
||
| build_config["truncate_input_tokens"]["show"] = False | ||
| build_config["input_text"]["show"] = False | ||
| elif field_value == "Ollama": | ||
| build_config["ollama_base_url"]["show"] = True | ||
|
|
||
|
|
@@ -238,7 +289,8 @@ async def update_build_config( | |
| else: | ||
| build_config["model"]["options"] = [] | ||
| build_config["model"]["value"] = "" | ||
|
|
||
| build_config["truncate_input_tokens"]["show"] = False | ||
| build_config["input_text"]["show"] = False | ||
| build_config["api_key"]["display_name"] = "API Key (Optional)" | ||
| build_config["api_key"]["required"] = False | ||
| build_config["api_key"]["show"] = False | ||
|
|
@@ -247,16 +299,20 @@ async def update_build_config( | |
| build_config["base_url_ibm_watsonx"]["show"] = False | ||
|
|
||
| elif field_value == "IBM watsonx.ai": | ||
| build_config["model"]["options"] = WATSONX_EMBEDDING_MODEL_NAMES | ||
| build_config["model"]["value"] = WATSONX_EMBEDDING_MODEL_NAMES[0] | ||
| build_config["model"]["options"] = self.fetch_ibm_models(base_url=self.base_url_ibm_watsonx) | ||
| build_config["model"]["value"] = self.fetch_ibm_models(base_url=self.base_url_ibm_watsonx)[0] | ||
|
Comment on lines
+302
to
+303
|
||
| build_config["api_key"]["display_name"] = "IBM watsonx.ai API Key" | ||
| build_config["api_key"]["required"] = True | ||
| build_config["api_key"]["show"] = True | ||
| build_config["api_base"]["show"] = False | ||
| build_config["ollama_base_url"]["show"] = False | ||
| build_config["base_url_ibm_watsonx"]["show"] = True | ||
| build_config["project_id"]["show"] = True | ||
|
|
||
| build_config["truncate_input_tokens"]["show"] = True | ||
| build_config["input_text"]["show"] = True | ||
| elif field_name == "base_url_ibm_watsonx": | ||
| build_config["model"]["options"] = self.fetch_ibm_models(base_url=field_value) | ||
| build_config["model"]["value"] = self.fetch_ibm_models(base_url=field_value)[0] | ||
|
Comment on lines
+314
to
+315
|
||
| elif field_name == "ollama_base_url": | ||
| # # Refresh Ollama models when base URL changes | ||
| # if hasattr(self, "provider") and self.provider == "Ollama": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The filter syntax
"function_embedding,!lifecycle_withdrawn:and"has an unusual:andsuffix at the end. Comparing with the similar implementation inlanguage_model.py(line 57), which uses"function_text_chat,!lifecycle_withdrawn"without the:andsuffix, this appears to be inconsistent. Consider removing the:andsuffix or verifying the correct filter syntax with the IBM watsonx.ai API documentation.