Skip to content

Commit c448dfb

Browse files
committed
Fix langfuse
1 parent 92ecd8a commit c448dfb

File tree

5 files changed

+509
-575
lines changed

5 files changed

+509
-575
lines changed

cognee/api/v1/cognify/code_graph_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
1818
from cognee.tasks.graph import extract_graph_from_code
1919
from cognee.tasks.storage import add_data_points
20+
from cognee.base_config import get_base_config
21+
from cognee.shared.data_models import MonitoringTool
22+
if MonitoringTool.LANGFUSE:
23+
from langfuse.decorators import observe
2024

2125
logger = logging.getLogger("code_graph_pipeline")
2226

@@ -49,7 +53,7 @@ async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User
4953

5054
return await asyncio.gather(*awaitables)
5155

52-
56+
@observe
5357
async def run_pipeline(dataset: Dataset, user: User):
5458
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
5559

cognee/base_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ class BaseConfig(BaseSettings):
1010
monitoring_tool: object = MonitoringTool.LANGFUSE
1111
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
1212
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")
13-
13+
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
14+
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
15+
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
1416
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
1517

1618
def to_dict(self) -> dict:

cognee/infrastructure/llm/openai/adapter.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,31 @@
66
import litellm
77
import instructor
88
from pydantic import BaseModel
9-
9+
from cognee.shared.data_models import MonitoringTool
1010
from cognee.exceptions import InvalidValueError
1111
from cognee.infrastructure.llm.llm_interface import LLMInterface
1212
from cognee.infrastructure.llm.prompts import read_query_prompt
13+
from cognee.base_config import get_base_config
14+
15+
if MonitoringTool.LANGFUSE:
16+
from langfuse.decorators import observe
1317

1418
class OpenAIAdapter(LLMInterface):
1519
name = "OpenAI"
1620
model: str
1721
api_key: str
1822
api_version: str
19-
23+
2024
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
25+
2126
def __init__(
22-
self,
23-
api_key: str,
24-
endpoint: str,
25-
api_version: str,
26-
model: str,
27-
transcription_model: str,
28-
streaming: bool = False,
27+
self,
28+
api_key: str,
29+
endpoint: str,
30+
api_version: str,
31+
model: str,
32+
transcription_model: str,
33+
streaming: bool = False,
2934
):
3035
self.aclient = instructor.from_litellm(litellm.acompletion)
3136
self.client = instructor.from_litellm(litellm.completion)
@@ -35,45 +40,52 @@ def __init__(
3540
self.endpoint = endpoint
3641
self.api_version = api_version
3742
self.streaming = streaming
43+
base_config = get_base_config()
44+
45+
46+
@observe()
47+
async def acreate_structured_output(self, text_input: str, system_prompt: str,
48+
response_model: Type[BaseModel]) -> BaseModel:
3849

39-
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
4050
"""Generate a response from a user query."""
4151

4252
return await self.aclient.chat.completions.create(
43-
model = self.model,
44-
messages = [{
53+
model=self.model,
54+
messages=[{
4555
"role": "user",
4656
"content": f"""Use the given format to
4757
extract information from the following input: {text_input}. """,
4858
}, {
4959
"role": "system",
5060
"content": system_prompt,
5161
}],
52-
api_key = self.api_key,
53-
api_base = self.endpoint,
54-
api_version = self.api_version,
55-
response_model = response_model,
56-
max_retries = 5,
62+
api_key=self.api_key,
63+
api_base=self.endpoint,
64+
api_version=self.api_version,
65+
response_model=response_model,
66+
max_retries=5,
5767
)
5868

59-
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
69+
@observe
70+
def create_structured_output(self, text_input: str, system_prompt: str,
71+
response_model: Type[BaseModel]) -> BaseModel:
6072
"""Generate a response from a user query."""
6173

6274
return self.client.chat.completions.create(
63-
model = self.model,
64-
messages = [{
75+
model=self.model,
76+
messages=[{
6577
"role": "user",
6678
"content": f"""Use the given format to
6779
extract information from the following input: {text_input}. """,
6880
}, {
6981
"role": "system",
7082
"content": system_prompt,
7183
}],
72-
api_key = self.api_key,
73-
api_base = self.endpoint,
74-
api_version = self.api_version,
75-
response_model = response_model,
76-
max_retries = 5,
84+
api_key=self.api_key,
85+
api_base=self.endpoint,
86+
api_version=self.api_version,
87+
response_model=response_model,
88+
max_retries=5,
7789
)
7890

7991
def create_transcript(self, input):
@@ -86,12 +98,12 @@ def create_transcript(self, input):
8698
# audio_data = audio_file.read()
8799

88100
transcription = litellm.transcription(
89-
model = self.transcription_model,
90-
file = Path(input),
101+
model=self.transcription_model,
102+
file=Path(input),
91103
api_key=self.api_key,
92104
api_base=self.endpoint,
93105
api_version=self.api_version,
94-
max_retries = 5,
106+
max_retries=5,
95107
)
96108

97109
return transcription
@@ -101,8 +113,8 @@ def transcribe_image(self, input) -> BaseModel:
101113
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
102114

103115
return litellm.completion(
104-
model = self.model,
105-
messages = [{
116+
model=self.model,
117+
messages=[{
106118
"role": "user",
107119
"content": [
108120
{
@@ -119,8 +131,8 @@ def transcribe_image(self, input) -> BaseModel:
119131
api_key=self.api_key,
120132
api_base=self.endpoint,
121133
api_version=self.api_version,
122-
max_tokens = 300,
123-
max_retries = 5,
134+
max_tokens=300,
135+
max_retries=5,
124136
)
125137

126138
def show_prompt(self, text_input: str, system_prompt: str) -> str:
@@ -132,4 +144,4 @@ def show_prompt(self, text_input: str, system_prompt: str) -> str:
132144
system_prompt = read_query_prompt(system_prompt)
133145

134146
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
135-
return formatted_prompt
147+
return formatted_prompt

0 commit comments

Comments
 (0)