Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cognee/modules/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .code_graph_retrieval import code_graph_retrieval
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
16 changes: 16 additions & 0 deletions cognee/modules/retrieval/base_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any, Optional


class BaseRetriever(ABC):
"""Base class for all retrieval operations."""

@abstractmethod
async def get_context(self, query: str) -> Any:
"""Retrieves context based on the query."""
pass

@abstractmethod
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a response using the query and optional context."""
pass
20 changes: 20 additions & 0 deletions cognee/modules/retrieval/chunks_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Optional

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever


class ChunksRetriever(BaseRetriever):
"""Retriever for handling document chunk-based searches."""

async def get_context(self, query: str) -> Any:
"""Retrieves document chunks context based on the query."""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
return [result.payload for result in found_chunks]

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a completion using document chunks context."""
if context is None:
context = await self.get_context(query)
return context
146 changes: 146 additions & 0 deletions cognee/modules/retrieval/code_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Any, Optional, List, Dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code retriever shouldn't be changed as we developed a new one, so be sure when resolving conflicts that this doesn't change that.

import asyncio
import aiofiles
from pydantic import BaseModel

from cognee.low_level import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt


class CodeRetriever(BaseRetriever):
"""Retriever for handling code-based searches."""

class CodeQueryInfo(BaseModel):
"""Response model for information extraction from the query"""

filenames: List[str] = []
sourcecode: str

def __init__(self, limit: int = 3):
"""Initialize retriever with search parameters."""
self.limit = limit
self.file_name_collections = ["CodeFile_name"]
self.classes_and_functions_collections = [
"ClassDefinition_source_code",
"FunctionDefinition_source_code",
]

async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
"""Process the query using LLM to extract file names and source code parts."""
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
llm_client = get_llm_client()
try:
return await llm_client.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=self.CodeQueryInfo,
)
except Exception as e:
raise RuntimeError("Failed to retrieve structured output from LLM") from e

async def get_context(self, query: str) -> Any:
"""Find relevant code files based on the query."""
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")

try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e

files_and_codeparts = await self._process_query(query)

similar_filenames = []
similar_codepieces = []

if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
for collection in self.file_name_collections:
search_results_file = await vector_engine.search(
collection, query, limit=self.limit
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)

for collection in self.classes_and_functions_collections:
search_results_code = await vector_engine.search(
collection, query, limit=self.limit
)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
else:
for collection in self.file_name_collections:
for file_from_query in files_and_codeparts.filenames:
search_results_file = await vector_engine.search(
collection, file_from_query, limit=self.limit
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)

for collection in self.classes_and_functions_collections:
search_results_code = await vector_engine.search(
collection, files_and_codeparts.sourcecode, limit=self.limit
)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
Comment on lines +64 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor to reduce code duplication.

There's significant code duplication between the two branches (when filenames/sourcecode are available vs. when they're not). Consider refactoring to avoid this duplication.

- if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
-     for collection in self.file_name_collections:
-         search_results_file = await vector_engine.search(
-             collection, query, limit=self.limit
-         )
-         for res in search_results_file:
-             similar_filenames.append(
-                 {"id": res.id, "score": res.score, "payload": res.payload}
-             )
-
-     for collection in self.classes_and_functions_collections:
-         search_results_code = await vector_engine.search(
-             collection, query, limit=self.limit
-         )
-         for res in search_results_code:
-             similar_codepieces.append(
-                 {"id": res.id, "score": res.score, "payload": res.payload}
-             )
- else:
-     for collection in self.file_name_collections:
-         for file_from_query in files_and_codeparts.filenames:
-             search_results_file = await vector_engine.search(
-                 collection, file_from_query, limit=self.limit
-             )
-             for res in search_results_file:
-                 similar_filenames.append(
-                     {"id": res.id, "score": res.score, "payload": res.payload}
-                 )
-
-     for collection in self.classes_and_functions_collections:
-         search_results_code = await vector_engine.search(
-             collection, files_and_codeparts.sourcecode, limit=self.limit
-         )
-         for res in search_results_code:
-             similar_codepieces.append(
-                 {"id": res.id, "score": res.score, "payload": res.payload}
-             )
+ # Search for filenames
+ for collection in self.file_name_collections:
+     if files_and_codeparts.filenames:
+         for file_from_query in files_and_codeparts.filenames:
+             search_results_file = await vector_engine.search(
+                 collection, file_from_query, limit=self.limit
+             )
+             for res in search_results_file:
+                 similar_filenames.append(
+                     {"id": res.id, "score": res.score, "payload": res.payload}
+                 )
+     else:
+         search_results_file = await vector_engine.search(
+             collection, query, limit=self.limit
+         )
+         for res in search_results_file:
+             similar_filenames.append(
+                 {"id": res.id, "score": res.score, "payload": res.payload}
+             )
+
+ # Search for code pieces
+ for collection in self.classes_and_functions_collections:
+     search_query = files_and_codeparts.sourcecode if files_and_codeparts.sourcecode else query
+     search_results_code = await vector_engine.search(
+         collection, search_query, limit=self.limit
+     )
+     for res in search_results_code:
+         similar_codepieces.append(
+             {"id": res.id, "score": res.score, "payload": res.payload}
+         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for collection in self.file_name_collections:
search_results_file = await vector_engine.search(
collection, query, limit=self.limit
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
for collection in self.classes_and_functions_collections:
search_results_code = await vector_engine.search(
collection, query, limit=self.limit
)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
else:
for collection in self.file_name_collections:
for file_from_query in files_and_codeparts.filenames:
search_results_file = await vector_engine.search(
collection, file_from_query, limit=self.limit
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
for collection in self.classes_and_functions_collections:
search_results_code = await vector_engine.search(
collection, files_and_codeparts.sourcecode, limit=self.limit
)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
# Search for filenames
for collection in self.file_name_collections:
if files_and_codeparts.filenames:
for file_from_query in files_and_codeparts.filenames:
search_results_file = await vector_engine.search(
collection, file_from_query, limit=self.limit
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
else:
search_results_file = await vector_engine.search(
collection, query, limit=self.limit
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
# Search for code pieces
for collection in self.classes_and_functions_collections:
search_query = files_and_codeparts.sourcecode if files_and_codeparts.sourcecode else query
search_results_code = await vector_engine.search(
collection, search_query, limit=self.limit
)
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)


file_ids = [str(item["id"]) for item in similar_filenames]
code_ids = [str(item["id"]) for item in similar_codepieces]

relevant_triplets = await asyncio.gather(
*[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids]
)
Comment on lines +104 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Consider limiting concurrent graph engine calls.

Multiple graph engine calls are made in parallel without limiting their number, which could potentially overwhelm the database if there are many IDs.


🏁 Script executed:

#!/bin/bash
# Check if there's any rate limiting or connection pooling for graph engine
rg -A 2 -B 2 "get_graph_engine" --type py

Length of output: 25960


Action Required: Limit Concurrent Graph Engine Calls

In cognee/modules/retrieval/code_retriever.py (lines 104-106), the code currently dispatches all calls to graph_engine.get_connections(node_id) concurrently via asyncio.gather. This pattern may lead to overwhelming the database when many IDs are processed simultaneously. Please consider adding concurrency controls—such as using an async semaphore, batching the IDs, or another throttling mechanism—to limit the number of parallel requests. If the graph engine already handles rate limiting or connection pooling internally, please document that behavior to clarify the design intent.


paths = set()
for sublist in relevant_triplets:
for tpl in sublist:
if isinstance(tpl, tuple) and len(tpl) >= 3:
if "file_path" in tpl[0]:
paths.add(tpl[0]["file_path"])
if "file_path" in tpl[2]:
paths.add(tpl[2]["file_path"])

retrieved_files = {}
read_tasks = []
for file_path in paths:

async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
print(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""

read_tasks.append(read_file(file_path))

Comment on lines +121 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Define read_file outside the loop to avoid recreation.

The read_file function is redefined in every loop iteration. Define it once outside the loop to improve performance and readability.

  # Before the loop
+ async def read_file(fp):
+     try:
+         async with aiofiles.open(fp, "r", encoding="utf-8") as f:
+             retrieved_files[fp] = await f.read()
+     except Exception as e:
+         logger.error(f"Error reading {fp}: {e}")
+         retrieved_files[fp] = ""
+
  retrieved_files = {}
  read_tasks = []
  for file_path in paths:
-     async def read_file(fp):
-         try:
-             async with aiofiles.open(fp, "r", encoding="utf-8") as f:
-                 retrieved_files[fp] = await f.read()
-         except Exception as e:
-             print(f"Error reading {fp}: {e}")
-             retrieved_files[fp] = ""
-
      read_tasks.append(read_file(file_path))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
print(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
read_tasks.append(read_file(file_path))
# Before the loop
async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
logger.error(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
retrieved_files = {}
read_tasks = []
for file_path in paths:
read_tasks.append(read_file(file_path))

await asyncio.gather(*read_tasks)
Comment on lines +117 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace print with proper logging for file read errors.

Using print for error reporting is not ideal for production code. Consider using a proper logging mechanism.

+ import logging
+ 
+ # At the top of the file, after imports
+ logger = logging.getLogger(__name__)
+
  # Then in the read_file function
  async def read_file(fp):
      try:
          async with aiofiles.open(fp, "r", encoding="utf-8") as f:
              retrieved_files[fp] = await f.read()
      except Exception as e:
-         print(f"Error reading {fp}: {e}")
+         logger.error(f"Error reading {fp}: {e}")
          retrieved_files[fp] = ""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
retrieved_files = {}
read_tasks = []
for file_path in paths:
async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
print(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
read_tasks.append(read_file(file_path))
await asyncio.gather(*read_tasks)
import logging
logger = logging.getLogger(__name__)
# ... other imports
# The code segment starting at line 117
retrieved_files = {}
read_tasks = []
for file_path in paths:
async def read_file(fp):
try:
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
except Exception as e:
logger.error(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
read_tasks.append(read_file(file_path))
await asyncio.gather(*read_tasks)


return [
{
"name": file_path,
"description": file_path,
"content": retrieved_files[file_path],
}
for file_path in paths
]

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if completion makes sense here. @borisarzentar ?

"""Returns the code files context."""
if context is None:
context = await self.get_context(query)
return context
40 changes: 40 additions & 0 deletions cognee/modules/retrieval/completion_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Optional

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.tasks.completion.exceptions import NoRelevantDataFound


class CompletionRetriever(BaseRetriever):
"""Retriever for handling LLM-based completion searches."""

def __init__(
self,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path

async def get_context(self, query: str) -> Any:
"""Retrieves relevant document chunks as context."""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know now the limit is hardcoded so its just a theoretical question. Shouldn't we outsource these to the user? Maybe not just asking

if len(found_chunks) == 0:
raise NoRelevantDataFound
return found_chunks[0].payload["text"]

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates an LLM completion using the context."""
if context is None:
context = await self.get_context(query)

completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
71 changes: 71 additions & 0 deletions cognee/modules/retrieval/graph_completion_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Any, Optional

from cognee.infrastructure.engine import ExtendableDataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.tasks.completion.exceptions import NoRelevantDataFound


class GraphCompletionRetriever(BaseRetriever):
"""Retriever for handling graph-based completion searches."""

def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k

async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a human-readable string format."""
edge_strings = []
for edge in retrieved_edges:
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name")
edge_string = edge.attributes["relationship_type"]
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
edge_strings.append(edge_str)
return "\n---\n".join(edge_strings)

async def get_triplets(self, query: str) -> list:
"""Retrieves relevant graph triplets."""
subclasses = get_all_subclasses(ExtendableDataPoint)
vector_index_collections = []

for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")

found_triplets = await brute_force_triplet_search(
query, top_k=self.top_k, collections=vector_index_collections or None
)

if len(found_triplets) == 0:
raise NoRelevantDataFound

return found_triplets

async def get_context(self, query: str) -> Any:
"""Retrieves and resolves graph triplets into context."""
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a completion using graph connections context."""
if context is None:
context = await self.get_context(query)

completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
36 changes: 36 additions & 0 deletions cognee/modules/retrieval/graph_summary_completion_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever


class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
"""Retriever for handling graph-based completion searches with summarized context."""

def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt",
top_k: int = 5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
)
self.summarize_prompt_path = summarize_prompt_path

async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a summary without redundancies."""
direct_text = await super().resolve_edges_to_text(retrieved_edges)
system_prompt = read_query_prompt(self.summarize_prompt_path)

llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
text_input=direct_text,
system_prompt=system_prompt,
response_model=str,
)
Loading
Loading