Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
resolving merge conflicts
  • Loading branch information
rajeevrajeshuni committed Dec 10, 2025
commit 8e5f14da78a69e4e07ee7024759b12a1ecf5f364
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from instructor.core import InstructorRetryException

import logging
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.shared.logging_utils import get_logger

from tenacity import (
retry,
stop_after_delay,
Expand Down Expand Up @@ -110,24 +113,25 @@ async def acreate_structured_output(
"""

try:
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
max_retries=self.MAX_RETRIES,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
max_retries=2,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
Expand All @@ -145,23 +149,24 @@ async def acreate_structured_output(
)

try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=self.MAX_RETRIES,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=2,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from typing import Type
from openai import OpenAI
from pydantic import BaseModel

from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from tenacity import (
retry,
stop_after_delay,
Expand All @@ -20,10 +21,9 @@
)

logger = get_logger()
observe = get_observe()


class OllamaAPIAdapter(GenericAPIAdapter):
class OllamaAPIAdapter(LLMInterface):
"""
Adapter for a Generic API LLM provider using instructor with an OpenAI backend.

Expand All @@ -47,20 +47,18 @@ class OllamaAPIAdapter(GenericAPIAdapter):

def __init__(
self,
endpoint: str,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
endpoint: str,
instructor_mode: str = None,
):
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Ollama",
endpoint=endpoint,
)
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens

self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode

Expand All @@ -69,10 +67,9 @@ def __init__(
mode=instructor.Mode(self.instructor_mode),
)

@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(8, 128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
Expand Down Expand Up @@ -117,3 +114,95 @@ async def acreate_structured_output(
)

return response

@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input_file: str) -> str:
"""
Generate an audio transcript from a user query.

This synchronous method takes an input audio file and returns its transcription. Raises
a FileNotFoundError if the input file does not exist, and raises a ValueError if
transcription fails or returns no text.

Parameters:
-----------

- input_file (str): The path to the audio file to be transcribed.

Returns:
--------

- str: The transcription of the audio as a string.
"""

async with open_data_file(input_file, mode="rb") as audio_file:
transcription = self.aclient.audio.transcriptions.create(
model="whisper-1", # Ensure the correct model for transcription
file=audio_file,
language="en",
)

# Ensure the response contains a valid transcript
if not hasattr(transcription, "text"):
raise ValueError("Transcription failed. No text returned.")

return transcription.text

@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input_file: str) -> str:
"""
Transcribe content from an image using base64 encoding.

This synchronous method takes an input image file, encodes it as base64, and returns the
transcription of its content. Raises a FileNotFoundError if the input file does not
exist, and raises a ValueError if the transcription fails or no valid response is
received.

Parameters:
-----------

- input_file (str): The path to the image file to be transcribed.

Returns:
--------

- str: The transcription of the image's content as a string.
"""

async with open_data_file(input_file, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")

response = self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
}
],
max_completion_tokens=300,
)

# Ensure response is valid before accessing .choices[0].message.content
if not hasattr(response, "choices") or not response.choices:
raise ValueError("Image transcription failed. No response received.")

return response.choices[0].message.content
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ async def acreate_structured_output(
"content": system_prompt,
},
],
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
)
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.