diff --git a/cognee/api/v1/cognify/routers/get_code_pipeline_router.py b/cognee/api/v1/cognify/routers/get_code_pipeline_router.py index 5cdb00f0e7..742a1bb8fb 100644 --- a/cognee/api/v1/cognify/routers/get_code_pipeline_router.py +++ b/cognee/api/v1/cognify/routers/get_code_pipeline_router.py @@ -4,7 +4,7 @@ from fastapi.responses import JSONResponse from cognee.api.DTO import InDTO from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline -from cognee.modules.retrieval import code_graph_retrieval +from cognee.modules.retrieval.code_retriever import CodeRetriever from cognee.modules.storage.utils import JSONEncoder @@ -43,7 +43,8 @@ async def code_pipeline_retrieve(payload: CodePipelineRetrievePayloadDTO): else payload.full_input ) - retrieved_files = await code_graph_retrieval(query) + retriever = CodeRetriever() + retrieved_files = await retriever.get_context(query) return json.dumps(retrieved_files, cls=JSONEncoder) except Exception as error: diff --git a/cognee/modules/retrieval/__init__.py b/cognee/modules/retrieval/__init__.py index 1c29be8cf1..75afb34c86 100644 --- a/cognee/modules/retrieval/__init__.py +++ b/cognee/modules/retrieval/__init__.py @@ -1 +1 @@ -from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval +from cognee.modules.retrieval.code_retriever import CodeRetriever diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py index 88313b253b..5fa39c53f4 100644 --- a/cognee/modules/retrieval/base_retriever.py +++ b/cognee/modules/retrieval/base_retriever.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, Callable class BaseRetriever(ABC): @@ -14,3 +14,8 @@ async def get_context(self, query: str) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """Generates a response using the query and optional context.""" pass + + @classmethod + def as_search(cls) -> Callable: + """Creates a search function from the retriever class.""" + return lambda query: cls().get_completion(query) diff --git a/cognee/modules/retrieval/utils/code_graph_retrieval.py b/cognee/modules/retrieval/utils/code_graph_retrieval.py deleted file mode 100644 index 151a4f732d..0000000000 --- a/cognee/modules/retrieval/utils/code_graph_retrieval.py +++ /dev/null @@ -1,128 +0,0 @@ -import asyncio -import aiofiles - -from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph -from typing import List, Dict, Any -from pydantic import BaseModel -from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.infrastructure.llm.prompts import read_query_prompt - - -class CodeQueryInfo(BaseModel): - """Response model for information extraction from the query""" - - filenames: List[str] = [] - sourcecode: str - - -async def code_graph_retrieval(query: str) -> list[dict[str, Any]]: - if not query or not isinstance(query, str): - raise ValueError("The query must be a non-empty string.") - - file_name_collections = ["CodeFile_name"] - classes_and_functions_collections = [ - "ClassDefinition_source_code", - "FunctionDefinition_source_code", - ] - - try: - vector_engine = get_vector_engine() - graph_engine = await get_graph_engine() - except Exception as e: - raise RuntimeError("Database initialization error in code_graph_retriever, ") from e - - system_prompt = read_query_prompt("codegraph_retriever_system.txt") - - llm_client = get_llm_client() - try: - files_and_codeparts = await llm_client.acreate_structured_output( - text_input=query, - system_prompt=system_prompt, - response_model=CodeQueryInfo, - ) - except Exception as e: - raise RuntimeError("Failed to retrieve structured output from LLM") from e - - similar_filenames = [] - similar_codepieces = [] - - if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode: - for collection in file_name_collections: - search_results_file = await vector_engine.search(collection, query, limit=3) - for res in search_results_file: - similar_filenames.append({"id": res.id, "score": res.score, "payload": res.payload}) - - for collection in classes_and_functions_collections: - search_results_code = await vector_engine.search(collection, query, limit=3) - for res in search_results_code: - similar_codepieces.append( - {"id": res.id, "score": res.score, "payload": res.payload} - ) - - else: - for collection in file_name_collections: - for file_from_query in files_and_codeparts.filenames: - search_results_file = await vector_engine.search( - collection, file_from_query, limit=3 - ) - for res in search_results_file: - similar_filenames.append( - {"id": res.id, "score": res.score, "payload": res.payload} - ) - - for collection in classes_and_functions_collections: - for code_from_query in files_and_codeparts.sourcecode: - search_results_code = await vector_engine.search( - collection, code_from_query, limit=3 - ) - for res in search_results_code: - similar_codepieces.append( - {"id": res.id, "score": res.score, "payload": res.payload} - ) - - file_ids = [str(item["id"]) for item in similar_filenames] - code_ids = [str(item["id"]) for item in similar_codepieces] - - relevant_triplets = await asyncio.gather( - *[graph_engine.get_connections(node_id) for node_id in code_ids + file_ids] - ) - - paths = set() - - for sublist in relevant_triplets: - for tpl in sublist: - if isinstance(tpl, tuple) and len(tpl) >= 3: - if "file_path" in tpl[0]: - paths.add(tpl[0]["file_path"]) - if "file_path" in tpl[2]: # Third tuple element - paths.add(tpl[2]["file_path"]) - - retrieved_files = {} - - read_tasks = [] - for file_path in paths: - - async def read_file(fp): - try: - async with aiofiles.open(fp, "r", encoding="utf-8") as f: - retrieved_files[fp] = await f.read() - except Exception as e: - print(f"Error reading {fp}: {e}") - retrieved_files[fp] = "" - - read_tasks.append(read_file(file_path)) - - await asyncio.gather(*read_tasks) - - result = [ - { - "name": file_path, - "description": file_path, - "content": retrieved_files[file_path], - } - for file_path in paths - ] - - return result diff --git a/cognee/modules/retrieval/utils/run_search_comparisons.py b/cognee/modules/retrieval/utils/run_search_comparisons.py deleted file mode 100644 index b5785c919c..0000000000 --- a/cognee/modules/retrieval/utils/run_search_comparisons.py +++ /dev/null @@ -1,221 +0,0 @@ -# TODO: delete after merging COG-1365, see COG-1403 -import asyncio -import json -import logging -import os -from typing import Any, Callable, Dict, Type - -from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.modules.retrieval.code_retriever import CodeRetriever -from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.graph_summary_completion_retriever import ( - GraphSummaryCompletionRetriever, -) -from cognee.modules.retrieval.insights_retriever import InsightsRetriever -from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval -from cognee.tasks.chunks import query_chunks -from cognee.tasks.completion import ( - query_completion, - graph_query_completion, - graph_query_summary_completion, -) -from cognee.tasks.graph import query_graph_connections -from cognee.tasks.summarization import query_summaries -from examples.python.dynamic_steps_example import main as setup_main - - -CONTEXT_DUMP_DIR = "context_dumps" - -# Define retriever configurations -COMPLETION_RETRIEVERS = [ - { - "name": "completion", - "old_implementation": query_completion, - "retriever_class": CompletionRetriever, - "type": "completion", - }, - { - "name": "graph completion", - "old_implementation": graph_query_completion, - "retriever_class": GraphCompletionRetriever, - "type": "graph_completion", - }, - { - "name": "graph summary completion", - "old_implementation": graph_query_summary_completion, - "retriever_class": GraphSummaryCompletionRetriever, - "type": "graph_summary_completion", - }, -] - -BASIC_RETRIEVERS = [ - { - "name": "summaries search", - "old_implementation": query_summaries, - "retriever_class": SummariesRetriever, - }, - { - "name": "chunks search", - "old_implementation": query_chunks, - "retriever_class": ChunksRetriever, - }, - { - "name": "insights search", - "old_implementation": query_graph_connections, - "retriever_class": InsightsRetriever, - }, - { - "name": "code search", - "old_implementation": code_graph_retrieval, - "retriever_class": CodeRetriever, - }, -] - - -async def compare_completion(old_results: list, new_results: list) -> Dict: - """Compare two lists of completion results and print differences.""" - lengths_match = len(old_results) == len(new_results) - matches = [] - - if lengths_match: - print("Results length match") - matches = [old == new for old, new in zip(old_results, new_results)] - if all(matches): - print("All entries match") - else: - print(f"Differences found at indices: {[i for i, m in enumerate(matches) if not m]}") - print("\nDifferences:") - for i, (old, new) in enumerate(zip(old_results, new_results)): - if old != new: - print(f"\nIndex {i}:") - print("Old:", json.dumps(old, indent=2)) - print("New:", json.dumps(new, indent=2)) - else: - print(f"Results length mismatch: {len(old_results)} vs {len(new_results)}") - print("\nOld results:", json.dumps(old_results, indent=2)) - print("\nNew results:", json.dumps(new_results, indent=2)) - - return { - "old_results": old_results, - "new_results": new_results, - "lengths_match": lengths_match, - "element_matches": matches, - } - - -async def compare_retriever( - query: str, old_implementation: Callable, new_retriever: Any, name: str -) -> Dict: - """Compare old and new retriever implementations.""" - print(f"\nComparing {name}...") - - # Get results from both implementations - old_results = await old_implementation(query) - new_results = await new_retriever.get_completion(query) - - return await compare_completion(old_results, new_results) - - -async def compare_completion_context( - query: str, old_implementation: Callable, retriever_class: Type, name: str, retriever_type: str -) -> Dict: - """Compare context between old completion implementation and new retriever.""" - print(f"\nComparing {name} contexts...") - - # Get context from old implementation with dumping - context_path = f"{CONTEXT_DUMP_DIR}/{retriever_type}_{hash(query)}_context.json" - os.makedirs(CONTEXT_DUMP_DIR, exist_ok=True) - await old_implementation(query, save_context_path=context_path) - - # Get context from new implementation - retriever = retriever_class() - new_context = await retriever.get_context(query) - - # Read dumped context - with open(context_path, "r") as f: - old_context = json.load(f) - - # Compare contexts - contexts_match = old_context == new_context - if contexts_match: - print("Contexts match exactly") - else: - print("Contexts differ:") - print("\nOld context:", json.dumps(old_context, indent=2)) - print("\nNew context:", json.dumps(new_context, indent=2)) - - return { - "old_context": old_context, - "new_context": new_context, - "contexts_match": contexts_match, - } - - -async def main(query: str, comparisons: Dict[str, bool], setup_steps: Dict[str, bool]): - """Run comparison tests for selected retrievers with the given setup configuration.""" - # Ensure retriever is always False in setup steps - setup_steps["retriever"] = False - await setup_main(setup_steps) - - # Compare contexts for completion-based retrievers - for retriever in COMPLETION_RETRIEVERS: - context_key = f"{retriever['type']}_context" - if comparisons.get(context_key, False): - await compare_completion_context( - query=query, - old_implementation=retriever["old_implementation"], - retriever_class=retriever["retriever_class"], - name=retriever["name"], - retriever_type=retriever["type"], - ) - - # Run completion comparisons - for retriever in COMPLETION_RETRIEVERS: - if comparisons.get(retriever["type"], False): - await compare_retriever( - query=query, - old_implementation=retriever["old_implementation"], - new_retriever=retriever["retriever_class"](), - name=retriever["name"], - ) - - # Run basic retriever comparisons - for retriever in BASIC_RETRIEVERS: - retriever_type = retriever["name"].split()[0] - if comparisons.get(retriever_type, False): - await compare_retriever( - query=query, - old_implementation=retriever["old_implementation"], - new_retriever=retriever["retriever_class"](), - name=retriever["name"], - ) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.ERROR) - - test_query = "Who has experience in data science?" - comparisons = { - # Context comparisons - "completion_context": True, - "graph_completion_context": True, - "graph_summary_completion_context": True, - # Result comparisons - "summaries": True, - "chunks": True, - "insights": True, - "code": False, - "completion": True, - "graph_completion": True, - "graph_summary_completion": True, - } - setup_steps = { - "prune_data": True, - "prune_system": True, - "add_text": True, - "cognify": True, - } - - asyncio.run(main(test_query, comparisons, setup_steps)) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 7e8f08ae2a..a88bd815aa 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -3,18 +3,20 @@ from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine.utils import parse_id -from cognee.modules.retrieval.utils.code_graph_retrieval import code_graph_retrieval +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.insights_retriever import InsightsRetriever +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) +from cognee.modules.retrieval.code_retriever import CodeRetriever from cognee.modules.search.types import SearchType from cognee.modules.storage.utils import JSONEncoder from cognee.modules.users.models import User from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.shared.utils import send_telemetry -from cognee.tasks.chunks import query_chunks -from cognee.tasks.graph import query_graph_connections -from cognee.tasks.summarization import query_summaries -from cognee.tasks.completion import query_completion -from cognee.tasks.completion import graph_query_completion -from cognee.tasks.completion import graph_query_summary_completion from ..operations import log_query, log_result @@ -44,15 +46,14 @@ async def search( async def specific_search(query_type: SearchType, query: str, user: User) -> list: - # TODO: update after merging COG-1365, see COG-1403 search_tasks: dict[SearchType, Callable] = { - SearchType.SUMMARIES: query_summaries, - SearchType.INSIGHTS: query_graph_connections, - SearchType.CHUNKS: query_chunks, - SearchType.COMPLETION: query_completion, - SearchType.GRAPH_COMPLETION: graph_query_completion, - SearchType.GRAPH_SUMMARY_COMPLETION: graph_query_summary_completion, - SearchType.CODE: code_graph_retrieval, + SearchType.SUMMARIES: SummariesRetriever.as_search(), + SearchType.INSIGHTS: InsightsRetriever.as_search(), + SearchType.CHUNKS: ChunksRetriever.as_search(), + SearchType.COMPLETION: CompletionRetriever.as_search(), + SearchType.GRAPH_COMPLETION: GraphCompletionRetriever.as_search(), + SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever.as_search(), + SearchType.CODE: CodeRetriever.as_search(), } search_task = search_tasks.get(query_type) diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py index e92658562f..22ce96be83 100644 --- a/cognee/tasks/chunks/__init__.py +++ b/cognee/tasks/chunks/__init__.py @@ -1,4 +1,3 @@ -from .query_chunks import query_chunks from .chunk_by_word import chunk_by_word from .chunk_by_sentence import chunk_by_sentence from .chunk_by_paragraph import chunk_by_paragraph diff --git a/cognee/tasks/chunks/query_chunks.py b/cognee/tasks/chunks/query_chunks.py deleted file mode 100644 index ed4d15d2d5..0000000000 --- a/cognee/tasks/chunks/query_chunks.py +++ /dev/null @@ -1,27 +0,0 @@ -# TODO: delete after merging COG-1365, see COG-1403 -from cognee.infrastructure.databases.vector import get_vector_engine - - -async def query_chunks(query: str) -> list[dict]: - """ - - Queries the vector database to retrieve chunks related to the given query string. - - Parameters: - - query (str): The query string to filter nodes by. - - Returns: - - list(dict): A list of objects providing information about the chunks related to query. - - Notes: - - The function uses the `search` method of the vector engine to find matches. - - Limits the results to the top 5 matching chunks to balance performance and relevance. - - Ensure that the vector database is properly initialized and contains the "DocumentChunk_text" collection. - """ - vector_engine = get_vector_engine() - - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=5) - - chunks = [result.payload for result in found_chunks] - - return chunks diff --git a/cognee/tasks/completion/__init__.py b/cognee/tasks/completion/__init__.py index 76759c3718..93901e0d79 100644 --- a/cognee/tasks/completion/__init__.py +++ b/cognee/tasks/completion/__init__.py @@ -1,3 +1 @@ -from .query_completion import query_completion -from .graph_query_completion import graph_query_completion -from .graph_query_summary_completion import graph_query_summary_completion +from cognee.tasks.completion.exceptions import NoRelevantDataFound diff --git a/cognee/tasks/completion/graph_query_completion.py b/cognee/tasks/completion/graph_query_completion.py deleted file mode 100644 index c7b781d439..0000000000 --- a/cognee/tasks/completion/graph_query_completion.py +++ /dev/null @@ -1,95 +0,0 @@ -# TODO: delete after merging COG-1365, see COG-1403 -import json -import logging -import os -from cognee.infrastructure.engine import ExtendableDataPoint -from cognee.infrastructure.engine.models.DataPoint import DataPoint -from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses -from cognee.tasks.completion.exceptions import NoRelevantDataFound -from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt -from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search -from typing import Callable - - -logger = logging.getLogger(__name__) - - -async def retrieved_edges_to_string(retrieved_edges: list) -> str: - """ - Converts a list of retrieved graph edges into a human-readable string format. - - """ - edge_strings = [] - for edge in retrieved_edges: - node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name") - node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name") - edge_string = edge.attributes["relationship_type"] - edge_str = f"{node1_string} -- {edge_string} -- {node2_string}" - edge_strings.append(edge_str) - return "\n---\n".join(edge_strings) - - -async def graph_query_completion( - query: str, context_resolver: Callable = None, save_context_path: str = None -) -> list: - """ - Executes a query on the graph database and retrieves a relevant completion based on the found data. - - Parameters: - - query (str): The query string to compute. - - context_resolver (Callable): A function to convert retrieved edges to a string. - - save_context_path (str): Path to save the retrieved context. - - Returns: - - list: Answer to the query. - - Notes: - - The `brute_force_triplet_search` is used to retrieve relevant graph data. - - Prompts are dynamically rendered and provided to the LLM for contextual understanding. - - Ensure that the LLM client and graph database are properly configured and accessible. - """ - subclasses = get_all_subclasses(DataPoint) - - vector_index_collections = [] - - for subclass in subclasses: - index_fields = subclass.model_fields["metadata"].default.get("index_fields", []) - for field_name in index_fields: - vector_index_collections.append(f"{subclass.__name__}_{field_name}") - - found_triplets = await brute_force_triplet_search( - query, top_k=5, collections=vector_index_collections or None - ) - - if len(found_triplets) == 0: - raise NoRelevantDataFound - - if not context_resolver: - context_resolver = retrieved_edges_to_string - - # Get context and optionally dump it - context = await context_resolver(found_triplets) - if save_context_path: - try: - os.makedirs(os.path.dirname(save_context_path), exist_ok=True) - with open(save_context_path, "w") as f: - json.dump(context, f, indent=2) - except (OSError, TypeError, ValueError) as e: - logger.error(f"Failed to save context to {save_context_path}: {str(e)}") - # Consider whether to raise or continue silently - args = { - "question": query, - "context": context, - } - user_prompt = render_prompt("graph_context_for_question.txt", args) - system_prompt = read_query_prompt("answer_simple_question.txt") - - llm_client = get_llm_client() - computed_answer = await llm_client.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, - ) - - return [computed_answer] diff --git a/cognee/tasks/completion/graph_query_summary_completion.py b/cognee/tasks/completion/graph_query_summary_completion.py deleted file mode 100644 index 262148e620..0000000000 --- a/cognee/tasks/completion/graph_query_summary_completion.py +++ /dev/null @@ -1,30 +0,0 @@ -# TODO: delete after merging COG-1365, see COG-1403 -from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.infrastructure.llm.prompts import read_query_prompt -from cognee.tasks.completion.graph_query_completion import ( - graph_query_completion, - retrieved_edges_to_string, -) - - -async def retrieved_edges_to_summary(retrieved_edges: list) -> str: - """ - Converts a list of retrieved graph edges into a summary without redundancies. - - """ - edges_string = await retrieved_edges_to_string(retrieved_edges) - system_prompt = read_query_prompt("summarize_search_results.txt") - llm_client = get_llm_client() - summarized_context = await llm_client.acreate_structured_output( - text_input=edges_string, - system_prompt=system_prompt, - response_model=str, - ) - return summarized_context - - -async def graph_query_summary_completion(query: str, save_context_path: str = None) -> list: - """Executes a query on the graph database and retrieves a summarized completion with optional context saving.""" - return await graph_query_completion( - query, context_resolver=retrieved_edges_to_summary, save_context_path=save_context_path - ) diff --git a/cognee/tasks/completion/query_completion.py b/cognee/tasks/completion/query_completion.py deleted file mode 100644 index 88074c46ad..0000000000 --- a/cognee/tasks/completion/query_completion.py +++ /dev/null @@ -1,63 +0,0 @@ -# TODO: delete after merging COG-1365, see COG-1403 -import json -import logging -import os -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.tasks.completion.exceptions import NoRelevantDataFound -from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt - - -logger = logging.getLogger(__name__) - - -async def query_completion(query: str, save_context_path: str = None) -> list: - """ - - Executes a query against a vector database and computes a relevant response using an LLM. - - Parameters: - - query (str): The query string to compute. - - save_context_path (str): The path to save the context. - - Returns: - - list: Answer to the query. - - Notes: - - Limits the search to the top 1 matching chunk for simplicity and relevance. - - Ensure that the vector database and LLM client are properly configured and accessible. - - The response model used for the LLM output is expected to be a string. - - """ - vector_engine = get_vector_engine() - - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=1) - - if len(found_chunks) == 0: - raise NoRelevantDataFound - - # Get context and optionally dump it - context = found_chunks[0].payload["text"] - if save_context_path: - try: - os.makedirs(os.path.dirname(save_context_path), exist_ok=True) - with open(save_context_path, "w", encoding="utf-8") as f: - json.dump(context, f, indent=2, ensure_ascii=False) - except OSError as e: - logger.error(f"Failed to save context to {save_context_path}: {str(e)}") - # Continue execution as context saving is optional - args = { - "question": query, - "context": context, - } - user_prompt = render_prompt("context_for_question.txt", args) - system_prompt = read_query_prompt("answer_simple_question.txt") - - llm_client = get_llm_client() - computed_answer = await llm_client.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, - ) - - return [computed_answer] diff --git a/cognee/tasks/graph/__init__.py b/cognee/tasks/graph/__init__.py index eafc129211..a96f394269 100644 --- a/cognee/tasks/graph/__init__.py +++ b/cognee/tasks/graph/__init__.py @@ -1,3 +1,2 @@ from .extract_graph_from_data import extract_graph_from_data from .extract_graph_from_code import extract_graph_from_code -from .query_graph_connections import query_graph_connections diff --git a/cognee/tasks/graph/query_graph_connections.py b/cognee/tasks/graph/query_graph_connections.py deleted file mode 100644 index 0fe039dab4..0000000000 --- a/cognee/tasks/graph/query_graph_connections.py +++ /dev/null @@ -1,62 +0,0 @@ -# TODO: delete after merging COG-1365, see COG-1403 -import asyncio -from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine - - -async def query_graph_connections(query: str, exploration_levels=1) -> list[(str, str, str)]: - """ - Find the neighbours of a given node in the graph and return formed sentences. - - Parameters: - - query (str): The query string to filter nodes by. - - exploration_levels (int): The number of jumps through edges to perform. - - Returns: - - list[(str, str, str)]: A list containing the source and destination nodes and relationship. - """ - if query is None: - return [] - - node_id = query - - graph_engine = await get_graph_engine() - - exact_node = await graph_engine.extract_node(node_id) - - if exact_node is not None and "id" in exact_node: - node_connections = await graph_engine.get_connections(str(exact_node["id"])) - else: - vector_engine = get_vector_engine() - results = await asyncio.gather( - vector_engine.search("Entity_name", query_text=query, limit=5), - vector_engine.search("EntityType_name", query_text=query, limit=5), - ) - results = [*results[0], *results[1]] - relevant_results = [result for result in results if result.score < 0.5][:5] - - if len(relevant_results) == 0: - return [] - - node_connections_results = await asyncio.gather( - *[graph_engine.get_connections(result.id) for result in relevant_results] - ) - - node_connections = [] - for neighbours in node_connections_results: - node_connections.extend(neighbours) - - unique_node_connections_map = {} - unique_node_connections = [] - for node_connection in node_connections: - if "id" not in node_connection[0] or "id" not in node_connection[2]: - continue - - unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}" - - if unique_id not in unique_node_connections_map: - unique_node_connections_map[unique_id] = True - - unique_node_connections.append(node_connection) - - return unique_node_connections diff --git a/cognee/tasks/summarization/__init__.py b/cognee/tasks/summarization/__init__.py index a9a330cb2a..1d50848ade 100644 --- a/cognee/tasks/summarization/__init__.py +++ b/cognee/tasks/summarization/__init__.py @@ -1,3 +1,2 @@ -from .query_summaries import query_summaries from .summarize_code import summarize_code from .summarize_text import summarize_text diff --git a/cognee/tasks/summarization/query_summaries.py b/cognee/tasks/summarization/query_summaries.py deleted file mode 100644 index 40ad00d321..0000000000 --- a/cognee/tasks/summarization/query_summaries.py +++ /dev/null @@ -1,19 +0,0 @@ -# TODO: delete after merging COG-1365, see COG-1403 -from cognee.infrastructure.databases.vector import get_vector_engine - - -async def query_summaries(query: str) -> list: - """ - Parameters: - - query (str): The query string to filter summaries by. - - Returns: - - list[str, UUID]: A list of objects providing information about the summaries related to query. - """ - vector_engine = get_vector_engine() - - summaries_results = await vector_engine.search("TextSummary_text", query, limit=5) - - summaries = [summary.payload for summary in summaries_results] - - return summaries diff --git a/evals/qa_context_provider_utils.py b/evals/qa_context_provider_utils.py index 7bf6c63c66..bba98f052c 100644 --- a/evals/qa_context_provider_utils.py +++ b/evals/qa_context_provider_utils.py @@ -2,7 +2,7 @@ from cognee.modules.search.types import SearchType from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search -from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from functools import partial from cognee.api.v1.cognify.cognify_v2 import get_default_tasks import logging @@ -122,7 +122,8 @@ async def get_context_with_brute_force_triplet_search(instance: dict) -> str: found_triplets = await brute_force_triplet_search(instance["question"], top_k=5) - search_results_str = await retrieved_edges_to_string(found_triplets) + retriever = GraphCompletionRetriever() + search_results_str = await retriever.resolve_edges_to_text(found_triplets) return search_results_str diff --git a/examples/python/graphiti_example.py b/examples/python/graphiti_example.py index 2de8db41e4..a4729e86e9 100644 --- a/examples/python/graphiti_example.py +++ b/examples/python/graphiti_example.py @@ -12,7 +12,7 @@ index_and_transform_graphiti_nodes_and_edges, ) from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search -from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt from cognee.infrastructure.llm.get_llm_client import get_llm_client @@ -49,9 +49,12 @@ async def main(): collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"], ) + retriever = GraphCompletionRetriever() + context = await retriever.resolve_edges_to_text(triplets) + args = { "question": query, - "context": await retrieved_edges_to_string(triplets), + "context": context, } user_prompt = render_prompt("graph_context_for_question.txt", args) diff --git a/notebooks/cognee_graphiti_demo.ipynb b/notebooks/cognee_graphiti_demo.ipynb index 96c84254ee..6907c44b42 100644 --- a/notebooks/cognee_graphiti_demo.ipynb +++ b/notebooks/cognee_graphiti_demo.ipynb @@ -37,7 +37,7 @@ " index_and_transform_graphiti_nodes_and_edges,\n", ")\n", "from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search\n", - "from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n", + "from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever\n", "from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n", "from cognee.infrastructure.llm.get_llm_client import get_llm_client" ] @@ -186,7 +186,8 @@ ")\n", "\n", "# Step 3: Preparing the Context for the LLM\n", - "context = await retrieved_edges_to_string(triplets)\n", + "retriever = GraphCompletionRetriever()\n", + "context = await retriever.resolve_edges_to_text(triplets)\n", "\n", "args = {\"question\": query, \"context\": context}\n", "\n",