Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cognee/api/v1/cognify/code_graph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,17 @@
expand_dependency_graph,
get_repo_file_dependencies)
from cognee.tasks.storage import add_data_points

from cognee.base_config import get_base_config
from cognee.shared.data_models import MonitoringTool

monitoring = get_base_config().monitoring_tool
if monitoring == MonitoringTool.LANGFUSE:
from langfuse.decorators import observe

from cognee.tasks.summarization import summarize_code


logger = logging.getLogger("code_graph_pipeline")

update_status_lock = asyncio.Lock()
Expand Down Expand Up @@ -62,7 +71,7 @@ async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User

return await asyncio.gather(*awaitables)


@observe
async def run_pipeline(dataset: Dataset, user: User):
'''DEPRECATED: Use `run_code_graph_pipeline` instead. This function will be removed.'''
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
Expand Down
4 changes: 3 additions & 1 deletion cognee/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ class BaseConfig(BaseSettings):
monitoring_tool: object = MonitoringTool.LANGFUSE
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")

langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

def to_dict(self) -> dict:
Expand Down
78 changes: 45 additions & 33 deletions cognee/infrastructure/llm/openai/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,31 @@
import litellm
import instructor
from pydantic import BaseModel

from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.base_config import get_base_config

if MonitoringTool.LANGFUSE:
from langfuse.decorators import observe

class OpenAIAdapter(LLMInterface):
name = "OpenAI"
model: str
api_key: str
api_version: str

"""Adapter for OpenAI's GPT-3, GPT=4 API"""

def __init__(
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
streaming: bool = False,
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
streaming: bool = False,
):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
Expand All @@ -35,45 +40,52 @@ def __init__(
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
base_config = get_base_config()


Comment on lines +43 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused variable.

The base_config variable is assigned but never used.

-        base_config = get_base_config()
-
-
🧰 Tools
🪛 Ruff (0.8.2)

43-43: Local variable base_config is assigned to but never used

Remove assignment to unused variable base_config

(F841)

@observe()
async def acreate_structured_output(self, text_input: str, system_prompt: str,
response_model: Type[BaseModel]) -> BaseModel:

async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""

return await self.aclient.chat.completions.create(
model = self.model,
messages = [{
model=self.model,
messages=[{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
max_retries=5,
)

def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
@observe
def create_structured_output(self, text_input: str, system_prompt: str,
response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""

return self.client.chat.completions.create(
model = self.model,
messages = [{
model=self.model,
messages=[{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
max_retries=5,
)

def create_transcript(self, input):
Expand All @@ -86,12 +98,12 @@ def create_transcript(self, input):
# audio_data = audio_file.read()

transcription = litellm.transcription(
model = self.transcription_model,
file = Path(input),
model=self.transcription_model,
file=Path(input),
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_retries = 5,
max_retries=5,
)

return transcription
Expand All @@ -101,8 +113,8 @@ def transcribe_image(self, input) -> BaseModel:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')

return litellm.completion(
model = self.model,
messages = [{
model=self.model,
messages=[{
"role": "user",
"content": [
{
Expand All @@ -119,8 +131,8 @@ def transcribe_image(self, input) -> BaseModel:
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_tokens = 300,
max_retries = 5,
max_tokens=300,
max_retries=5,
)

def show_prompt(self, text_input: str, system_prompt: str) -> str:
Expand All @@ -132,4 +144,4 @@ def show_prompt(self, text_input: str, system_prompt: str) -> str:
system_prompt = read_query_prompt(system_prompt)

formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
return formatted_prompt
return formatted_prompt
Loading
Loading