Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import httpx
import logging
from typing import List, Optional
import os

from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer

logger = logging.getLogger("OllamaEmbeddingEngine")


class OllamaEmbeddingEngine(EmbeddingEngine):
model: str
dimensions: int
max_tokens: int
endpoint: str
mock: bool
huggingface_tokenizer_name: str

MAX_RETRIES = 5

def __init__(
self,
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
dimensions: Optional[int] = 1024,
max_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
):
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.endpoint = endpoint
self.huggingface_tokenizer_name = huggingface_tokenizer
self.tokenizer = self.get_tokenizer()

enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")

async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Given a list of text prompts, returns a list of embedding vectors.
"""
if self.mock:
return [[0.0] * self.dimensions for _ in text]

embeddings = []
async with httpx.AsyncClient() as client:
for prompt in text:
embedding = await self._get_embedding(client, prompt)
embeddings.append(embedding)
return embeddings

async def _get_embedding(self, client: httpx.AsyncClient, prompt: str) -> List[float]:
"""
Internal method to call the Ollama embeddings endpoint for a single prompt.
"""
payload = {
"model": self.model,
"prompt": prompt,
}
headers = {}
api_key = os.getenv("LLM_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

retries = 0
while retries < self.MAX_RETRIES:
try:
response = await client.post(
self.endpoint, json=payload, headers=headers, timeout=60.0
)
response.raise_for_status()
data = response.json()
return data["embedding"]
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
except Exception as e:
logger.error(f"Error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
raise EmbeddingException(
f"Failed to embed text using model {self.model} after {self.MAX_RETRIES} retries"
)

def get_vector_size(self) -> int:
return self.dimensions

def get_tokenizer(self):
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
tokenizer = HuggingFaceTokenizer(
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens
)
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
return tokenizer
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class EmbeddingConfig(BaseSettings):
embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None
embedding_max_tokens: Optional[int] = 8191
huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ def get_embedding_engine() -> EmbeddingEngine:
max_tokens=config.embedding_max_tokens,
)

if config.embedding_provider == "ollama":
from .OllamaEmbeddingEngine import OllamaEmbeddingEngine

return OllamaEmbeddingEngine(
model=config.embedding_model,
dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
huggingface_tokenizer=config.huggingface_tokenizer,
)

from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine

return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key.
provider=config.embedding_provider,
api_key=config.embedding_api_key or llm_config.llm_api_key,
endpoint=config.embedding_endpoint,
Expand Down
3 changes: 2 additions & 1 deletion cognee/infrastructure/llm/get_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm import get_llm_config
from cognee.infrastructure.llm.ollama.adapter import OllamaAPIAdapter


# Define an Enum for LLM Providers
Expand Down Expand Up @@ -52,7 +53,7 @@ def get_llm_client():

from .generic_llm_api.adapter import GenericAPIAdapter

return GenericAPIAdapter(
return OllamaAPIAdapter(
llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,
Expand Down
Empty file.
44 changes: 44 additions & 0 deletions cognee/infrastructure/llm/ollama/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Type
from pydantic import BaseModel
import instructor
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
from openai import OpenAI


class OllamaAPIAdapter(LLMInterface):
"""Adapter for a Generic API LLM provider using instructor with an OpenAI backend."""

def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int):
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_tokens = max_tokens

self.aclient = instructor.from_openai(
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
)

async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate a structured output from the LLM using the provided text and system prompt."""

response = self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"Use the given format to extract information from the following input: {text_input}",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
response_model=response_model,
)

return response
Comment on lines +23 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling and configuration options.

The method should handle API errors and allow configuration of model parameters.

     async def acreate_structured_output(
         self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
     ) -> BaseModel:
-        """Generate a structured output from the LLM using the provided text and system prompt."""
+        """Generate a structured output from Ollama using the provided text and system prompt.
+        
+        Args:
+            text_input: The input text to process
+            system_prompt: The system prompt to guide the model
+            response_model: Pydantic model for response structure
+            
+        Returns:
+            BaseModel: Structured response matching response_model
+            
+        Raises:
+            OpenAIError: If API call fails
+            ValueError: If input validation fails
+        """
+        if not text_input or not system_prompt:
+            raise ValueError("text_input and system_prompt are required")
 
-        response = self.aclient.chat.completions.create(
-            model=self.model,
-            messages=[
-                {
-                    "role": "user",
-                    "content": f"Use the given format to extract information from the following input: {text_input}",
-                },
-                {
-                    "role": "system",
-                    "content": system_prompt,
-                },
-            ],
-            max_retries=5,
-            response_model=response_model,
-        )
+        try:
+            response = await self.aclient.chat.completions.create(
+                model=self.model,
+                messages=[
+                    {
+                        "role": "user",
+                        "content": f"Use the given format to extract information from the following input: {text_input}",
+                    },
+                    {
+                        "role": "system",
+                        "content": system_prompt,
+                    },
+                ],
+                max_retries=5,
+                response_model=response_model,
+                temperature=0.7,  # Add configurable parameters
+                timeout=30,  # Add timeout
+            )
+            return response
+        except Exception as e:
+            raise OpenAIError(f"Failed to generate structured output: {str(e)}")
-        return response
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate a structured output from the LLM using the provided text and system prompt."""
response = self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"Use the given format to extract information from the following input: {text_input}",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
response_model=response_model,
)
return response
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate a structured output from Ollama using the provided text and system prompt.
Args:
text_input: The input text to process
system_prompt: The system prompt to guide the model
response_model: Pydantic model for response structure
Returns:
BaseModel: Structured response matching response_model
Raises:
OpenAIError: If API call fails
ValueError: If input validation fails
"""
if not text_input or not system_prompt:
raise ValueError("text_input and system_prompt are required")
try:
response = await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"Use the given format to extract information from the following input: {text_input}",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
response_model=response_model,
temperature=0.7, # Add configurable parameters
timeout=30, # Add timeout
)
return response
except Exception as e:
raise OpenAIError(f"Failed to generate structured output: {str(e)}")

Loading