From bd64bca51988d7f273848cf01574f76a86c3474c Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 21 May 2025 17:42:22 +0200 Subject: [PATCH 01/10] Fix: fixes volume update in modal dashboard --- cognee/eval_framework/modal_eval_dashboard.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cognee/eval_framework/modal_eval_dashboard.py b/cognee/eval_framework/modal_eval_dashboard.py index 6e123e22f9..acc0c3aa91 100644 --- a/cognee/eval_framework/modal_eval_dashboard.py +++ b/cognee/eval_framework/modal_eval_dashboard.py @@ -45,6 +45,8 @@ def run(): # Streamlit Dashboard Application Logic # ---------------------------------------------------------------------------- def main(): + metrics_volume.reload() + st.set_page_config(page_title="Metrics Dashboard", layout="wide") st.title("📊 Cognee Evaluations Dashboard") From a99caaae5c30c0d722a96846d20d566674c6dd07 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 11:10:44 +0200 Subject: [PATCH 02/10] Adds prompts for GraphCompletionCoT --- .../llm/prompts/cot_followup_system_prompt.txt | 3 +++ .../llm/prompts/cot_followup_user_prompt.txt | 14 ++++++++++++++ .../llm/prompts/cot_validation_system_prompt.txt | 2 ++ .../llm/prompts/cot_validation_user_prompt.txt | 11 +++++++++++ 4 files changed, 30 insertions(+) create mode 100644 cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt create mode 100644 cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt create mode 100644 cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt create mode 100644 cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt diff --git a/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt new file mode 100644 index 0000000000..5c7ac10356 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt @@ -0,0 +1,3 @@ +You are a helpful assistant whose job is to ask exactly one clarifying follow-up question, +to collect the missing piece of information needed to fully answer the user’s original query. +Respond with the question only (no extra text, no punctuation beyond what’s needed). diff --git a/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt new file mode 100644 index 0000000000..5547ca58dc --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt @@ -0,0 +1,14 @@ +Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer. +Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks + + +`{{ query}}` + + + +`{{ answer }}` + + + +`{{ reasoning }}` + diff --git a/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt new file mode 100644 index 0000000000..b8066d0818 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt @@ -0,0 +1,2 @@ +You are a helpful agent who are allowed to use only the provided question answer and context. +I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context. diff --git a/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt new file mode 100644 index 0000000000..b989595de7 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt @@ -0,0 +1,11 @@ + +`{{ query}}` + + + +`{{ answer }}` + + + +`{{ context }}` + From 79fa98546cd6d21158586bf10e3be6385df0f588 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 11:13:40 +0200 Subject: [PATCH 03/10] feat: Adds GraphCompletionCOT retriever --- .../graph_completion_cot_retriever.py | 84 +++++++++++++++++++ cognee/modules/search/methods/search.py | 5 ++ cognee/modules/search/types/SearchType.py | 1 + 3 files changed, 90 insertions(+) create mode 100644 cognee/modules/retrieval/graph_completion_cot_retriever.py diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py new file mode 100644 index 0000000000..942f3c0d31 --- /dev/null +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -0,0 +1,84 @@ +from typing import Any, Optional, List +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt + +logger = get_logger() + + +class GraphCompletionCotRetriever(GraphCompletionRetriever): + def __init__( + self, + user_prompt_path: str = "graph_context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + validation_user_prompt_path: str = "cot_validation_user_prompt.txt", + validation_system_prompt_path: str = "cot_validation_system_prompt.txt", + followup_system_prompt_path: str = "cot_followup_system_prompt.txt", + followup_user_prompt_path: str = "cot_followup_user_prompt.txt", + top_k: Optional[int] = 5, + ): + super().__init__( + user_prompt_path=user_prompt_path, + system_prompt_path=system_prompt_path, + top_k=top_k, + ) + self.validation_system_prompt_path = validation_system_prompt_path + self.validation_user_prompt_path = validation_user_prompt_path + self.followup_system_prompt_path = followup_system_prompt_path + self.followup_user_prompt_path = followup_user_prompt_path + + async def get_completion( + self, query: str, context: Optional[Any] = None, max_iter=4 + ) -> List[str]: + llm_client = get_llm_client() + followup_question = "" + triplets = [] + answer: List[str] = [""] + + for round_idx in range(max_iter + 1): + if round_idx == 0: + if context is None: + context = await self.get_context(query) + else: + triplets += await self.get_triplets(followup_question) + context = await self.resolve_edges_to_text(list(set(triplets))) + + answer = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}") + if round_idx < max_iter: + valid_args = {"query": query, "answer": answer, "context": context} + valid_user_prompt = render_prompt( + filename=self.validation_user_prompt_path, context=valid_args + ) + valid_system_prompt = read_query_prompt( + prompt_file_name=self.validation_system_prompt_path + ) + + reasoning = await llm_client.acreate_structured_output( + text_input=valid_user_prompt, + system_prompt=valid_system_prompt, + response_model=str, + ) + followup_args = {"query": query, "answer": answer, "reasoning": reasoning} + followup_prompt = render_prompt( + filename=self.followup_user_prompt_path, context=followup_args + ) + followup_system = read_query_prompt( + prompt_file_name=self.followup_system_prompt_path + ) + + followup_question = await llm_client.acreate_structured_output( + text_input=followup_prompt, system_prompt=followup_system, response_model=str + ) + logger.info( + f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" + ) + + return [answer] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 5901001d98..9c9144d59a 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -20,6 +20,7 @@ from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.shared.utils import send_telemetry from ..operations import log_query, log_result +from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever async def search( @@ -70,6 +71,10 @@ async def specific_search( system_prompt_path=system_prompt_path, top_k=top_k, ).get_completion, + SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, top_k=top_k ).get_completion, diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index 9c6a844833..aa85249123 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -11,3 +11,4 @@ class SearchType(Enum): CODE = "CODE" CYPHER = "CYPHER" NATURAL_LANGUAGE = "NATURAL_LANGUAGE" + GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" From 3de9b4ebcd6e7f3080ad326f58e157a0a693d5ef Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 11:14:20 +0200 Subject: [PATCH 04/10] feat: adds GraphCompletionCoT retriever to eval framework --- .../answer_generation/answer_generation_executor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 67eb025784..6875e41ebe 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -1,5 +1,6 @@ -from typing import List, Dict +from typing import List, Dict, Any from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, @@ -8,8 +9,9 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever -retriever_options: Dict[str, BaseRetriever] = { +retriever_options: Dict[str, Any] = { "cognee_graph_completion": GraphCompletionRetriever, + "cognee_graph_completion_cot": GraphCompletionCotRetriever, "cognee_completion": CompletionRetriever, "graph_summary_completion": GraphSummaryCompletionRetriever, } From 0ee2ba1fb91fe3261899d3846c20e8121a5d9796 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 11:31:44 +0200 Subject: [PATCH 05/10] feat: adds GrapCompletionCoT retriever unit tests --- .../graph_completion_retriever_cot_test.py | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py new file mode 100644 index 0000000000..855f1f857c --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -0,0 +1,183 @@ +import os +import pytest +import pathlib +from typing import Optional, Union + +import cognee +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever + + +class TestGraphCompletionRetriever: + @pytest.mark.asyncio + async def test_graph_completion_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionCotRetriever() + + context = await retriever.get_context("Who works at Canva?") + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_graph_completion_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionCotRetriever(top_k=20) + + context = await retriever.get_context("Who works at Figma?") + + print(context) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_get_graph_completion_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + retriever = GraphCompletionCotRetriever() + + with pytest.raises(DatabaseNotCreatedError): + await retriever.get_context("Who works at Figma?") + + await setup() + + context = await retriever.get_context("Who works at Figma?") + assert context == "", "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +if __name__ == "__main__": + from asyncio import run + + test = TestGraphCompletionRetriever() + + async def main(): + await test.test_graph_completion_context_simple() + await test.test_graph_completion_context_complex() + await test.test_get_graph_completion_context_on_empty_graph() + + run(main()) From 87f89f94d413248f6ff2b67d8107f09f8392d6c1 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 12:43:34 +0200 Subject: [PATCH 06/10] chore: fixes false typing --- cognee/modules/retrieval/graph_completion_cot_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 942f3c0d31..012e9574d3 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -35,7 +35,7 @@ async def get_completion( llm_client = get_llm_client() followup_question = "" triplets = [] - answer: List[str] = [""] + answer = [""] for round_idx in range(max_iter + 1): if round_idx == 0: From 10c852a9efea891ed922160c30c4a66400e8d916 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 16:29:17 +0200 Subject: [PATCH 07/10] feat: Adds graph completion context extension to search --- ..._completion_context_extension_retriever.py | 74 +++++++++++++++++++ cognee/modules/search/methods/search.py | 9 ++- cognee/modules/search/types/SearchType.py | 1 + 3 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 cognee/modules/retrieval/graph_completion_context_extension_retriever.py diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py new file mode 100644 index 0000000000..26364b2ff0 --- /dev/null +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -0,0 +1,74 @@ +from typing import Any, Optional, List +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt + +logger = get_logger() + + +class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): + def __init__( + self, + user_prompt_path: str = "graph_context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + top_k: Optional[int] = 5, + ): + super().__init__( + user_prompt_path=user_prompt_path, + system_prompt_path=system_prompt_path, + top_k=top_k, + ) + + async def get_completion( + self, query: str, context: Optional[Any] = None, context_extension_rounds=4 + ) -> List[str]: + triplets = [] + + if context is None: + triplets += await self.get_triplets(query) + context = await self.resolve_edges_to_text(triplets) + + round_idx = 1 + + while round_idx <= context_extension_rounds: + prev_size = len(triplets) + + logger.info( + f"Context extension: round {round_idx} - generating next graph locational query." + ) + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + + triplets += await self.get_triplets(completion) + triplets = list(set(triplets)) + context = await self.resolve_edges_to_text(triplets) + + num_triplets = len(triplets) + + if num_triplets == prev_size: + logger.info( + f"Context extension: round {round_idx} – no new triplets found; stopping early." + ) + break + + logger.info( + f"Context extension: round {round_idx} - " + f"number of unique retrieved triplets: {num_triplets}" + ) + + round_idx += 1 + + answer = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + + return [answer] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 9c9144d59a..2e3257016e 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -11,6 +11,10 @@ from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) from cognee.modules.retrieval.code_retriever import CodeRetriever from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever @@ -20,7 +24,6 @@ from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.shared.utils import send_telemetry from ..operations import log_query, log_result -from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever async def search( @@ -75,6 +78,10 @@ async def specific_search( system_prompt_path=system_prompt_path, top_k=top_k, ).get_completion, + SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, top_k=top_k ).get_completion, diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index aa85249123..1c672f0f04 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -12,3 +12,4 @@ class SearchType(Enum): CYPHER = "CYPHER" NATURAL_LANGUAGE = "NATURAL_LANGUAGE" GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" + GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" From 2b26e19447cc94d4b94f9eb189358737179832a9 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 16:29:54 +0200 Subject: [PATCH 08/10] feat: adds context extension search to eval framework --- .../answer_generation/answer_generation_executor.py | 4 ++++ cognee/eval_framework/eval_config.py | 4 +--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 6875e41ebe..6f166657e3 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -1,5 +1,8 @@ from typing import List, Dict, Any from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_summary_completion_retriever import ( @@ -12,6 +15,7 @@ retriever_options: Dict[str, Any] = { "cognee_graph_completion": GraphCompletionRetriever, "cognee_graph_completion_cot": GraphCompletionCotRetriever, + "cognee_graph_completion_context_extension": GraphCompletionContextExtensionRetriever, "cognee_completion": CompletionRetriever, "graph_summary_completion": GraphSummaryCompletionRetriever, } diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py index 1a0825f752..7baf58cf95 100644 --- a/cognee/eval_framework/eval_config.py +++ b/cognee/eval_framework/eval_config.py @@ -14,9 +14,7 @@ class EvalConfig(BaseSettings): # Question answering params answering_questions: bool = True - qa_engine: str = ( - "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' - ) + qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' # Evaluation params evaluating_answers: bool = True From 3e4f6c42be84f8b6eb1e7a8e06fee1820b1259e6 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 16:30:34 +0200 Subject: [PATCH 09/10] feat: adds unit test for context extension search --- ...letion_retriever_context_extension_test.py | 185 ++++++++++++++++++ .../graph_completion_retriever_cot_test.py | 6 +- 2 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py new file mode 100644 index 0000000000..2e2f2e197c --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -0,0 +1,185 @@ +import os +import pytest +import pathlib +from typing import Optional, Union + +import cognee +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) + + +class TestGraphCompletionRetriever: + @pytest.mark.asyncio + async def test_graph_completion_extension_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionContextExtensionRetriever() + + context = await retriever.get_context("Who works at Canva?") + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_graph_completion_extension_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionContextExtensionRetriever(top_k=20) + + context = await retriever.get_context("Who works at Figma?") + + print(context) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_get_graph_completion_extension_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + retriever = GraphCompletionContextExtensionRetriever() + + with pytest.raises(DatabaseNotCreatedError): + await retriever.get_context("Who works at Figma?") + + await setup() + + context = await retriever.get_context("Who works at Figma?") + assert context == "", "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +if __name__ == "__main__": + from asyncio import run + + test = TestGraphCompletionRetriever() + + async def main(): + await test.test_graph_completion_context_simple() + await test.test_graph_completion_context_complex() + await test.test_get_graph_completion_context_on_empty_graph() + + run(main()) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 855f1f857c..5723afd4a4 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -12,7 +12,7 @@ class TestGraphCompletionRetriever: @pytest.mark.asyncio - async def test_graph_completion_context_simple(self): + async def test_graph_completion_cot_context_simple(self): system_directory_path = os.path.join( pathlib.Path(__file__).parent, ".cognee_system/test_graph_context" ) @@ -60,7 +60,7 @@ class Person(DataPoint): ) @pytest.mark.asyncio - async def test_graph_completion_context_complex(self): + async def test_graph_completion_cot_context_complex(self): system_directory_path = os.path.join( pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" ) @@ -139,7 +139,7 @@ class Person(DataPoint): ) @pytest.mark.asyncio - async def test_get_graph_completion_context_on_empty_graph(self): + async def test_get_graph_completion_cot_context_on_empty_graph(self): system_directory_path = os.path.join( pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" ) From 76d562e1dea81a8ca9d7ee416d44aaea44876bb6 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 22 May 2025 16:40:30 +0200 Subject: [PATCH 10/10] chore: changes imports in search --- cognee/modules/search/methods/search.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 4cef49cab7..63c25924f2 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -23,8 +23,7 @@ from cognee.modules.users.models import User from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.shared.utils import send_telemetry -from ..operations import log_query, log_result -from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.search.operations import log_query, log_result async def search(