diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 6875e41ebe..6f166657e3 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -1,5 +1,8 @@ from typing import List, Dict, Any from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_summary_completion_retriever import ( @@ -12,6 +15,7 @@ retriever_options: Dict[str, Any] = { "cognee_graph_completion": GraphCompletionRetriever, "cognee_graph_completion_cot": GraphCompletionCotRetriever, + "cognee_graph_completion_context_extension": GraphCompletionContextExtensionRetriever, "cognee_completion": CompletionRetriever, "graph_summary_completion": GraphSummaryCompletionRetriever, } diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py index 1a0825f752..7baf58cf95 100644 --- a/cognee/eval_framework/eval_config.py +++ b/cognee/eval_framework/eval_config.py @@ -14,9 +14,7 @@ class EvalConfig(BaseSettings): # Question answering params answering_questions: bool = True - qa_engine: str = ( - "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' - ) + qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' # Evaluation params evaluating_answers: bool = True diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py new file mode 100644 index 0000000000..26364b2ff0 --- /dev/null +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -0,0 +1,74 @@ +from typing import Any, Optional, List +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt + +logger = get_logger() + + +class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): + def __init__( + self, + user_prompt_path: str = "graph_context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + top_k: Optional[int] = 5, + ): + super().__init__( + user_prompt_path=user_prompt_path, + system_prompt_path=system_prompt_path, + top_k=top_k, + ) + + async def get_completion( + self, query: str, context: Optional[Any] = None, context_extension_rounds=4 + ) -> List[str]: + triplets = [] + + if context is None: + triplets += await self.get_triplets(query) + context = await self.resolve_edges_to_text(triplets) + + round_idx = 1 + + while round_idx <= context_extension_rounds: + prev_size = len(triplets) + + logger.info( + f"Context extension: round {round_idx} - generating next graph locational query." + ) + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + + triplets += await self.get_triplets(completion) + triplets = list(set(triplets)) + context = await self.resolve_edges_to_text(triplets) + + num_triplets = len(triplets) + + if num_triplets == prev_size: + logger.info( + f"Context extension: round {round_idx} – no new triplets found; stopping early." + ) + break + + logger.info( + f"Context extension: round {round_idx} - " + f"number of unique retrieved triplets: {num_triplets}" + ) + + round_idx += 1 + + answer = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + + return [answer] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 9c9144d59a..63c25924f2 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -11,6 +11,10 @@ from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) from cognee.modules.retrieval.code_retriever import CodeRetriever from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever @@ -19,8 +23,7 @@ from cognee.modules.users.models import User from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.shared.utils import send_telemetry -from ..operations import log_query, log_result -from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.search.operations import log_query, log_result async def search( @@ -75,6 +78,10 @@ async def specific_search( system_prompt_path=system_prompt_path, top_k=top_k, ).get_completion, + SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, top_k=top_k ).get_completion, diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index aa85249123..1c672f0f04 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -12,3 +12,4 @@ class SearchType(Enum): CYPHER = "CYPHER" NATURAL_LANGUAGE = "NATURAL_LANGUAGE" GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" + GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py new file mode 100644 index 0000000000..2e2f2e197c --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -0,0 +1,185 @@ +import os +import pytest +import pathlib +from typing import Optional, Union + +import cognee +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) + + +class TestGraphCompletionRetriever: + @pytest.mark.asyncio + async def test_graph_completion_extension_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionContextExtensionRetriever() + + context = await retriever.get_context("Who works at Canva?") + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_graph_completion_extension_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionContextExtensionRetriever(top_k=20) + + context = await retriever.get_context("Who works at Figma?") + + print(context) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_get_graph_completion_extension_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + retriever = GraphCompletionContextExtensionRetriever() + + with pytest.raises(DatabaseNotCreatedError): + await retriever.get_context("Who works at Figma?") + + await setup() + + context = await retriever.get_context("Who works at Figma?") + assert context == "", "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +if __name__ == "__main__": + from asyncio import run + + test = TestGraphCompletionRetriever() + + async def main(): + await test.test_graph_completion_context_simple() + await test.test_graph_completion_context_complex() + await test.test_get_graph_completion_context_on_empty_graph() + + run(main()) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 855f1f857c..5723afd4a4 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -12,7 +12,7 @@ class TestGraphCompletionRetriever: @pytest.mark.asyncio - async def test_graph_completion_context_simple(self): + async def test_graph_completion_cot_context_simple(self): system_directory_path = os.path.join( pathlib.Path(__file__).parent, ".cognee_system/test_graph_context" ) @@ -60,7 +60,7 @@ class Person(DataPoint): ) @pytest.mark.asyncio - async def test_graph_completion_context_complex(self): + async def test_graph_completion_cot_context_complex(self): system_directory_path = os.path.join( pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" ) @@ -139,7 +139,7 @@ class Person(DataPoint): ) @pytest.mark.asyncio - async def test_get_graph_completion_context_on_empty_graph(self): + async def test_get_graph_completion_cot_context_on_empty_graph(self): system_directory_path = os.path.join( pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" )