-
Notifications
You must be signed in to change notification settings - Fork 966
Feat/cog 1365 unify retrievers #572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
468de67
d789dd0
49c2355
7619df2
5a5eb5e
8f0cbee
beacdea
4b71081
7631b11
62f8ac3
58c7eaf
c07cf22
2ef174a
5910fb7
2f70de4
3d0b839
4903d7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| from .code_graph_retrieval import code_graph_retrieval | ||
| from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Optional | ||
|
|
||
|
|
||
| class BaseRetriever(ABC): | ||
| """Base class for all retrieval operations.""" | ||
|
|
||
| @abstractmethod | ||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves context based on the query.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates a response using the query and optional context.""" | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.databases.vector import get_vector_engine | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
|
|
||
|
|
||
| class ChunksRetriever(BaseRetriever): | ||
| """Retriever for handling document chunk-based searches.""" | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves document chunks context based on the query.""" | ||
| vector_engine = get_vector_engine() | ||
| found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5) | ||
| return [result.payload for result in found_chunks] | ||
|
|
||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates a completion using document chunks context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
| return context | ||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
borisarzentar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.low_level import DataPoint | ||
| from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
| from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search | ||
|
|
||
|
|
||
| class CodeRetriever(BaseRetriever): | ||
| """Retriever for handling code-based searches.""" | ||
|
|
||
| def __init__(self, top_k: int = 5): | ||
| """Initialize retriever with search parameters.""" | ||
| self.top_k = top_k | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Find relevant code files based on the query.""" | ||
| subclasses = get_all_subclasses(DataPoint) | ||
| vector_index_collections = [] | ||
|
|
||
| for subclass in subclasses: | ||
| index_fields = subclass.model_fields["metadata"].default.get("index_fields", []) | ||
| for field_name in index_fields: | ||
| vector_index_collections.append(f"{subclass.__name__}_{field_name}") | ||
|
|
||
| found_triplets = await brute_force_triplet_search( | ||
| query, | ||
| top_k=self.top_k, | ||
| collections=vector_index_collections or None, | ||
| properties_to_project=["id", "file_path", "source_code"], | ||
| ) | ||
|
|
||
| retrieved_files = {} | ||
| for triplet in found_triplets: | ||
| if triplet.node1.attributes["source_code"]: | ||
| retrieved_files[triplet.node1.attributes["file_path"]] = triplet.node1.attributes[ | ||
| "source_code" | ||
| ] | ||
| if triplet.node2.attributes["source_code"]: | ||
| retrieved_files[triplet.node2.attributes["file_path"]] = triplet.node2.attributes[ | ||
| "source_code" | ||
| ] | ||
|
|
||
| return [ | ||
| { | ||
| "name": file_path, | ||
| "description": file_path, | ||
| "content": source_code, | ||
| } | ||
| for file_path, source_code in retrieved_files.items() | ||
| ] | ||
|
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont know if completion makes sense here. @borisarzentar ? |
||
| """Returns the code files context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
| return context | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.databases.vector import get_vector_engine | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
| from cognee.modules.retrieval.utils.completion import generate_completion | ||
| from cognee.tasks.completion.exceptions import NoRelevantDataFound | ||
|
|
||
|
|
||
| class CompletionRetriever(BaseRetriever): | ||
| """Retriever for handling LLM-based completion searches.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_prompt_path: str = "context_for_question.txt", | ||
| system_prompt_path: str = "answer_simple_question.txt", | ||
| ): | ||
| """Initialize retriever with optional custom prompt paths.""" | ||
| self.user_prompt_path = user_prompt_path | ||
| self.system_prompt_path = system_prompt_path | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves relevant document chunks as context.""" | ||
| vector_engine = get_vector_engine() | ||
| found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know now the limit is hardcoded so its just a theoretical question. Shouldn't we outsource these to the user? Maybe not just asking |
||
| if len(found_chunks) == 0: | ||
| raise NoRelevantDataFound | ||
| return found_chunks[0].payload["text"] | ||
|
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates an LLM completion using the context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
|
|
||
| completion = await generate_completion( | ||
| query=query, | ||
| context=context, | ||
| user_prompt_path=self.user_prompt_path, | ||
| system_prompt_path=self.system_prompt_path, | ||
| ) | ||
| return [completion] | ||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.engine import ExtendableDataPoint | ||
| from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
| from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search | ||
| from cognee.modules.retrieval.utils.completion import generate_completion | ||
| from cognee.tasks.completion.exceptions import NoRelevantDataFound | ||
|
|
||
|
|
||
| class GraphCompletionRetriever(BaseRetriever): | ||
| """Retriever for handling graph-based completion searches.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_prompt_path: str = "graph_context_for_question.txt", | ||
| system_prompt_path: str = "answer_simple_question.txt", | ||
| top_k: int = 5, | ||
| ): | ||
| """Initialize retriever with prompt paths and search parameters.""" | ||
| self.user_prompt_path = user_prompt_path | ||
| self.system_prompt_path = system_prompt_path | ||
| self.top_k = top_k | ||
|
|
||
| async def resolve_edges_to_text(self, retrieved_edges: list) -> str: | ||
| """Converts retrieved graph edges into a human-readable string format.""" | ||
| edge_strings = [] | ||
| for edge in retrieved_edges: | ||
| node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name") | ||
| node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name") | ||
| edge_string = edge.attributes["relationship_type"] | ||
| edge_str = f"{node1_string} -- {edge_string} -- {node2_string}" | ||
| edge_strings.append(edge_str) | ||
| return "\n---\n".join(edge_strings) | ||
|
|
||
| async def get_triplets(self, query: str) -> list: | ||
| """Retrieves relevant graph triplets.""" | ||
| subclasses = get_all_subclasses(ExtendableDataPoint) | ||
| vector_index_collections = [] | ||
|
|
||
| for subclass in subclasses: | ||
| index_fields = subclass.model_fields["metadata"].default.get("index_fields", []) | ||
| for field_name in index_fields: | ||
| vector_index_collections.append(f"{subclass.__name__}_{field_name}") | ||
|
|
||
| found_triplets = await brute_force_triplet_search( | ||
| query, top_k=self.top_k, collections=vector_index_collections or None | ||
| ) | ||
|
|
||
| if len(found_triplets) == 0: | ||
| raise NoRelevantDataFound | ||
|
|
||
| return found_triplets | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves and resolves graph triplets into context.""" | ||
| triplets = await self.get_triplets(query) | ||
| return await self.resolve_edges_to_text(triplets) | ||
|
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates a completion using graph connections context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
|
|
||
| completion = await generate_completion( | ||
| query=query, | ||
| context=context, | ||
| user_prompt_path=self.user_prompt_path, | ||
| system_prompt_path=self.system_prompt_path, | ||
| ) | ||
| return [completion] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| from typing import Optional | ||
|
|
||
| from cognee.infrastructure.llm.get_llm_client import get_llm_client | ||
| from cognee.infrastructure.llm.prompts import read_query_prompt | ||
| from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever | ||
|
|
||
|
|
||
| class GraphSummaryCompletionRetriever(GraphCompletionRetriever): | ||
| """Retriever for handling graph-based completion searches with summarized context.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_prompt_path: str = "graph_context_for_question.txt", | ||
| system_prompt_path: str = "answer_simple_question.txt", | ||
| summarize_prompt_path: str = "summarize_search_results.txt", | ||
| top_k: int = 5, | ||
| ): | ||
| """Initialize retriever with default prompt paths and search parameters.""" | ||
| super().__init__( | ||
| user_prompt_path=user_prompt_path, | ||
| system_prompt_path=system_prompt_path, | ||
| top_k=top_k, | ||
| ) | ||
| self.summarize_prompt_path = summarize_prompt_path | ||
|
|
||
| async def resolve_edges_to_text(self, retrieved_edges: list) -> str: | ||
| """Converts retrieved graph edges into a summary without redundancies.""" | ||
| direct_text = await super().resolve_edges_to_text(retrieved_edges) | ||
| system_prompt = read_query_prompt(self.summarize_prompt_path) | ||
|
|
||
| llm_client = get_llm_client() | ||
| return await llm_client.acreate_structured_output( | ||
| text_input=direct_text, | ||
| system_prompt=system_prompt, | ||
| response_model=str, | ||
| ) | ||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| import asyncio | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.databases.graph import get_graph_engine | ||
| from cognee.infrastructure.databases.vector import get_vector_engine | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
|
|
||
|
|
||
| class InsightsRetriever(BaseRetriever): | ||
| """Retriever for handling graph connection-based insights.""" | ||
|
|
||
| def __init__(self, exploration_levels: int = 1, top_k: int = 5): | ||
| """Initialize retriever with exploration levels and search parameters.""" | ||
| self.exploration_levels = exploration_levels | ||
| self.top_k = top_k | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Find the neighbours of a given node in the graph.""" | ||
| if query is None: | ||
| return [] | ||
|
|
||
| node_id = query | ||
| graph_engine = await get_graph_engine() | ||
| exact_node = await graph_engine.extract_node(node_id) | ||
|
|
||
| if exact_node is not None and "id" in exact_node: | ||
| node_connections = await graph_engine.get_connections(str(exact_node["id"])) | ||
| else: | ||
| vector_engine = get_vector_engine() | ||
| results = await asyncio.gather( | ||
| vector_engine.search("Entity_name", query_text=query, limit=self.top_k), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know its not your code, but it would be nice to make collection names dynamic for insights too. In this way they fail if the LLM doesn't extract anythign |
||
| vector_engine.search("EntityType_name", query_text=query, limit=self.top_k), | ||
| ) | ||
| results = [*results[0], *results[1]] | ||
| relevant_results = [result for result in results if result.score < 0.5][: self.top_k] | ||
|
|
||
| if len(relevant_results) == 0: | ||
| return [] | ||
|
|
||
| node_connections_results = await asyncio.gather( | ||
| *[graph_engine.get_connections(result.id) for result in relevant_results] | ||
| ) | ||
|
|
||
| node_connections = [] | ||
| for neighbours in node_connections_results: | ||
| node_connections.extend(neighbours) | ||
|
|
||
| unique_node_connections_map = {} | ||
| unique_node_connections = [] | ||
|
|
||
| for node_connection in node_connections: | ||
| if "id" not in node_connection[0] or "id" not in node_connection[2]: | ||
| continue | ||
|
|
||
| unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}" | ||
| if unique_id not in unique_node_connections_map: | ||
| unique_node_connections_map[unique_id] = True | ||
| unique_node_connections.append(node_connection) | ||
|
|
||
| return unique_node_connections | ||
|
|
||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Returns the graph connections context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
| return context | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from cognee.infrastructure.databases.vector import get_vector_engine | ||
| from cognee.modules.retrieval.base_retriever import BaseRetriever | ||
|
|
||
|
|
||
| class SummariesRetriever(BaseRetriever): | ||
| """Retriever for handling summary-based searches.""" | ||
|
|
||
| def __init__(self, limit: int = 5): | ||
| """Initialize retriever with search parameters.""" | ||
| self.limit = limit | ||
|
|
||
| async def get_context(self, query: str) -> Any: | ||
| """Retrieves summary context based on the query.""" | ||
| vector_engine = get_vector_engine() | ||
| summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit) | ||
| return [summary.payload for summary in summaries_results] | ||
|
|
||
lxobr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: | ||
| """Generates a completion using summaries context.""" | ||
| if context is None: | ||
| context = await self.get_context(query) | ||
| return context | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| from typing import Optional | ||
|
|
||
| from cognee.infrastructure.llm.get_llm_client import get_llm_client | ||
| from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt | ||
|
|
||
|
|
||
| async def generate_completion( | ||
| query: str, | ||
| context: str, | ||
| user_prompt_path: str, | ||
| system_prompt_path: str, | ||
| ) -> str: | ||
| """Generates a completion using LLM with given context and prompts.""" | ||
| args = {"question": query, "context": context} | ||
| user_prompt = render_prompt(user_prompt_path, args) | ||
| system_prompt = read_query_prompt(system_prompt_path) | ||
|
|
||
| llm_client = get_llm_client() | ||
| return await llm_client.acreate_structured_output( | ||
| text_input=user_prompt, | ||
| system_prompt=system_prompt, | ||
| response_model=str, | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.