diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py new file mode 100644 index 0000000000..26f3bfd016 --- /dev/null +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -0,0 +1,101 @@ +import asyncio +import httpx +import logging +from typing import List, Optional +import os + +from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine +from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer + +logger = logging.getLogger("OllamaEmbeddingEngine") + + +class OllamaEmbeddingEngine(EmbeddingEngine): + model: str + dimensions: int + max_tokens: int + endpoint: str + mock: bool + huggingface_tokenizer_name: str + + MAX_RETRIES = 5 + + def __init__( + self, + model: Optional[str] = "avr/sfr-embedding-mistral:latest", + dimensions: Optional[int] = 1024, + max_tokens: int = 512, + endpoint: Optional[str] = "http://localhost:11434/api/embeddings", + huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral", + ): + self.model = model + self.dimensions = dimensions + self.max_tokens = max_tokens + self.endpoint = endpoint + self.huggingface_tokenizer_name = huggingface_tokenizer + self.tokenizer = self.get_tokenizer() + + enable_mocking = os.getenv("MOCK_EMBEDDING", "false") + if isinstance(enable_mocking, bool): + enable_mocking = str(enable_mocking).lower() + self.mock = enable_mocking in ("true", "1", "yes") + + async def embed_text(self, text: List[str]) -> List[List[float]]: + """ + Given a list of text prompts, returns a list of embedding vectors. + """ + if self.mock: + return [[0.0] * self.dimensions for _ in text] + + embeddings = [] + async with httpx.AsyncClient() as client: + for prompt in text: + embedding = await self._get_embedding(client, prompt) + embeddings.append(embedding) + return embeddings + + async def _get_embedding(self, client: httpx.AsyncClient, prompt: str) -> List[float]: + """ + Internal method to call the Ollama embeddings endpoint for a single prompt. + """ + payload = { + "model": self.model, + "prompt": prompt, + } + headers = {} + api_key = os.getenv("LLM_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + retries = 0 + while retries < self.MAX_RETRIES: + try: + response = await client.post( + self.endpoint, json=payload, headers=headers, timeout=60.0 + ) + response.raise_for_status() + data = response.json() + return data["embedding"] + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error on attempt {retries + 1}: {e}") + retries += 1 + await asyncio.sleep(min(2**retries, 60)) + except Exception as e: + logger.error(f"Error on attempt {retries + 1}: {e}") + retries += 1 + await asyncio.sleep(min(2**retries, 60)) + raise EmbeddingException( + f"Failed to embed text using model {self.model} after {self.MAX_RETRIES} retries" + ) + + def get_vector_size(self) -> int: + return self.dimensions + + def get_tokenizer(self): + logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...") + tokenizer = HuggingFaceTokenizer( + model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens + ) + logger.debug("Tokenizer loaded for OllamaEmbeddingEngine") + return tokenizer diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 315caf7eff..733548f89e 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -11,6 +11,7 @@ class EmbeddingConfig(BaseSettings): embedding_api_key: Optional[str] = None embedding_api_version: Optional[str] = None embedding_max_tokens: Optional[int] = 8191 + huggingface_tokenizer: Optional[str] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index 5d23b21642..afd11cc9cd 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -16,10 +16,19 @@ def get_embedding_engine() -> EmbeddingEngine: max_tokens=config.embedding_max_tokens, ) + if config.embedding_provider == "ollama": + from .OllamaEmbeddingEngine import OllamaEmbeddingEngine + + return OllamaEmbeddingEngine( + model=config.embedding_model, + dimensions=config.embedding_dimensions, + max_tokens=config.embedding_max_tokens, + huggingface_tokenizer=config.huggingface_tokenizer, + ) + from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine return LiteLLMEmbeddingEngine( - # If OpenAI API is used for embeddings, litellm needs only the api_key. provider=config.embedding_provider, api_key=config.embedding_api_key or llm_config.llm_api_key, endpoint=config.embedding_endpoint, diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 5e26345e83..4a095d179f 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -4,6 +4,7 @@ from cognee.exceptions import InvalidValueError from cognee.infrastructure.llm import get_llm_config +from cognee.infrastructure.llm.ollama.adapter import OllamaAPIAdapter # Define an Enum for LLM Providers @@ -52,7 +53,7 @@ def get_llm_client(): from .generic_llm_api.adapter import GenericAPIAdapter - return GenericAPIAdapter( + return OllamaAPIAdapter( llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, diff --git a/cognee/infrastructure/llm/ollama/__init__.py b/cognee/infrastructure/llm/ollama/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cognee/infrastructure/llm/ollama/adapter.py b/cognee/infrastructure/llm/ollama/adapter.py new file mode 100644 index 0000000000..4eb3927394 --- /dev/null +++ b/cognee/infrastructure/llm/ollama/adapter.py @@ -0,0 +1,44 @@ +from typing import Type +from pydantic import BaseModel +import instructor +from cognee.infrastructure.llm.llm_interface import LLMInterface +from cognee.infrastructure.llm.config import get_llm_config +from openai import OpenAI + + +class OllamaAPIAdapter(LLMInterface): + """Adapter for a Generic API LLM provider using instructor with an OpenAI backend.""" + + def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int): + self.name = name + self.model = model + self.api_key = api_key + self.endpoint = endpoint + self.max_tokens = max_tokens + + self.aclient = instructor.from_openai( + OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON + ) + + async def acreate_structured_output( + self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + ) -> BaseModel: + """Generate a structured output from the LLM using the provided text and system prompt.""" + + response = self.aclient.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": f"Use the given format to extract information from the following input: {text_input}", + }, + { + "role": "system", + "content": system_prompt, + }, + ], + max_retries=5, + response_model=response_model, + ) + + return response