From 02b17786588b7be4c582a2fdfe93a5412c074cda Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Tue, 25 Nov 2025 12:22:15 +0530 Subject: [PATCH 1/4] Adding support for audio/image transcription for all other providers --- cognee/infrastructure/llm/LLMGateway.py | 13 -- .../llm/anthropic/adapter.py | 27 ++-- .../litellm_instructor/llm/gemini/adapter.py | 59 +++---- .../llm/generic_llm_api/adapter.py | 132 +++++++++++++++- .../litellm_instructor/llm/get_llm_client.py | 13 +- .../litellm_instructor/llm/llm_interface.py | 47 +++++- .../litellm_instructor/llm/mistral/adapter.py | 66 ++++++-- .../litellm_instructor/llm/ollama/adapter.py | 118 ++------------ .../litellm_instructor/llm/openai/adapter.py | 148 +++--------------- uv.lock | 2 +- 10 files changed, 313 insertions(+), 312 deletions(-) diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py index ab5bb35d78..66a364110f 100644 --- a/cognee/infrastructure/llm/LLMGateway.py +++ b/cognee/infrastructure/llm/LLMGateway.py @@ -34,19 +34,6 @@ def acreate_structured_output( text_input=text_input, system_prompt=system_prompt, response_model=response_model ) - @staticmethod - def create_structured_output( - text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( - get_llm_client, - ) - - llm_client = get_llm_client() - return llm_client.create_structured_output( - text_input=text_input, system_prompt=system_prompt, response_model=response_model - ) - @staticmethod def create_transcript(input) -> Coroutine: from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index dbf0dfbeaf..818d3adb79 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -3,7 +3,9 @@ from pydantic import BaseModel import litellm import instructor +import anthropic from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe from tenacity import ( retry, stop_after_delay, @@ -12,27 +14,32 @@ before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.config import get_llm_config logger = get_logger() +observe = get_observe() -class AnthropicAdapter(LLMInterface): +class AnthropicAdapter(GenericAPIAdapter): """ Adapter for interfacing with the Anthropic API, enabling structured output generation and prompt display. """ - name = "Anthropic" - model: str default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): - import anthropic - + def __init__( + self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None + ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Anthropic", + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( @@ -40,9 +47,7 @@ def __init__(self, max_completion_tokens: int, model: str = None, instructor_mod mode=instructor.Mode(self.instructor_mode), ) - self.model = model - self.max_completion_tokens = max_completion_tokens - + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 226f291d7d..bae6650526 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -1,4 +1,4 @@ -"""Adapter for Generic API LLM provider API""" +"""Adapter for Gemini API LLM provider""" import litellm import instructor @@ -8,12 +8,7 @@ from litellm.exceptions import ContentPolicyViolationError from instructor.core import InstructorRetryException -from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, -) import logging -from cognee.shared.logging_utils import get_logger from tenacity import ( retry, stop_after_delay, @@ -22,55 +17,65 @@ before_sleep_log, ) +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, +) +from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe + logger = get_logger() +observe = get_observe() -class GeminiAdapter(LLMInterface): +class GeminiAdapter(GenericAPIAdapter): """ Adapter for Gemini API LLM provider. This class initializes the API adapter with necessary credentials and configurations for interacting with the gemini LLM models. It provides methods for creating structured outputs - based on user input and system prompts. + based on user input and system prompts, as well as multimodal processing capabilities. Public methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) -> BaseModel + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel + - create_transcript(input) -> BaseModel: Transcribe audio files to text + - transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter """ - name: str - model: str - api_key: str default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - api_version: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens - - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Gemini", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -118,7 +123,7 @@ async def acreate_structured_output( }, ], api_key=self.api_key, - max_retries=5, + max_retries=self.MAX_RETRIES, api_base=self.endpoint, api_version=self.api_version, response_model=response_model, @@ -152,7 +157,7 @@ async def acreate_structured_output( "content": system_prompt, }, ], - max_retries=5, + max_retries=self.MAX_RETRIES, api_key=self.fallback_api_key, api_base=self.fallback_endpoint, response_model=response_model, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 9d7f25fc53..9987711b97 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -1,8 +1,10 @@ """Adapter for Generic API LLM provider API""" +import base64 +import mimetypes import litellm import instructor -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError @@ -12,6 +14,8 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( LLMInterface, ) +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from cognee.modules.observability.get_observe import get_observe import logging from cognee.shared.logging_utils import get_logger from tenacity import ( @@ -23,6 +27,7 @@ ) logger = get_logger() +observe = get_observe() class GenericAPIAdapter(LLMInterface): @@ -38,18 +43,19 @@ class GenericAPIAdapter(LLMInterface): Type[BaseModel]) -> BaseModel """ - name: str - model: str - api_key: str + MAX_RETRIES = 5 default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - name: str, max_completion_tokens: int, + name: str, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, @@ -58,9 +64,11 @@ def __init__( self.name = name self.model = model self.api_key = api_key + self.api_version = api_version self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens - + self.transcription_model = transcription_model or model + self.image_transcribe_model = image_transcribe_model or model self.fallback_model = fallback_model self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint @@ -71,6 +79,7 @@ def __init__( litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -170,3 +179,112 @@ async def acreate_structured_output( raise ContentPolicyFilterError( f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input) -> Optional[BaseModel]: + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file, raising a + FileNotFoundError if the file does not exist. The audio file is processed and the + transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + async with open_data_file(input, mode="rb") as audio_file: + encoded_string = base64.b64encode(audio_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("audio/"): + raise ValueError( + f"Could not determine MIME type for audio file: {input}. Is the extension correct?" + ) + return litellm.completion( + model=self.transcription_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "file", + "file": {"file_data": f"data:{mime_type};base64,{encoded_string}"}, + }, + {"type": "text", "text": "Transcribe the following audio precisely."}, + ], + } + ], + api_key=self.api_key, + api_version=self.api_version, + max_completion_tokens=self.max_completion_tokens, + api_base=self.endpoint, + max_retries=self.MAX_RETRIES, + ) + + @observe(as_type="transcribe_image") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def transcribe_image(self, input) -> Optional[BaseModel]: + """ + Generate a transcription of an image from a user query. + + This method encodes the image and sends a request to the API to obtain a + description of the contents of the image. + + Parameters: + ----------- + - input: The path to the image file that needs to be transcribed. + + Returns: + -------- + - BaseModel: A structured output generated by the model, returned as an instance of + BaseModel. + """ + async with open_data_file(input, mode="rb") as image_file: + encoded_image = base64.b64encode(image_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("image/"): + raise ValueError( + f"Could not determine MIME type for image file: {input}. Is the extension correct?" + ) + return litellm.completion( + model=self.image_transcribe_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?", + }, + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{encoded_image}", + }, + }, + ], + } + ], + api_key=self.api_key, + api_base=self.endpoint, + api_version=self.api_version, + max_completion_tokens=300, + max_retries=self.MAX_RETRIES, + ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 39558f36d1..de6cfaf19a 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -97,11 +97,10 @@ def get_llm_client(raise_api_key_error: bool = True): ) return OllamaAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, - "Ollama", - max_completion_tokens=max_completion_tokens, + max_completion_tokens, + llm_config.llm_endpoint, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -111,8 +110,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, - model=llm_config.llm_model, + llm_config.llm_api_key, + llm_config.llm_model, + max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -125,11 +125,10 @@ def get_llm_client(raise_api_key_error: bool = True): ) return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, + max_completion_tokens, "Custom", - max_completion_tokens=max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py index b02105484b..f8352737db 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py @@ -1,6 +1,6 @@ """LLM Interface""" -from typing import Type, Protocol +from typing import Type, Protocol, Optional from abc import abstractmethod from pydantic import BaseModel from cognee.infrastructure.llm.LLMGateway import LLMGateway @@ -8,13 +8,12 @@ class LLMInterface(Protocol): """ - Define an interface for LLM models with methods for structured output and prompt - display. + Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display. Methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) - - show_prompt(text_input: str, system_prompt: str) + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) + - create_transcript(input): Transcribe audio files to text + - transcribe_image(input): Analyze image files and return text description """ @abstractmethod @@ -36,3 +35,39 @@ async def acreate_structured_output( output. """ raise NotImplementedError + + @abstractmethod + async def create_transcript(self, input) -> Optional[BaseModel]: + """ + Transcribe audio content to text. + + This method should be implemented by subclasses that support audio transcription. + If not implemented, returns None and should be handled gracefully by callers. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + - BaseModel: A structured output containing the transcription, or None if not supported. + """ + raise NotImplementedError + + @abstractmethod + async def transcribe_image(self, input) -> Optional[BaseModel]: + """ + Analyze image content and return text description. + + This method should be implemented by subclasses that support image analysis. + If not implemented, returns None and should be handled gracefully by callers. + + Parameters: + ----------- + - input: The path to the image file that needs to be analyzed. + + Returns: + -------- + - BaseModel: A structured output containing the image description, or None if not supported. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 355cdae0b0..0fa35923f4 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -2,12 +2,12 @@ import instructor from pydantic import BaseModel from typing import Type -from litellm import JSONSchemaValidationError +from litellm import JSONSchemaValidationError, transcription from cognee.shared.logging_utils import get_logger from cognee.modules.observability.get_observe import get_observe -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.config import get_llm_config @@ -19,12 +19,13 @@ retry_if_not_exception_type, before_sleep_log, ) +from mistralai import Mistral logger = get_logger() observe = get_observe() -class MistralAdapter(LLMInterface): +class MistralAdapter(GenericAPIAdapter): """ Adapter for Mistral AI API, for structured output generation and prompt display. @@ -33,10 +34,6 @@ class MistralAdapter(LLMInterface): - show_prompt """ - name = "Mistral" - model: str - api_key: str - max_completion_tokens: int default_instructor_mode = "mistral_tools" def __init__( @@ -45,12 +42,21 @@ def __init__( model: str, max_completion_tokens: int, endpoint: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, ): from mistralai import Mistral - self.model = model - self.max_completion_tokens = max_completion_tokens + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Mistral", + endpoint=endpoint, + transcription_model=transcription_model, + image_transcribe_model=image_transcribe_model, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode @@ -60,6 +66,7 @@ def __init__( api_key=get_llm_config().llm_api_key, ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -117,3 +124,42 @@ async def acreate_structured_output( logger.error(f"Schema validation failed: {str(e)}") logger.debug(f"Raw response: {e.raw_response}") raise ValueError(f"Response failed schema validation: {str(e)}") + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input): + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file. + The audio file is processed and the transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + transcription_model = self.transcription_model + if self.transcription_model.startswith("mistral"): + transcription_model = self.transcription_model.split("/")[-1] + file_name = input.split("/")[-1] + client = Mistral(api_key=self.api_key) + with open(input, "rb") as f: + transcription_response = client.audio.transcriptions.complete( + model=transcription_model, + file={ + "content": f, + "file_name": file_name, + }, + ) + # TODO: We need to standardize return type of create_transcript across different models. + return transcription_response diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index aabd19867b..163637a953 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -5,12 +5,12 @@ from typing import Type from openai import OpenAI from pydantic import BaseModel - -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, -) from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, +) from tenacity import ( retry, stop_after_delay, @@ -20,9 +20,10 @@ ) logger = get_logger() +observe = get_observe() -class OllamaAPIAdapter(LLMInterface): +class OllamaAPIAdapter(GenericAPIAdapter): """ Adapter for a Generic API LLM provider using instructor with an OpenAI backend. @@ -46,18 +47,20 @@ class OllamaAPIAdapter(LLMInterface): def __init__( self, - endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int, + endpoint: str, instructor_mode: str = None, ): - self.name = name - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.max_completion_tokens = max_completion_tokens + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Ollama", + endpoint=endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode @@ -66,6 +69,7 @@ def __init__( mode=instructor.Mode(self.instructor_mode), ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -113,95 +117,3 @@ async def acreate_structured_output( ) return response - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def create_transcript(self, input_file: str) -> str: - """ - Generate an audio transcript from a user query. - - This synchronous method takes an input audio file and returns its transcription. Raises - a FileNotFoundError if the input file does not exist, and raises a ValueError if - transcription fails or returns no text. - - Parameters: - ----------- - - - input_file (str): The path to the audio file to be transcribed. - - Returns: - -------- - - - str: The transcription of the audio as a string. - """ - - async with open_data_file(input_file, mode="rb") as audio_file: - transcription = self.aclient.audio.transcriptions.create( - model="whisper-1", # Ensure the correct model for transcription - file=audio_file, - language="en", - ) - - # Ensure the response contains a valid transcript - if not hasattr(transcription, "text"): - raise ValueError("Transcription failed. No text returned.") - - return transcription.text - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def transcribe_image(self, input_file: str) -> str: - """ - Transcribe content from an image using base64 encoding. - - This synchronous method takes an input image file, encodes it as base64, and returns the - transcription of its content. Raises a FileNotFoundError if the input file does not - exist, and raises a ValueError if the transcription fails or no valid response is - received. - - Parameters: - ----------- - - - input_file (str): The path to the image file to be transcribed. - - Returns: - -------- - - - str: The transcription of the image's content as a string. - """ - - async with open_data_file(input_file, mode="rb") as image_file: - encoded_image = base64.b64encode(image_file.read()).decode("utf-8") - - response = self.aclient.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, - }, - ], - } - ], - max_completion_tokens=300, - ) - - # Ensure response is valid before accessing .choices[0].message.content - if not hasattr(response, "choices") or not response.choices: - raise ValueError("Image transcription failed. No response received.") - - return response.choices[0].message.content diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 778c8eec77..e9943c3354 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,4 +1,3 @@ -import base64 import litellm import instructor from typing import Type @@ -16,8 +15,8 @@ before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.exceptions import ( ContentPolicyFilterError, @@ -31,7 +30,7 @@ observe = get_observe() -class OpenAIAdapter(LLMInterface): +class OpenAIAdapter(GenericAPIAdapter): """ Adapter for OpenAI's GPT-3, GPT-4 API. @@ -52,12 +51,7 @@ class OpenAIAdapter(LLMInterface): - MAX_RETRIES """ - name = "OpenAI" - model: str - api_key: str - api_version: str default_instructor_mode = "json_schema_mode" - MAX_RETRIES = 5 """Adapter for OpenAI's GPT-3, GPT=4 API""" @@ -65,17 +59,29 @@ class OpenAIAdapter(LLMInterface): def __init__( self, api_key: str, - endpoint: str, - api_version: str, model: str, - transcription_model: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="OpenAI", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. @@ -90,18 +96,8 @@ def __init__( self.aclient = instructor.from_litellm(litellm.acompletion) self.client = instructor.from_litellm(litellm.completion) - self.transcription_model = transcription_model - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens self.streaming = streaming - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - @observe(as_type="generation") @retry( stop=stop_after_delay(128), @@ -174,7 +170,7 @@ async def acreate_structured_output( }, ], api_key=self.fallback_api_key, - # api_base=self.fallback_endpoint, + api_base=self.fallback_endpoint, response_model=response_model, max_retries=self.MAX_RETRIES, ) @@ -193,57 +189,7 @@ async def acreate_structured_output( f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error - @observe - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - def create_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - """ - Generate a response from a user query. - - This method creates structured output by sending a synchronous request to the OpenAI API - using the provided parameters to generate a completion based on the user input and - system prompt. - - Parameters: - ----------- - - - text_input (str): The input text provided by the user for generating a response. - - system_prompt (str): The system's prompt to guide the model's response. - - response_model (Type[BaseModel]): The expected model type for the response. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - - return self.client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""{text_input}""", - }, - { - "role": "system", - "content": system_prompt, - }, - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - response_model=response_model, - max_retries=self.MAX_RETRIES, - ) - + @observe(as_type="transcription") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -282,56 +228,4 @@ async def create_transcript(self, input): return transcription - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def transcribe_image(self, input) -> BaseModel: - """ - Generate a transcription of an image from a user query. - - This method encodes the image and sends a request to the OpenAI API to obtain a - description of the contents of the image. - - Parameters: - ----------- - - - input: The path to the image file that needs to be transcribed. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - async with open_data_file(input, mode="rb") as image_file: - encoded_image = base64.b64encode(image_file.read()).decode("utf-8") - - return litellm.completion( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What's in this image?", - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{encoded_image}", - }, - }, - ], - } - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - max_completion_tokens=300, - max_retries=self.MAX_RETRIES, - ) + # transcribe image inherited from GenericAdapter diff --git a/uv.lock b/uv.lock index cc66c3d7e6..d8fb3805bd 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'", From 09fbf2276828043b8ed1458f50b3ab7efcaa04d2 Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Tue, 25 Nov 2025 12:24:30 +0530 Subject: [PATCH 2/4] uv lock version revert --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index d8fb3805bd..cc66c3d7e6 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'", From d57d1884599257a1305395212935fc352c1caf84 Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Wed, 10 Dec 2025 10:52:10 +0530 Subject: [PATCH 3/4] resolving merge conflicts --- .../litellm_instructor/llm/ollama/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index ec7addcaf5..abcd21f862 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -69,7 +69,7 @@ def __init__( @retry( stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), + wait=wait_exponential_jitter(8, 128), retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, @@ -117,7 +117,7 @@ async def acreate_structured_output( @retry( stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), + wait=wait_exponential_jitter(8, 128), retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, From 6260f9eb82c9078eded523a80f035f4054d7091c Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Thu, 11 Dec 2025 06:53:36 +0530 Subject: [PATCH 4/4] strandardizing return type for transcription and some CR changes --- .../llm/anthropic/adapter.py | 2 +- .../litellm_instructor/llm/gemini/adapter.py | 1 - .../llm/generic_llm_api/adapter.py | 11 ++++++++-- .../litellm_instructor/llm/get_llm_client.py | 3 ++- .../litellm_instructor/llm/mistral/adapter.py | 22 +++++++++---------- .../litellm_instructor/llm/openai/adapter.py | 11 ++++++---- .../litellm_instructor/llm/types.py | 9 ++++++++ 7 files changed, 39 insertions(+), 20 deletions(-) create mode 100644 cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index 4d75c886a6..49b13fcaa2 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -44,7 +44,7 @@ def __init__( self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( - create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, + create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create, mode=instructor.Mode(self.instructor_mode), ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index ffb7bf77b7..99dfd61791 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -10,7 +10,6 @@ import logging from cognee.shared.rate_limiting import llm_rate_limiter_context_manager -from cognee.shared.logging_utils import get_logger from tenacity import ( retry, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 408058d3e9..7905c25bf1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -27,6 +27,8 @@ before_sleep_log, ) +from ..types import TranscriptionReturnType + logger = get_logger() observe = get_observe() @@ -191,7 +193,7 @@ async def acreate_structured_output( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input) -> Optional[BaseModel]: + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -214,7 +216,7 @@ async def create_transcript(self, input) -> Optional[BaseModel]: raise ValueError( f"Could not determine MIME type for audio file: {input}. Is the extension correct?" ) - return litellm.completion( + response = litellm.completion( model=self.transcription_model, messages=[ { @@ -234,6 +236,11 @@ async def create_transcript(self, input) -> Optional[BaseModel]: api_base=self.endpoint, max_retries=self.MAX_RETRIES, ) + if response and response.choices and len(response.choices) > 0: + return TranscriptionReturnType(response.choices[0].message.content,response) + else: + return None + @observe(as_type="transcribe_image") @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index de6cfaf19a..e5f4bd1b16 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -97,10 +97,11 @@ def get_llm_client(raise_api_key_error: bool = True): ) return OllamaAPIAdapter( + llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, + "Ollama", max_completion_tokens, - llm_config.llm_endpoint, instructor_mode=llm_config.llm_instructor_mode.lower(), ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 954510a258..b141f75858 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -1,9 +1,9 @@ import litellm import instructor from pydantic import BaseModel -from typing import Type -from litellm import JSONSchemaValidationError, transcription - +from typing import Type, Optional +from litellm import JSONSchemaValidationError +from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger from cognee.modules.observability.get_observe import get_observe from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( @@ -20,6 +20,7 @@ retry_if_not_exception_type, before_sleep_log, ) +from ..types import TranscriptionReturnType from mistralai import Mistral logger = get_logger() @@ -47,8 +48,6 @@ def __init__( image_transcribe_model: str = None, instructor_mode: str = None, ): - from mistralai import Mistral - super().__init__( api_key=api_key, model=model, @@ -66,6 +65,7 @@ def __init__( mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) + self.mistral_client = Mistral(api_key=self.api_key) @observe(as_type="generation") @retry( @@ -135,7 +135,7 @@ async def acreate_structured_output( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input): + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -154,14 +154,14 @@ async def create_transcript(self, input): if self.transcription_model.startswith("mistral"): transcription_model = self.transcription_model.split("/")[-1] file_name = input.split("/")[-1] - client = Mistral(api_key=self.api_key) - with open(input, "rb") as f: - transcription_response = client.audio.transcriptions.complete( + async with open_data_file(input, mode="rb") as f: + transcription_response = self.mistral_client.audio.transcriptions.complete( model=transcription_model, file={ "content": f, "file_name": file_name, }, ) - # TODO: We need to standardize return type of create_transcript across different models. - return transcription_response + if transcription_response: + return TranscriptionReturnType(transcription_response.text, transcription_response) + return None diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 57b6d339a8..94c6aed6df 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,6 +1,6 @@ import litellm import instructor -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError @@ -25,6 +25,7 @@ from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.modules.observability.get_observe import get_observe from cognee.shared.logging_utils import get_logger +from ..types import TranscriptionReturnType logger = get_logger() @@ -200,7 +201,7 @@ async def acreate_structured_output( before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input): + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -228,7 +229,9 @@ async def create_transcript(self, input): api_version=self.api_version, max_retries=self.MAX_RETRIES, ) + if transcription: + return TranscriptionReturnType(transcription.text, transcription) - return transcription + return None - # transcribe image inherited from GenericAdapter + # transcribe_image is inherited from GenericAPIAdapter diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py new file mode 100644 index 0000000000..887cdd88dd --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +class TranscriptionReturnType: + text: str + payload: BaseModel + + def __init__(self, text:str, payload: BaseModel): + self.text = text + self.payload = payload \ No newline at end of file