66import litellm
77import instructor
88from pydantic import BaseModel
9-
9+ from cognee . shared . data_models import MonitoringTool
1010from cognee .exceptions import InvalidValueError
1111from cognee .infrastructure .llm .llm_interface import LLMInterface
1212from cognee .infrastructure .llm .prompts import read_query_prompt
13+ from cognee .base_config import get_base_config
14+
15+ if MonitoringTool .LANGFUSE :
16+ from langfuse .decorators import observe
1317
1418class OpenAIAdapter (LLMInterface ):
1519 name = "OpenAI"
1620 model : str
1721 api_key : str
1822 api_version : str
19-
23+
2024 """Adapter for OpenAI's GPT-3, GPT=4 API"""
25+
2126 def __init__ (
22- self ,
23- api_key : str ,
24- endpoint : str ,
25- api_version : str ,
26- model : str ,
27- transcription_model : str ,
28- streaming : bool = False ,
27+ self ,
28+ api_key : str ,
29+ endpoint : str ,
30+ api_version : str ,
31+ model : str ,
32+ transcription_model : str ,
33+ streaming : bool = False ,
2934 ):
3035 self .aclient = instructor .from_litellm (litellm .acompletion )
3136 self .client = instructor .from_litellm (litellm .completion )
@@ -35,45 +40,52 @@ def __init__(
3540 self .endpoint = endpoint
3641 self .api_version = api_version
3742 self .streaming = streaming
43+ base_config = get_base_config ()
44+
45+
46+ @observe ()
47+ async def acreate_structured_output (self , text_input : str , system_prompt : str ,
48+ response_model : Type [BaseModel ]) -> BaseModel :
3849
39- async def acreate_structured_output (self , text_input : str , system_prompt : str , response_model : Type [BaseModel ]) -> BaseModel :
4050 """Generate a response from a user query."""
4151
4252 return await self .aclient .chat .completions .create (
43- model = self .model ,
44- messages = [{
53+ model = self .model ,
54+ messages = [{
4555 "role" : "user" ,
4656 "content" : f"""Use the given format to
4757 extract information from the following input: { text_input } . """ ,
4858 }, {
4959 "role" : "system" ,
5060 "content" : system_prompt ,
5161 }],
52- api_key = self .api_key ,
53- api_base = self .endpoint ,
54- api_version = self .api_version ,
55- response_model = response_model ,
56- max_retries = 5 ,
62+ api_key = self .api_key ,
63+ api_base = self .endpoint ,
64+ api_version = self .api_version ,
65+ response_model = response_model ,
66+ max_retries = 5 ,
5767 )
5868
59- def create_structured_output (self , text_input : str , system_prompt : str , response_model : Type [BaseModel ]) -> BaseModel :
69+ @observe
70+ def create_structured_output (self , text_input : str , system_prompt : str ,
71+ response_model : Type [BaseModel ]) -> BaseModel :
6072 """Generate a response from a user query."""
6173
6274 return self .client .chat .completions .create (
63- model = self .model ,
64- messages = [{
75+ model = self .model ,
76+ messages = [{
6577 "role" : "user" ,
6678 "content" : f"""Use the given format to
6779 extract information from the following input: { text_input } . """ ,
6880 }, {
6981 "role" : "system" ,
7082 "content" : system_prompt ,
7183 }],
72- api_key = self .api_key ,
73- api_base = self .endpoint ,
74- api_version = self .api_version ,
75- response_model = response_model ,
76- max_retries = 5 ,
84+ api_key = self .api_key ,
85+ api_base = self .endpoint ,
86+ api_version = self .api_version ,
87+ response_model = response_model ,
88+ max_retries = 5 ,
7789 )
7890
7991 def create_transcript (self , input ):
@@ -86,12 +98,12 @@ def create_transcript(self, input):
8698 # audio_data = audio_file.read()
8799
88100 transcription = litellm .transcription (
89- model = self .transcription_model ,
90- file = Path (input ),
101+ model = self .transcription_model ,
102+ file = Path (input ),
91103 api_key = self .api_key ,
92104 api_base = self .endpoint ,
93105 api_version = self .api_version ,
94- max_retries = 5 ,
106+ max_retries = 5 ,
95107 )
96108
97109 return transcription
@@ -101,8 +113,8 @@ def transcribe_image(self, input) -> BaseModel:
101113 encoded_image = base64 .b64encode (image_file .read ()).decode ('utf-8' )
102114
103115 return litellm .completion (
104- model = self .model ,
105- messages = [{
116+ model = self .model ,
117+ messages = [{
106118 "role" : "user" ,
107119 "content" : [
108120 {
@@ -119,8 +131,8 @@ def transcribe_image(self, input) -> BaseModel:
119131 api_key = self .api_key ,
120132 api_base = self .endpoint ,
121133 api_version = self .api_version ,
122- max_tokens = 300 ,
123- max_retries = 5 ,
134+ max_tokens = 300 ,
135+ max_retries = 5 ,
124136 )
125137
126138 def show_prompt (self , text_input : str , system_prompt : str ) -> str :
@@ -132,4 +144,4 @@ def show_prompt(self, text_input: str, system_prompt: str) -> str:
132144 system_prompt = read_query_prompt (system_prompt )
133145
134146 formatted_prompt = f"""System Prompt:\n { system_prompt } \n \n User Input:\n { text_input } \n """ if system_prompt else None
135- return formatted_prompt
147+ return formatted_prompt
0 commit comments