diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py
index 67eb025784..6875e41ebe 100644
--- a/cognee/eval_framework/answer_generation/answer_generation_executor.py
+++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py
@@ -1,5 +1,6 @@
-from typing import List, Dict
+from typing import List, Dict, Any
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
+from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
@@ -8,8 +9,9 @@
from cognee.modules.retrieval.base_retriever import BaseRetriever
-retriever_options: Dict[str, BaseRetriever] = {
+retriever_options: Dict[str, Any] = {
"cognee_graph_completion": GraphCompletionRetriever,
+ "cognee_graph_completion_cot": GraphCompletionCotRetriever,
"cognee_completion": CompletionRetriever,
"graph_summary_completion": GraphSummaryCompletionRetriever,
}
diff --git a/cognee/eval_framework/modal_eval_dashboard.py b/cognee/eval_framework/modal_eval_dashboard.py
index 6e123e22f9..acc0c3aa91 100644
--- a/cognee/eval_framework/modal_eval_dashboard.py
+++ b/cognee/eval_framework/modal_eval_dashboard.py
@@ -45,6 +45,8 @@ def run():
# Streamlit Dashboard Application Logic
# ----------------------------------------------------------------------------
def main():
+ metrics_volume.reload()
+
st.set_page_config(page_title="Metrics Dashboard", layout="wide")
st.title("📊 Cognee Evaluations Dashboard")
diff --git a/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt
new file mode 100644
index 0000000000..5c7ac10356
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt
@@ -0,0 +1,3 @@
+You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,
+to collect the missing piece of information needed to fully answer the user’s original query.
+Respond with the question only (no extra text, no punctuation beyond what’s needed).
diff --git a/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt
new file mode 100644
index 0000000000..5547ca58dc
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt
@@ -0,0 +1,14 @@
+Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer.
+Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks
+
+
+`{{ query}}`
+
+
+
+`{{ answer }}`
+
+
+
+`{{ reasoning }}`
+
diff --git a/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt
new file mode 100644
index 0000000000..b8066d0818
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt
@@ -0,0 +1,2 @@
+You are a helpful agent who are allowed to use only the provided question answer and context.
+I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context.
diff --git a/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt
new file mode 100644
index 0000000000..b989595de7
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt
@@ -0,0 +1,11 @@
+
+`{{ query}}`
+
+
+
+`{{ answer }}`
+
+
+
+`{{ context }}`
+
diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py
new file mode 100644
index 0000000000..012e9574d3
--- /dev/null
+++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py
@@ -0,0 +1,84 @@
+from typing import Any, Optional, List
+from cognee.shared.logging_utils import get_logger
+from cognee.infrastructure.llm.get_llm_client import get_llm_client
+from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
+from cognee.modules.retrieval.utils.completion import generate_completion
+from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
+
+logger = get_logger()
+
+
+class GraphCompletionCotRetriever(GraphCompletionRetriever):
+ def __init__(
+ self,
+ user_prompt_path: str = "graph_context_for_question.txt",
+ system_prompt_path: str = "answer_simple_question.txt",
+ validation_user_prompt_path: str = "cot_validation_user_prompt.txt",
+ validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
+ followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
+ followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
+ top_k: Optional[int] = 5,
+ ):
+ super().__init__(
+ user_prompt_path=user_prompt_path,
+ system_prompt_path=system_prompt_path,
+ top_k=top_k,
+ )
+ self.validation_system_prompt_path = validation_system_prompt_path
+ self.validation_user_prompt_path = validation_user_prompt_path
+ self.followup_system_prompt_path = followup_system_prompt_path
+ self.followup_user_prompt_path = followup_user_prompt_path
+
+ async def get_completion(
+ self, query: str, context: Optional[Any] = None, max_iter=4
+ ) -> List[str]:
+ llm_client = get_llm_client()
+ followup_question = ""
+ triplets = []
+ answer = [""]
+
+ for round_idx in range(max_iter + 1):
+ if round_idx == 0:
+ if context is None:
+ context = await self.get_context(query)
+ else:
+ triplets += await self.get_triplets(followup_question)
+ context = await self.resolve_edges_to_text(list(set(triplets)))
+
+ answer = await generate_completion(
+ query=query,
+ context=context,
+ user_prompt_path=self.user_prompt_path,
+ system_prompt_path=self.system_prompt_path,
+ )
+ logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}")
+ if round_idx < max_iter:
+ valid_args = {"query": query, "answer": answer, "context": context}
+ valid_user_prompt = render_prompt(
+ filename=self.validation_user_prompt_path, context=valid_args
+ )
+ valid_system_prompt = read_query_prompt(
+ prompt_file_name=self.validation_system_prompt_path
+ )
+
+ reasoning = await llm_client.acreate_structured_output(
+ text_input=valid_user_prompt,
+ system_prompt=valid_system_prompt,
+ response_model=str,
+ )
+ followup_args = {"query": query, "answer": answer, "reasoning": reasoning}
+ followup_prompt = render_prompt(
+ filename=self.followup_user_prompt_path, context=followup_args
+ )
+ followup_system = read_query_prompt(
+ prompt_file_name=self.followup_system_prompt_path
+ )
+
+ followup_question = await llm_client.acreate_structured_output(
+ text_input=followup_prompt, system_prompt=followup_system, response_model=str
+ )
+ logger.info(
+ f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
+ )
+
+ return [answer]
diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py
index 5901001d98..9c9144d59a 100644
--- a/cognee/modules/search/methods/search.py
+++ b/cognee/modules/search/methods/search.py
@@ -20,6 +20,7 @@
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
from ..operations import log_query, log_result
+from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
async def search(
@@ -70,6 +71,10 @@ async def specific_search(
system_prompt_path=system_prompt_path,
top_k=top_k,
).get_completion,
+ SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
+ system_prompt_path=system_prompt_path,
+ top_k=top_k,
+ ).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
).get_completion,
diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py
index 9c6a844833..aa85249123 100644
--- a/cognee/modules/search/types/SearchType.py
+++ b/cognee/modules/search/types/SearchType.py
@@ -11,3 +11,4 @@ class SearchType(Enum):
CODE = "CODE"
CYPHER = "CYPHER"
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
+ GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
new file mode 100644
index 0000000000..855f1f857c
--- /dev/null
+++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
@@ -0,0 +1,183 @@
+import os
+import pytest
+import pathlib
+from typing import Optional, Union
+
+import cognee
+from cognee.low_level import setup, DataPoint
+from cognee.tasks.storage import add_data_points
+from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
+from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
+
+
+class TestGraphCompletionRetriever:
+ @pytest.mark.asyncio
+ async def test_graph_completion_context_simple(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+ person1 = Person(name="Steve Rodger", works_for=company1)
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person3 = Person(name="Jason Statham", works_for=company1)
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person5 = Person(name="Christina Mayer", works_for=company2)
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ retriever = GraphCompletionCotRetriever()
+
+ context = await retriever.get_context("Who works at Canva?")
+
+ assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
+ assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
+
+ answer = await retriever.get_completion("Who works at Canva?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+ @pytest.mark.asyncio
+ async def test_graph_completion_context_complex(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ metadata: dict = {"index_fields": ["name"]}
+
+ class Car(DataPoint):
+ brand: str
+ model: str
+ year: int
+
+ class Location(DataPoint):
+ country: str
+ city: str
+
+ class Home(DataPoint):
+ location: Location
+ rooms: int
+ sqm: int
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ owns: Optional[list[Union[Car, Home]]] = None
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+
+ person1 = Person(name="Mike Rodger", works_for=company1)
+ person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
+
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person2.owns = [
+ Car(brand="Tesla", model="Model S", year=2021),
+ Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
+ ]
+
+ person3 = Person(name="Jason Statham", works_for=company1)
+
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
+
+ person5 = Person(name="Christina Mayer", works_for=company2)
+ person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ retriever = GraphCompletionCotRetriever(top_k=20)
+
+ context = await retriever.get_context("Who works at Figma?")
+
+ print(context)
+
+ assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
+ assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
+ assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+ @pytest.mark.asyncio
+ async def test_get_graph_completion_context_on_empty_graph(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ retriever = GraphCompletionCotRetriever()
+
+ with pytest.raises(DatabaseNotCreatedError):
+ await retriever.get_context("Who works at Figma?")
+
+ await setup()
+
+ context = await retriever.get_context("Who works at Figma?")
+ assert context == "", "Context should be empty on an empty graph"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+if __name__ == "__main__":
+ from asyncio import run
+
+ test = TestGraphCompletionRetriever()
+
+ async def main():
+ await test.test_graph_completion_context_simple()
+ await test.test_graph_completion_context_complex()
+ await test.test_get_graph_completion_context_on_empty_graph()
+
+ run(main())