Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cognee/modules/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .code_graph_retrieval import code_graph_retrieval
from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval
16 changes: 16 additions & 0 deletions cognee/modules/retrieval/base_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any, Optional


class BaseRetriever(ABC):
"""Base class for all retrieval operations."""

@abstractmethod
async def get_context(self, query: str) -> Any:
"""Retrieves context based on the query."""
pass

@abstractmethod
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a response using the query and optional context."""
pass
20 changes: 20 additions & 0 deletions cognee/modules/retrieval/chunks_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Optional

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever


class ChunksRetriever(BaseRetriever):
"""Retriever for handling document chunk-based searches."""

async def get_context(self, query: str) -> Any:
"""Retrieves document chunks context based on the query."""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5)
return [result.payload for result in found_chunks]

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a completion using document chunks context."""
if context is None:
context = await self.get_context(query)
return context
57 changes: 57 additions & 0 deletions cognee/modules/retrieval/code_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Optional

from cognee.low_level import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search


class CodeRetriever(BaseRetriever):
"""Retriever for handling code-based searches."""

def __init__(self, top_k: int = 5):
"""Initialize retriever with search parameters."""
self.top_k = top_k

async def get_context(self, query: str) -> Any:
"""Find relevant code files based on the query."""
subclasses = get_all_subclasses(DataPoint)
vector_index_collections = []

for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")

found_triplets = await brute_force_triplet_search(
query,
top_k=self.top_k,
collections=vector_index_collections or None,
properties_to_project=["id", "file_path", "source_code"],
)

retrieved_files = {}
for triplet in found_triplets:
if triplet.node1.attributes["source_code"]:
retrieved_files[triplet.node1.attributes["file_path"]] = triplet.node1.attributes[
"source_code"
]
if triplet.node2.attributes["source_code"]:
retrieved_files[triplet.node2.attributes["file_path"]] = triplet.node2.attributes[
"source_code"
]

return [
{
"name": file_path,
"description": file_path,
"content": source_code,
}
for file_path, source_code in retrieved_files.items()
]

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if completion makes sense here. @borisarzentar ?

"""Returns the code files context."""
if context is None:
context = await self.get_context(query)
return context
40 changes: 40 additions & 0 deletions cognee/modules/retrieval/completion_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Optional

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.tasks.completion.exceptions import NoRelevantDataFound


class CompletionRetriever(BaseRetriever):
"""Retriever for handling LLM-based completion searches."""

def __init__(
self,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path

async def get_context(self, query: str) -> Any:
"""Retrieves relevant document chunks as context."""
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know now the limit is hardcoded so its just a theoretical question. Shouldn't we outsource these to the user? Maybe not just asking

if len(found_chunks) == 0:
raise NoRelevantDataFound
return found_chunks[0].payload["text"]

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates an LLM completion using the context."""
if context is None:
context = await self.get_context(query)

completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
71 changes: 71 additions & 0 deletions cognee/modules/retrieval/graph_completion_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Any, Optional

from cognee.infrastructure.engine import ExtendableDataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.tasks.completion.exceptions import NoRelevantDataFound


class GraphCompletionRetriever(BaseRetriever):
"""Retriever for handling graph-based completion searches."""

def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k

async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a human-readable string format."""
edge_strings = []
for edge in retrieved_edges:
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name")
edge_string = edge.attributes["relationship_type"]
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
edge_strings.append(edge_str)
return "\n---\n".join(edge_strings)

async def get_triplets(self, query: str) -> list:
"""Retrieves relevant graph triplets."""
subclasses = get_all_subclasses(ExtendableDataPoint)
vector_index_collections = []

for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")

found_triplets = await brute_force_triplet_search(
query, top_k=self.top_k, collections=vector_index_collections or None
)

if len(found_triplets) == 0:
raise NoRelevantDataFound

return found_triplets

async def get_context(self, query: str) -> Any:
"""Retrieves and resolves graph triplets into context."""
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a completion using graph connections context."""
if context is None:
context = await self.get_context(query)

completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
36 changes: 36 additions & 0 deletions cognee/modules/retrieval/graph_summary_completion_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Optional

from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever


class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
"""Retriever for handling graph-based completion searches with summarized context."""

def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt",
top_k: int = 5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
)
self.summarize_prompt_path = summarize_prompt_path

async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a summary without redundancies."""
direct_text = await super().resolve_edges_to_text(retrieved_edges)
system_prompt = read_query_prompt(self.summarize_prompt_path)

llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
text_input=direct_text,
system_prompt=system_prompt,
response_model=str,
)
66 changes: 66 additions & 0 deletions cognee/modules/retrieval/insights_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import asyncio
from typing import Any, Optional

from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever


class InsightsRetriever(BaseRetriever):
"""Retriever for handling graph connection-based insights."""

def __init__(self, exploration_levels: int = 1, top_k: int = 5):
"""Initialize retriever with exploration levels and search parameters."""
self.exploration_levels = exploration_levels
self.top_k = top_k

async def get_context(self, query: str) -> Any:
"""Find the neighbours of a given node in the graph."""
if query is None:
return []

node_id = query
graph_engine = await get_graph_engine()
exact_node = await graph_engine.extract_node(node_id)

if exact_node is not None and "id" in exact_node:
node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("Entity_name", query_text=query, limit=self.top_k),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know its not your code, but it would be nice to make collection names dynamic for insights too. In this way they fail if the LLM doesn't extract anythign

vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][: self.top_k]

if len(relevant_results) == 0:
return []

node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(result.id) for result in relevant_results]
)

node_connections = []
for neighbours in node_connections_results:
node_connections.extend(neighbours)

unique_node_connections_map = {}
unique_node_connections = []

for node_connection in node_connections:
if "id" not in node_connection[0] or "id" not in node_connection[2]:
continue

unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)

return unique_node_connections

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns the graph connections context."""
if context is None:
context = await self.get_context(query)
return context
24 changes: 24 additions & 0 deletions cognee/modules/retrieval/summaries_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, Optional

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever


class SummariesRetriever(BaseRetriever):
"""Retriever for handling summary-based searches."""

def __init__(self, limit: int = 5):
"""Initialize retriever with search parameters."""
self.limit = limit

async def get_context(self, query: str) -> Any:
"""Retrieves summary context based on the query."""
vector_engine = get_vector_engine()
summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit)
return [summary.payload for summary in summaries_results]

async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a completion using summaries context."""
if context is None:
context = await self.get_context(query)
return context
Empty file.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: delete after merging COG-1365, see COG-1403
from cognee.low_level import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from .brute_force_triplet_search import brute_force_triplet_search
Expand Down
23 changes: 23 additions & 0 deletions cognee/modules/retrieval/utils/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt


async def generate_completion(
query: str,
context: str,
user_prompt_path: str,
system_prompt_path: str,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
user_prompt = render_prompt(user_prompt_path, args)
system_prompt = read_query_prompt(system_prompt_path)

llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
Loading
Loading