-
Notifications
You must be signed in to change notification settings - Fork 8.2k
fix: Add IBM watsonx.ai support to EmbeddingModel #10677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
55cf5d9
d927cc6
ce8aabd
ec14524
9fad275
7079625
4dc4ec9
5444d54
04f4a78
9f3541f
968a006
b2ee642
3f93924
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| from typing import Any | ||
|
|
||
| from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames | ||
| from langchain_openai import OpenAIEmbeddings | ||
|
|
||
| from lfx.base.embeddings.model import LCEmbeddingsModel | ||
| from lfx.base.models.model_utils import get_ollama_models, is_valid_ollama_url | ||
| from lfx.base.models.openai_constants import OPENAI_EMBEDDING_MODEL_NAMES | ||
| from lfx.base.models.watsonx_constants import IBM_WATSONX_URLS, WATSONX_EMBEDDING_MODEL_NAMES | ||
| from lfx.base.models.watsonx_constants import IBM_WATSONX_URLS, WATSONX_DEFAULT_EMBEDDING_MODELS, WATSONX_EMBEDDING_MODEL_NAMES | ||
|
Check failure on line 9 in src/lfx/src/lfx/components/models_and_agents/embedding_model.py
|
||
|
||
| from lfx.field_typing import Embeddings | ||
| from lfx.io import ( | ||
| BoolInput, | ||
|
|
@@ -19,6 +20,7 @@ | |
| from lfx.log.logger import logger | ||
| from lfx.schema.dotdict import dotdict | ||
| from lfx.utils.util import transform_localhost_url | ||
| import requests | ||
|
||
|
|
||
| # Ollama API constants | ||
| HTTP_STATUS_OK = 200 | ||
|
|
@@ -77,6 +79,8 @@ | |
| options=OPENAI_EMBEDDING_MODEL_NAMES, | ||
| value=OPENAI_EMBEDDING_MODEL_NAMES[0], | ||
| info="Select the embedding model to use", | ||
| real_time_refresh=True, | ||
| refresh_button=True, | ||
| ), | ||
| SecretStrInput( | ||
| name="api_key", | ||
|
|
@@ -110,7 +114,38 @@ | |
| advanced=True, | ||
| info="Additional keyword arguments to pass to the model.", | ||
| ), | ||
| IntInput( | ||
| name="truncate_input_tokens", | ||
| display_name="Truncate Input Tokens", | ||
| advanced=True, | ||
| value=200, | ||
| show=False, | ||
| ), | ||
| BoolInput( | ||
| name="input_text", | ||
| display_name="Include the original text in the output", | ||
| value=True, | ||
| advanced=True, | ||
| show=False, | ||
| ), | ||
| ] | ||
| @staticmethod | ||
| def fetch_ibm_models(base_url: str) -> list[str]: | ||
| """Fetch available models from the watsonx.ai API.""" | ||
| try: | ||
| endpoint = f"{base_url}/ml/v1/foundation_model_specs" | ||
| params = { | ||
| "version": "2024-09-16", | ||
| "filters": "function_embedding,!lifecycle_withdrawn:and", | ||
|
||
| } | ||
| response = requests.get(endpoint, params=params, timeout=10) | ||
|
||
| response.raise_for_status() | ||
| data = response.json() | ||
| models = [model["model_id"] for model in data.get("resources", [])] | ||
| return sorted(models) | ||
| except Exception: # noqa: BLE001 | ||
| logger.exception("Error fetching models") | ||
| return WATSONX_EMBEDDING_MODEL_NAMES | ||
|
|
||
| def build_embeddings(self) -> Embeddings: | ||
| provider = self.provider | ||
|
|
@@ -188,15 +223,26 @@ | |
| msg = "Project ID is required for IBM watsonx.ai provider" | ||
| raise ValueError(msg) | ||
|
|
||
| from ibm_watsonx_ai import APIClient, Credentials | ||
|
|
||
| credentials = Credentials( | ||
| api_key=self.api_key, | ||
| url=base_url_ibm_watsonx or "https://us-south.ml.cloud.ibm.com", | ||
| ) | ||
|
|
||
| api_client = APIClient(credentials) | ||
|
|
||
| params = { | ||
| "model_id": model, | ||
| "url": base_url_ibm_watsonx or "https://us-south.ml.cloud.ibm.com", | ||
| "apikey": api_key, | ||
| EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: self.truncate_input_tokens, | ||
| EmbedTextParamsMetaNames.RETURN_OPTIONS: {"input_text": self.input_text}, | ||
| } | ||
|
|
||
| params["project_id"] = project_id | ||
|
|
||
| return WatsonxEmbeddings(**params) | ||
| return WatsonxEmbeddings( | ||
| model_id=model, | ||
| params=params, | ||
| watsonx_client=api_client, | ||
| project_id=project_id, | ||
| ) | ||
|
|
||
| msg = f"Unknown provider: {provider}" | ||
| raise ValueError(msg) | ||
|
|
@@ -217,7 +263,8 @@ | |
| build_config["ollama_base_url"]["show"] = False | ||
| build_config["project_id"]["show"] = False | ||
| build_config["base_url_ibm_watsonx"]["show"] = False | ||
|
|
||
| build_config["truncate_input_tokens"]["show"] = False | ||
| build_config["input_text"]["show"] = False | ||
| elif field_value == "Ollama": | ||
| build_config["ollama_base_url"]["show"] = True | ||
|
|
||
|
|
@@ -238,7 +285,8 @@ | |
| else: | ||
| build_config["model"]["options"] = [] | ||
| build_config["model"]["value"] = "" | ||
|
|
||
| build_config["truncate_input_tokens"]["show"] = False | ||
| build_config["input_text"]["show"] = False | ||
| build_config["api_key"]["display_name"] = "API Key (Optional)" | ||
| build_config["api_key"]["required"] = False | ||
| build_config["api_key"]["show"] = False | ||
|
|
@@ -247,16 +295,20 @@ | |
| build_config["base_url_ibm_watsonx"]["show"] = False | ||
|
|
||
| elif field_value == "IBM watsonx.ai": | ||
| build_config["model"]["options"] = WATSONX_EMBEDDING_MODEL_NAMES | ||
| build_config["model"]["value"] = WATSONX_EMBEDDING_MODEL_NAMES[0] | ||
| build_config["model"]["options"] = self.fetch_ibm_models(base_url=self.base_url_ibm_watsonx) | ||
| build_config["model"]["value"] = self.fetch_ibm_models(base_url=self.base_url_ibm_watsonx)[0] | ||
|
Comment on lines
+302
to
+303
|
||
| build_config["api_key"]["display_name"] = "IBM watsonx.ai API Key" | ||
| build_config["api_key"]["required"] = True | ||
| build_config["api_key"]["show"] = True | ||
| build_config["api_base"]["show"] = False | ||
| build_config["ollama_base_url"]["show"] = False | ||
| build_config["base_url_ibm_watsonx"]["show"] = True | ||
| build_config["project_id"]["show"] = True | ||
|
|
||
| build_config["truncate_input_tokens"]["show"] = True | ||
| build_config["input_text"]["show"] = True | ||
| elif field_name == "base_url_ibm_watsonx": | ||
| build_config["model"]["options"] = self.fetch_ibm_models(base_url=field_value) | ||
| build_config["model"]["value"] = self.fetch_ibm_models(base_url=field_value)[0] | ||
|
Comment on lines
+314
to
+315
|
||
| elif field_name == "ollama_base_url": | ||
| # # Refresh Ollama models when base URL changes | ||
| # if hasattr(self, "provider") and self.provider == "Ollama": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary blank lines. There are three consecutive blank lines here, which violates PEP 8 style guide that recommends at most two blank lines between top-level definitions. Remove the extra blank lines.