diff --git a/cognee/modules/retrieval/__init__.py b/cognee/modules/retrieval/__init__.py index 59e0379904..1c29be8cf1 100644 --- a/cognee/modules/retrieval/__init__.py +++ b/cognee/modules/retrieval/__init__.py @@ -1 +1 @@ -from .code_graph_retrieval import code_graph_retrieval +from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py new file mode 100644 index 0000000000..88313b253b --- /dev/null +++ b/cognee/modules/retrieval/base_retriever.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional + + +class BaseRetriever(ABC): + """Base class for all retrieval operations.""" + + @abstractmethod + async def get_context(self, query: str) -> Any: + """Retrieves context based on the query.""" + pass + + @abstractmethod + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + """Generates a response using the query and optional context.""" + pass diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py new file mode 100644 index 0000000000..61427b6f9e --- /dev/null +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -0,0 +1,20 @@ +from typing import Any, Optional + +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.retrieval.base_retriever import BaseRetriever + + +class ChunksRetriever(BaseRetriever): + """Retriever for handling document chunk-based searches.""" + + async def get_context(self, query: str) -> Any: + """Retrieves document chunks context based on the query.""" + vector_engine = get_vector_engine() + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5) + return [result.payload for result in found_chunks] + + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + """Generates a completion using document chunks context.""" + if context is None: + context = await self.get_context(query) + return context diff --git a/cognee/modules/retrieval/code_retriever.py b/cognee/modules/retrieval/code_retriever.py new file mode 100644 index 0000000000..27c601b609 --- /dev/null +++ b/cognee/modules/retrieval/code_retriever.py @@ -0,0 +1,146 @@ +from typing import Any, Optional, List, Dict +import asyncio +import aiofiles +from pydantic import BaseModel + +from cognee.low_level import DataPoint +from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses +from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt + + +class CodeRetriever(BaseRetriever): + """Retriever for handling code-based searches.""" + + class CodeQueryInfo(BaseModel): + """Response model for information extraction from the query""" + + filenames: List[str] = [] + sourcecode: str + + def __init__(self, limit: int = 3): + """Initialize retriever with search parameters.""" + self.limit = limit + self.file_name_collections = ["CodeFile_name"] + self.classes_and_functions_collections = [ + "ClassDefinition_source_code", + "FunctionDefinition_source_code", + ] + + async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo": + """Process the query using LLM to extract file names and source code parts.""" + system_prompt = read_query_prompt("codegraph_retriever_system.txt") + llm_client = get_llm_client() + try: + return await llm_client.acreate_structured_output( + text_input=query, + system_prompt=system_prompt, + response_model=self.CodeQueryInfo, + ) + except Exception as e: + raise RuntimeError("Failed to retrieve structured output from LLM") from e + + async def get_context(self, query: str) -> Any: + """Find relevant code files based on the query.""" + if not query or not isinstance(query, str): + raise ValueError("The query must be a non-empty string.") + + try: + vector_engine = get_vector_engine() + graph_engine = await get_graph_engine() + except Exception as e: + raise RuntimeError("Database initialization error in code_graph_retriever, ") from e + + files_and_codeparts = await self._process_query(query) + + similar_filenames = [] + similar_codepieces = [] + + if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode: + for collection in self.file_name_collections: + search_results_file = await vector_engine.search( + collection, query, limit=self.limit + ) + for res in search_results_file: + similar_filenames.append( + {"id": res.id, "score": res.score, "payload": res.payload} + ) + + for collection in self.classes_and_functions_collections: + search_results_code = await vector_engine.search( + collection, query, limit=self.limit + ) + for res in search_results_code: + similar_codepieces.append( + {"id": res.id, "score": res.score, "payload": res.payload} + ) + else: + for collection in self.file_name_collections: + for file_from_query in files_and_codeparts.filenames: + search_results_file = await vector_engine.search( + collection, file_from_query, limit=self.limit + ) + for res in search_results_file: + similar_filenames.append( + {"id": res.id, "score": res.score, "payload": res.payload} + ) + + for collection in self.classes_and_functions_collections: + search_results_code = await vector_engine.search( + collection, files_and_codeparts.sourcecode, limit=self.limit + ) + for res in search_results_code: + similar_codepieces.append( + {"id": res.id, "score": res.score, "payload": res.payload} + ) + + file_ids = [str(item["id"]) for item in similar_filenames] + code_ids = [str(item["id"]) for item in similar_codepieces] + + relevant_triplets = await asyncio.gather( + *[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids] + ) + + paths = set() + for sublist in relevant_triplets: + for tpl in sublist: + if isinstance(tpl, tuple) and len(tpl) >= 3: + if "file_path" in tpl[0]: + paths.add(tpl[0]["file_path"]) + if "file_path" in tpl[2]: + paths.add(tpl[2]["file_path"]) + + retrieved_files = {} + read_tasks = [] + for file_path in paths: + + async def read_file(fp): + try: + async with aiofiles.open(fp, "r", encoding="utf-8") as f: + retrieved_files[fp] = await f.read() + except Exception as e: + print(f"Error reading {fp}: {e}") + retrieved_files[fp] = "" + + read_tasks.append(read_file(file_path)) + + await asyncio.gather(*read_tasks) + + return [ + { + "name": file_path, + "description": file_path, + "content": retrieved_files[file_path], + } + for file_path in paths + ] + + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + """Returns the code files context.""" + if context is None: + context = await self.get_context(query) + return context diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py new file mode 100644 index 0000000000..f2427f0624 --- /dev/null +++ b/cognee/modules/retrieval/completion_retriever.py @@ -0,0 +1,40 @@ +from typing import Any, Optional + +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.tasks.completion.exceptions import NoRelevantDataFound + + +class CompletionRetriever(BaseRetriever): + """Retriever for handling LLM-based completion searches.""" + + def __init__( + self, + user_prompt_path: str = "context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + ): + """Initialize retriever with optional custom prompt paths.""" + self.user_prompt_path = user_prompt_path + self.system_prompt_path = system_prompt_path + + async def get_context(self, query: str) -> Any: + """Retrieves relevant document chunks as context.""" + vector_engine = get_vector_engine() + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1) + if len(found_chunks) == 0: + raise NoRelevantDataFound + return found_chunks[0].payload["text"] + + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + """Generates an LLM completion using the context.""" + if context is None: + context = await self.get_context(query) + + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py new file mode 100644 index 0000000000..034c1d40d8 --- /dev/null +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -0,0 +1,71 @@ +from typing import Any, Optional + +from cognee.infrastructure.engine import ExtendableDataPoint +from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses +from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.tasks.completion.exceptions import NoRelevantDataFound + + +class GraphCompletionRetriever(BaseRetriever): + """Retriever for handling graph-based completion searches.""" + + def __init__( + self, + user_prompt_path: str = "graph_context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + top_k: int = 5, + ): + """Initialize retriever with prompt paths and search parameters.""" + self.user_prompt_path = user_prompt_path + self.system_prompt_path = system_prompt_path + self.top_k = top_k + + async def resolve_edges_to_text(self, retrieved_edges: list) -> str: + """Converts retrieved graph edges into a human-readable string format.""" + edge_strings = [] + for edge in retrieved_edges: + node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name") + node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name") + edge_string = edge.attributes["relationship_type"] + edge_str = f"{node1_string} -- {edge_string} -- {node2_string}" + edge_strings.append(edge_str) + return "\n---\n".join(edge_strings) + + async def get_triplets(self, query: str) -> list: + """Retrieves relevant graph triplets.""" + subclasses = get_all_subclasses(ExtendableDataPoint) + vector_index_collections = [] + + for subclass in subclasses: + index_fields = subclass.model_fields["metadata"].default.get("index_fields", []) + for field_name in index_fields: + vector_index_collections.append(f"{subclass.__name__}_{field_name}") + + found_triplets = await brute_force_triplet_search( + query, top_k=self.top_k, collections=vector_index_collections or None + ) + + if len(found_triplets) == 0: + raise NoRelevantDataFound + + return found_triplets + + async def get_context(self, query: str) -> Any: + """Retrieves and resolves graph triplets into context.""" + triplets = await self.get_triplets(query) + return await self.resolve_edges_to_text(triplets) + + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + """Generates a completion using graph connections context.""" + if context is None: + context = await self.get_context(query) + + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + return [completion] diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py new file mode 100644 index 0000000000..db83cbee38 --- /dev/null +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -0,0 +1,36 @@ +from typing import Optional + +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever + + +class GraphSummaryCompletionRetriever(GraphCompletionRetriever): + """Retriever for handling graph-based completion searches with summarized context.""" + + def __init__( + self, + user_prompt_path: str = "graph_context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + summarize_prompt_path: str = "summarize_search_results.txt", + top_k: int = 5, + ): + """Initialize retriever with default prompt paths and search parameters.""" + super().__init__( + user_prompt_path=user_prompt_path, + system_prompt_path=system_prompt_path, + top_k=top_k, + ) + self.summarize_prompt_path = summarize_prompt_path + + async def resolve_edges_to_text(self, retrieved_edges: list) -> str: + """Converts retrieved graph edges into a summary without redundancies.""" + direct_text = await super().resolve_edges_to_text(retrieved_edges) + system_prompt = read_query_prompt(self.summarize_prompt_path) + + llm_client = get_llm_client() + return await llm_client.acreate_structured_output( + text_input=direct_text, + system_prompt=system_prompt, + response_model=str, + ) diff --git a/cognee/modules/retrieval/insights_retriever.py b/cognee/modules/retrieval/insights_retriever.py new file mode 100644 index 0000000000..021b39f95f --- /dev/null +++ b/cognee/modules/retrieval/insights_retriever.py @@ -0,0 +1,66 @@ +import asyncio +from typing import Any, Optional + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.retrieval.base_retriever import BaseRetriever + + +class InsightsRetriever(BaseRetriever): + """Retriever for handling graph connection-based insights.""" + + def __init__(self, exploration_levels: int = 1, top_k: int = 5): + """Initialize retriever with exploration levels and search parameters.""" + self.exploration_levels = exploration_levels + self.top_k = top_k + + async def get_context(self, query: str) -> Any: + """Find the neighbours of a given node in the graph.""" + if query is None: + return [] + + node_id = query + graph_engine = await get_graph_engine() + exact_node = await graph_engine.extract_node(node_id) + + if exact_node is not None and "id" in exact_node: + node_connections = await graph_engine.get_connections(str(exact_node["id"])) + else: + vector_engine = get_vector_engine() + results = await asyncio.gather( + vector_engine.search("Entity_name", query_text=query, limit=self.top_k), + vector_engine.search("EntityType_name", query_text=query, limit=self.top_k), + ) + results = [*results[0], *results[1]] + relevant_results = [result for result in results if result.score < 0.5][: self.top_k] + + if len(relevant_results) == 0: + return [] + + node_connections_results = await asyncio.gather( + *[graph_engine.get_connections(result.id) for result in relevant_results] + ) + + node_connections = [] + for neighbours in node_connections_results: + node_connections.extend(neighbours) + + unique_node_connections_map = {} + unique_node_connections = [] + + for node_connection in node_connections: + if "id" not in node_connection[0] or "id" not in node_connection[2]: + continue + + unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}" + if unique_id not in unique_node_connections_map: + unique_node_connections_map[unique_id] = True + unique_node_connections.append(node_connection) + + return unique_node_connections + + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + """Returns the graph connections context.""" + if context is None: + context = await self.get_context(query) + return context diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py new file mode 100644 index 0000000000..7356563e15 --- /dev/null +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -0,0 +1,24 @@ +from typing import Any, Optional + +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.retrieval.base_retriever import BaseRetriever + + +class SummariesRetriever(BaseRetriever): + """Retriever for handling summary-based searches.""" + + def __init__(self, limit: int = 5): + """Initialize retriever with search parameters.""" + self.limit = limit + + async def get_context(self, query: str) -> Any: + """Retrieves summary context based on the query.""" + vector_engine = get_vector_engine() + summaries_results = await vector_engine.search("TextSummary_text", query, limit=self.limit) + return [summary.payload for summary in summaries_results] + + async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + """Generates a completion using summaries context.""" + if context is None: + context = await self.get_context(query) + return context diff --git a/cognee/modules/retrieval/utils/__init__.py b/cognee/modules/retrieval/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cognee/modules/retrieval/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py similarity index 100% rename from cognee/modules/retrieval/brute_force_triplet_search.py rename to cognee/modules/retrieval/utils/brute_force_triplet_search.py diff --git a/cognee/modules/retrieval/code_graph_retrieval.py b/cognee/modules/retrieval/utils/code_graph_retrieval.py similarity index 100% rename from cognee/modules/retrieval/code_graph_retrieval.py rename to cognee/modules/retrieval/utils/code_graph_retrieval.py diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py new file mode 100644 index 0000000000..512193c31c --- /dev/null +++ b/cognee/modules/retrieval/utils/completion.py @@ -0,0 +1,23 @@ +from typing import Optional + +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt + + +async def generate_completion( + query: str, + context: str, + user_prompt_path: str, + system_prompt_path: str, +) -> str: + """Generates a completion using LLM with given context and prompts.""" + args = {"question": query, "context": context} + user_prompt = render_prompt(user_prompt_path, args) + system_prompt = read_query_prompt(system_prompt_path) + + llm_client = get_llm_client() + return await llm_client.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=str, + ) diff --git a/cognee/modules/retrieval/description_to_codepart_search.py b/cognee/modules/retrieval/utils/description_to_codepart_search.py similarity index 100% rename from cognee/modules/retrieval/description_to_codepart_search.py rename to cognee/modules/retrieval/utils/description_to_codepart_search.py diff --git a/cognee/modules/retrieval/utils/run_search_comparisons.py b/cognee/modules/retrieval/utils/run_search_comparisons.py new file mode 100644 index 0000000000..b5785c919c --- /dev/null +++ b/cognee/modules/retrieval/utils/run_search_comparisons.py @@ -0,0 +1,221 @@ +# TODO: delete after merging COG-1365, see COG-1403 +import asyncio +import json +import logging +import os +from typing import Any, Callable, Dict, Type + +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.code_retriever import CodeRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) +from cognee.modules.retrieval.insights_retriever import InsightsRetriever +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval +from cognee.tasks.chunks import query_chunks +from cognee.tasks.completion import ( + query_completion, + graph_query_completion, + graph_query_summary_completion, +) +from cognee.tasks.graph import query_graph_connections +from cognee.tasks.summarization import query_summaries +from examples.python.dynamic_steps_example import main as setup_main + + +CONTEXT_DUMP_DIR = "context_dumps" + +# Define retriever configurations +COMPLETION_RETRIEVERS = [ + { + "name": "completion", + "old_implementation": query_completion, + "retriever_class": CompletionRetriever, + "type": "completion", + }, + { + "name": "graph completion", + "old_implementation": graph_query_completion, + "retriever_class": GraphCompletionRetriever, + "type": "graph_completion", + }, + { + "name": "graph summary completion", + "old_implementation": graph_query_summary_completion, + "retriever_class": GraphSummaryCompletionRetriever, + "type": "graph_summary_completion", + }, +] + +BASIC_RETRIEVERS = [ + { + "name": "summaries search", + "old_implementation": query_summaries, + "retriever_class": SummariesRetriever, + }, + { + "name": "chunks search", + "old_implementation": query_chunks, + "retriever_class": ChunksRetriever, + }, + { + "name": "insights search", + "old_implementation": query_graph_connections, + "retriever_class": InsightsRetriever, + }, + { + "name": "code search", + "old_implementation": code_graph_retrieval, + "retriever_class": CodeRetriever, + }, +] + + +async def compare_completion(old_results: list, new_results: list) -> Dict: + """Compare two lists of completion results and print differences.""" + lengths_match = len(old_results) == len(new_results) + matches = [] + + if lengths_match: + print("Results length match") + matches = [old == new for old, new in zip(old_results, new_results)] + if all(matches): + print("All entries match") + else: + print(f"Differences found at indices: {[i for i, m in enumerate(matches) if not m]}") + print("\nDifferences:") + for i, (old, new) in enumerate(zip(old_results, new_results)): + if old != new: + print(f"\nIndex {i}:") + print("Old:", json.dumps(old, indent=2)) + print("New:", json.dumps(new, indent=2)) + else: + print(f"Results length mismatch: {len(old_results)} vs {len(new_results)}") + print("\nOld results:", json.dumps(old_results, indent=2)) + print("\nNew results:", json.dumps(new_results, indent=2)) + + return { + "old_results": old_results, + "new_results": new_results, + "lengths_match": lengths_match, + "element_matches": matches, + } + + +async def compare_retriever( + query: str, old_implementation: Callable, new_retriever: Any, name: str +) -> Dict: + """Compare old and new retriever implementations.""" + print(f"\nComparing {name}...") + + # Get results from both implementations + old_results = await old_implementation(query) + new_results = await new_retriever.get_completion(query) + + return await compare_completion(old_results, new_results) + + +async def compare_completion_context( + query: str, old_implementation: Callable, retriever_class: Type, name: str, retriever_type: str +) -> Dict: + """Compare context between old completion implementation and new retriever.""" + print(f"\nComparing {name} contexts...") + + # Get context from old implementation with dumping + context_path = f"{CONTEXT_DUMP_DIR}/{retriever_type}_{hash(query)}_context.json" + os.makedirs(CONTEXT_DUMP_DIR, exist_ok=True) + await old_implementation(query, save_context_path=context_path) + + # Get context from new implementation + retriever = retriever_class() + new_context = await retriever.get_context(query) + + # Read dumped context + with open(context_path, "r") as f: + old_context = json.load(f) + + # Compare contexts + contexts_match = old_context == new_context + if contexts_match: + print("Contexts match exactly") + else: + print("Contexts differ:") + print("\nOld context:", json.dumps(old_context, indent=2)) + print("\nNew context:", json.dumps(new_context, indent=2)) + + return { + "old_context": old_context, + "new_context": new_context, + "contexts_match": contexts_match, + } + + +async def main(query: str, comparisons: Dict[str, bool], setup_steps: Dict[str, bool]): + """Run comparison tests for selected retrievers with the given setup configuration.""" + # Ensure retriever is always False in setup steps + setup_steps["retriever"] = False + await setup_main(setup_steps) + + # Compare contexts for completion-based retrievers + for retriever in COMPLETION_RETRIEVERS: + context_key = f"{retriever['type']}_context" + if comparisons.get(context_key, False): + await compare_completion_context( + query=query, + old_implementation=retriever["old_implementation"], + retriever_class=retriever["retriever_class"], + name=retriever["name"], + retriever_type=retriever["type"], + ) + + # Run completion comparisons + for retriever in COMPLETION_RETRIEVERS: + if comparisons.get(retriever["type"], False): + await compare_retriever( + query=query, + old_implementation=retriever["old_implementation"], + new_retriever=retriever["retriever_class"](), + name=retriever["name"], + ) + + # Run basic retriever comparisons + for retriever in BASIC_RETRIEVERS: + retriever_type = retriever["name"].split()[0] + if comparisons.get(retriever_type, False): + await compare_retriever( + query=query, + old_implementation=retriever["old_implementation"], + new_retriever=retriever["retriever_class"](), + name=retriever["name"], + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.ERROR) + + test_query = "Who has experience in data science?" + comparisons = { + # Context comparisons + "completion_context": True, + "graph_completion_context": True, + "graph_summary_completion_context": True, + # Result comparisons + "summaries": True, + "chunks": True, + "insights": True, + "code": False, + "completion": True, + "graph_completion": True, + "graph_summary_completion": True, + } + setup_steps = { + "prune_data": True, + "prune_system": True, + "add_text": True, + "cognify": True, + } + + asyncio.run(main(test_query, comparisons, setup_steps)) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index dbc6eb8e59..7e8f08ae2a 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -3,7 +3,7 @@ from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine.utils import parse_id -from cognee.modules.retrieval.code_graph_retrieval import code_graph_retrieval +from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval from cognee.modules.search.types import SearchType from cognee.modules.storage.utils import JSONEncoder from cognee.modules.users.models import User @@ -44,6 +44,7 @@ async def search( async def specific_search(query_type: SearchType, query: str, user: User) -> list: + # TODO: update after merging COG-1365, see COG-1403 search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: query_summaries, SearchType.INSIGHTS: query_graph_connections, diff --git a/cognee/tasks/chunks/query_chunks.py b/cognee/tasks/chunks/query_chunks.py index 6263b519bf..ed4d15d2d5 100644 --- a/cognee/tasks/chunks/query_chunks.py +++ b/cognee/tasks/chunks/query_chunks.py @@ -1,3 +1,4 @@ +# TODO: delete after merging COG-1365, see COG-1403 from cognee.infrastructure.databases.vector import get_vector_engine diff --git a/cognee/tasks/completion/graph_query_completion.py b/cognee/tasks/completion/graph_query_completion.py index e7778ce88a..c7b781d439 100644 --- a/cognee/tasks/completion/graph_query_completion.py +++ b/cognee/tasks/completion/graph_query_completion.py @@ -1,13 +1,20 @@ +# TODO: delete after merging COG-1365, see COG-1403 +import json +import logging +import os from cognee.infrastructure.engine import ExtendableDataPoint from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.tasks.completion.exceptions import NoRelevantDataFound from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from typing import Callable +logger = logging.getLogger(__name__) + + async def retrieved_edges_to_string(retrieved_edges: list) -> str: """ Converts a list of retrieved graph edges into a human-readable string format. @@ -23,12 +30,16 @@ async def retrieved_edges_to_string(retrieved_edges: list) -> str: return "\n---\n".join(edge_strings) -async def graph_query_completion(query: str, context_resolver: Callable = None) -> list: +async def graph_query_completion( + query: str, context_resolver: Callable = None, save_context_path: str = None +) -> list: """ Executes a query on the graph database and retrieves a relevant completion based on the found data. Parameters: - query (str): The query string to compute. + - context_resolver (Callable): A function to convert retrieved edges to a string. + - save_context_path (str): Path to save the retrieved context. Returns: - list: Answer to the query. @@ -38,7 +49,6 @@ async def graph_query_completion(query: str, context_resolver: Callable = None) - Prompts are dynamically rendered and provided to the LLM for contextual understanding. - Ensure that the LLM client and graph database are properly configured and accessible. """ - subclasses = get_all_subclasses(DataPoint) vector_index_collections = [] @@ -58,9 +68,19 @@ async def graph_query_completion(query: str, context_resolver: Callable = None) if not context_resolver: context_resolver = retrieved_edges_to_string + # Get context and optionally dump it + context = await context_resolver(found_triplets) + if save_context_path: + try: + os.makedirs(os.path.dirname(save_context_path), exist_ok=True) + with open(save_context_path, "w") as f: + json.dump(context, f, indent=2) + except (OSError, TypeError, ValueError) as e: + logger.error(f"Failed to save context to {save_context_path}: {str(e)}") + # Consider whether to raise or continue silently args = { "question": query, - "context": await context_resolver(found_triplets), + "context": context, } user_prompt = render_prompt("graph_context_for_question.txt", args) system_prompt = read_query_prompt("answer_simple_question.txt") diff --git a/cognee/tasks/completion/graph_query_summary_completion.py b/cognee/tasks/completion/graph_query_summary_completion.py index 6839f0cdaa..262148e620 100644 --- a/cognee/tasks/completion/graph_query_summary_completion.py +++ b/cognee/tasks/completion/graph_query_summary_completion.py @@ -1,3 +1,4 @@ +# TODO: delete after merging COG-1365, see COG-1403 from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.tasks.completion.graph_query_completion import ( @@ -22,5 +23,8 @@ async def retrieved_edges_to_summary(retrieved_edges: list) -> str: return summarized_context -async def graph_query_summary_completion(query: str) -> list: - return await graph_query_completion(query, context_resolver=retrieved_edges_to_summary) +async def graph_query_summary_completion(query: str, save_context_path: str = None) -> list: + """Executes a query on the graph database and retrieves a summarized completion with optional context saving.""" + return await graph_query_completion( + query, context_resolver=retrieved_edges_to_summary, save_context_path=save_context_path + ) diff --git a/cognee/tasks/completion/query_completion.py b/cognee/tasks/completion/query_completion.py index 5209bff672..88074c46ad 100644 --- a/cognee/tasks/completion/query_completion.py +++ b/cognee/tasks/completion/query_completion.py @@ -1,16 +1,24 @@ +# TODO: delete after merging COG-1365, see COG-1403 +import json +import logging +import os from cognee.infrastructure.databases.vector import get_vector_engine from cognee.tasks.completion.exceptions import NoRelevantDataFound from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt -async def query_completion(query: str) -> list: +logger = logging.getLogger(__name__) + + +async def query_completion(query: str, save_context_path: str = None) -> list: """ Executes a query against a vector database and computes a relevant response using an LLM. Parameters: - query (str): The query string to compute. + - save_context_path (str): The path to save the context. Returns: - list: Answer to the query. @@ -28,9 +36,19 @@ async def query_completion(query: str) -> list: if len(found_chunks) == 0: raise NoRelevantDataFound + # Get context and optionally dump it + context = found_chunks[0].payload["text"] + if save_context_path: + try: + os.makedirs(os.path.dirname(save_context_path), exist_ok=True) + with open(save_context_path, "w", encoding="utf-8") as f: + json.dump(context, f, indent=2, ensure_ascii=False) + except OSError as e: + logger.error(f"Failed to save context to {save_context_path}: {str(e)}") + # Continue execution as context saving is optional args = { "question": query, - "context": found_chunks[0].payload["text"], + "context": context, } user_prompt = render_prompt("context_for_question.txt", args) system_prompt = read_query_prompt("answer_simple_question.txt") diff --git a/cognee/tasks/graph/query_graph_connections.py b/cognee/tasks/graph/query_graph_connections.py index 1bc7f9d69f..0fe039dab4 100644 --- a/cognee/tasks/graph/query_graph_connections.py +++ b/cognee/tasks/graph/query_graph_connections.py @@ -1,3 +1,4 @@ +# TODO: delete after merging COG-1365, see COG-1403 import asyncio from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine diff --git a/cognee/tasks/summarization/query_summaries.py b/cognee/tasks/summarization/query_summaries.py index 342322d3b2..40ad00d321 100644 --- a/cognee/tasks/summarization/query_summaries.py +++ b/cognee/tasks/summarization/query_summaries.py @@ -1,3 +1,4 @@ +# TODO: delete after merging COG-1365, see COG-1403 from cognee.infrastructure.databases.vector import get_vector_engine diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 124526314c..adef82ac87 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -3,7 +3,7 @@ import pathlib import cognee from cognee.modules.search.types import SearchType -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 9048702ee3..d5f224744f 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -5,7 +5,7 @@ from cognee.modules.data.models import Data from cognee.modules.search.types import SearchType -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.users.methods import get_default_user logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index a4bb298d50..5108f03119 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -3,7 +3,7 @@ import pathlib import cognee from cognee.modules.search.types import SearchType -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index 04ad893e84..a2b55d60ab 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -3,7 +3,7 @@ import pathlib import cognee from cognee.modules.search.types import SearchType -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search logging.basicConfig(level=logging.DEBUG) diff --git a/cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py b/cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py index 5cceade719..35e1ee027d 100644 --- a/cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py +++ b/cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py @@ -13,19 +13,19 @@ async def test_code_description_to_code_part_no_results(): with ( patch( - "cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", + "cognee.modules.retrieval.utils.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine, ), patch( - "cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", + "cognee.modules.retrieval.utils.description_to_codepart_search.get_graph_engine", return_value=AsyncMock(), ), patch( - "cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", + "cognee.modules.retrieval.utils.description_to_codepart_search.CogneeGraph", return_value=AsyncMock(), ), ): - from cognee.modules.retrieval.description_to_codepart_search import ( + from cognee.modules.retrieval.utils.description_to_codepart_search import ( code_description_to_code_part, ) @@ -41,7 +41,7 @@ async def test_code_description_to_code_part_invalid_query(): mock_user = AsyncMock() with pytest.raises(ValueError, match="The query must be a non-empty string."): - from cognee.modules.retrieval.description_to_codepart_search import ( + from cognee.modules.retrieval.utils.description_to_codepart_search import ( code_description_to_code_part, ) @@ -55,7 +55,7 @@ async def test_code_description_to_code_part_invalid_top_k(): mock_user = AsyncMock() with pytest.raises(ValueError, match="top_k must be a positive integer."): - from cognee.modules.retrieval.description_to_codepart_search import ( + from cognee.modules.retrieval.utils.description_to_codepart_search import ( code_description_to_code_part, ) @@ -70,15 +70,15 @@ async def test_code_description_to_code_part_initialization_error(): with ( patch( - "cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", + "cognee.modules.retrieval.utils.description_to_codepart_search.get_vector_engine", side_effect=Exception("Engine init failed"), ), patch( - "cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", + "cognee.modules.retrieval.utils.description_to_codepart_search.get_graph_engine", return_value=AsyncMock(), ), ): - from cognee.modules.retrieval.description_to_codepart_search import ( + from cognee.modules.retrieval.utils.description_to_codepart_search import ( code_description_to_code_part, ) @@ -99,19 +99,19 @@ async def test_code_description_to_code_part_execution_error(): with ( patch( - "cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", + "cognee.modules.retrieval.utils.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine, ), patch( - "cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", + "cognee.modules.retrieval.utils.description_to_codepart_search.get_graph_engine", return_value=AsyncMock(), ), patch( - "cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", + "cognee.modules.retrieval.utils.description_to_codepart_search.CogneeGraph", return_value=AsyncMock(), ), ): - from cognee.modules.retrieval.description_to_codepart_search import ( + from cognee.modules.retrieval.utils.description_to_codepart_search import ( code_description_to_code_part, ) diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index 07036ccd26..511f99b928 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -10,7 +10,7 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt -from cognee.modules.retrieval.description_to_codepart_search import ( +from cognee.modules.retrieval.utils.description_to_codepart_search import ( code_description_to_code_part_search, ) from evals.eval_utils import download_github_repo diff --git a/evals/qa_context_provider_utils.py b/evals/qa_context_provider_utils.py index 6663b1d4a0..7bf6c63c66 100644 --- a/evals/qa_context_provider_utils.py +++ b/evals/qa_context_provider_utils.py @@ -1,7 +1,7 @@ import cognee from cognee.modules.search.types import SearchType from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string from functools import partial from cognee.api.v1.cognify.cognify_v2 import get_default_tasks diff --git a/examples/python/graphiti_example.py b/examples/python/graphiti_example.py index eba9ccea7b..2de8db41e4 100644 --- a/examples/python/graphiti_example.py +++ b/examples/python/graphiti_example.py @@ -11,7 +11,7 @@ from cognee.tasks.temporal_awareness.index_graphiti_objects import ( index_and_transform_graphiti_nodes_and_edges, ) -from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt from cognee.infrastructure.llm.get_llm_client import get_llm_client diff --git a/notebooks/cognee_graphiti_demo.ipynb b/notebooks/cognee_graphiti_demo.ipynb index ebc88c9cf9..96c84254ee 100644 --- a/notebooks/cognee_graphiti_demo.ipynb +++ b/notebooks/cognee_graphiti_demo.ipynb @@ -36,7 +36,7 @@ "from cognee.tasks.temporal_awareness.index_graphiti_objects import (\n", " index_and_transform_graphiti_nodes_and_edges,\n", ")\n", - "from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search\n", + "from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search\n", "from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n", "from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n", "from cognee.infrastructure.llm.get_llm_client import get_llm_client"