diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 67eb025784..6875e41ebe 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -1,5 +1,6 @@ -from typing import List, Dict +from typing import List, Dict, Any from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, @@ -8,8 +9,9 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever -retriever_options: Dict[str, BaseRetriever] = { +retriever_options: Dict[str, Any] = { "cognee_graph_completion": GraphCompletionRetriever, + "cognee_graph_completion_cot": GraphCompletionCotRetriever, "cognee_completion": CompletionRetriever, "graph_summary_completion": GraphSummaryCompletionRetriever, } diff --git a/cognee/eval_framework/modal_eval_dashboard.py b/cognee/eval_framework/modal_eval_dashboard.py index 6e123e22f9..acc0c3aa91 100644 --- a/cognee/eval_framework/modal_eval_dashboard.py +++ b/cognee/eval_framework/modal_eval_dashboard.py @@ -45,6 +45,8 @@ def run(): # Streamlit Dashboard Application Logic # ---------------------------------------------------------------------------- def main(): + metrics_volume.reload() + st.set_page_config(page_title="Metrics Dashboard", layout="wide") st.title("📊 Cognee Evaluations Dashboard") diff --git a/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt new file mode 100644 index 0000000000..5c7ac10356 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt @@ -0,0 +1,3 @@ +You are a helpful assistant whose job is to ask exactly one clarifying follow-up question, +to collect the missing piece of information needed to fully answer the user’s original query. +Respond with the question only (no extra text, no punctuation beyond what’s needed). diff --git a/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt new file mode 100644 index 0000000000..5547ca58dc --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt @@ -0,0 +1,14 @@ +Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer. +Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks + + +`{{ query}}` + + + +`{{ answer }}` + + + +`{{ reasoning }}` + diff --git a/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt new file mode 100644 index 0000000000..b8066d0818 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt @@ -0,0 +1,2 @@ +You are a helpful agent who are allowed to use only the provided question answer and context. +I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context. diff --git a/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt new file mode 100644 index 0000000000..b989595de7 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt @@ -0,0 +1,11 @@ + +`{{ query}}` + + + +`{{ answer }}` + + + +`{{ context }}` + diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py new file mode 100644 index 0000000000..012e9574d3 --- /dev/null +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -0,0 +1,84 @@ +from typing import Any, Optional, List +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt + +logger = get_logger() + + +class GraphCompletionCotRetriever(GraphCompletionRetriever): + def __init__( + self, + user_prompt_path: str = "graph_context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + validation_user_prompt_path: str = "cot_validation_user_prompt.txt", + validation_system_prompt_path: str = "cot_validation_system_prompt.txt", + followup_system_prompt_path: str = "cot_followup_system_prompt.txt", + followup_user_prompt_path: str = "cot_followup_user_prompt.txt", + top_k: Optional[int] = 5, + ): + super().__init__( + user_prompt_path=user_prompt_path, + system_prompt_path=system_prompt_path, + top_k=top_k, + ) + self.validation_system_prompt_path = validation_system_prompt_path + self.validation_user_prompt_path = validation_user_prompt_path + self.followup_system_prompt_path = followup_system_prompt_path + self.followup_user_prompt_path = followup_user_prompt_path + + async def get_completion( + self, query: str, context: Optional[Any] = None, max_iter=4 + ) -> List[str]: + llm_client = get_llm_client() + followup_question = "" + triplets = [] + answer = [""] + + for round_idx in range(max_iter + 1): + if round_idx == 0: + if context is None: + context = await self.get_context(query) + else: + triplets += await self.get_triplets(followup_question) + context = await self.resolve_edges_to_text(list(set(triplets))) + + answer = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}") + if round_idx < max_iter: + valid_args = {"query": query, "answer": answer, "context": context} + valid_user_prompt = render_prompt( + filename=self.validation_user_prompt_path, context=valid_args + ) + valid_system_prompt = read_query_prompt( + prompt_file_name=self.validation_system_prompt_path + ) + + reasoning = await llm_client.acreate_structured_output( + text_input=valid_user_prompt, + system_prompt=valid_system_prompt, + response_model=str, + ) + followup_args = {"query": query, "answer": answer, "reasoning": reasoning} + followup_prompt = render_prompt( + filename=self.followup_user_prompt_path, context=followup_args + ) + followup_system = read_query_prompt( + prompt_file_name=self.followup_system_prompt_path + ) + + followup_question = await llm_client.acreate_structured_output( + text_input=followup_prompt, system_prompt=followup_system, response_model=str + ) + logger.info( + f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}" + ) + + return [answer] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 5901001d98..9c9144d59a 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -20,6 +20,7 @@ from cognee.modules.users.permissions.methods import get_document_ids_for_user from cognee.shared.utils import send_telemetry from ..operations import log_query, log_result +from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever async def search( @@ -70,6 +71,10 @@ async def specific_search( system_prompt_path=system_prompt_path, top_k=top_k, ).get_completion, + SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, top_k=top_k ).get_completion, diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index 9c6a844833..aa85249123 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -11,3 +11,4 @@ class SearchType(Enum): CODE = "CODE" CYPHER = "CYPHER" NATURAL_LANGUAGE = "NATURAL_LANGUAGE" + GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py new file mode 100644 index 0000000000..855f1f857c --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -0,0 +1,183 @@ +import os +import pytest +import pathlib +from typing import Optional, Union + +import cognee +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever + + +class TestGraphCompletionRetriever: + @pytest.mark.asyncio + async def test_graph_completion_context_simple(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionCotRetriever() + + context = await retriever.get_context("Who works at Canva?") + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_graph_completion_context_complex(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + retriever = GraphCompletionCotRetriever(top_k=20) + + context = await retriever.get_context("Who works at Figma?") + + print(context) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + @pytest.mark.asyncio + async def test_get_graph_completion_context_on_empty_graph(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + retriever = GraphCompletionCotRetriever() + + with pytest.raises(DatabaseNotCreatedError): + await retriever.get_context("Who works at Figma?") + + await setup() + + context = await retriever.get_context("Who works at Figma?") + assert context == "", "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +if __name__ == "__main__": + from asyncio import run + + test = TestGraphCompletionRetriever() + + async def main(): + await test.test_graph_completion_context_simple() + await test.test_graph_completion_context_complex() + await test.test_get_graph_completion_context_on_empty_graph() + + run(main())