From bd64bca51988d7f273848cf01574f76a86c3474c Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Wed, 21 May 2025 17:42:22 +0200
Subject: [PATCH 01/10] Fix: fixes volume update in modal dashboard
---
cognee/eval_framework/modal_eval_dashboard.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/cognee/eval_framework/modal_eval_dashboard.py b/cognee/eval_framework/modal_eval_dashboard.py
index 6e123e22f9..acc0c3aa91 100644
--- a/cognee/eval_framework/modal_eval_dashboard.py
+++ b/cognee/eval_framework/modal_eval_dashboard.py
@@ -45,6 +45,8 @@ def run():
# Streamlit Dashboard Application Logic
# ----------------------------------------------------------------------------
def main():
+ metrics_volume.reload()
+
st.set_page_config(page_title="Metrics Dashboard", layout="wide")
st.title("📊 Cognee Evaluations Dashboard")
From a99caaae5c30c0d722a96846d20d566674c6dd07 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 11:10:44 +0200
Subject: [PATCH 02/10] Adds prompts for GraphCompletionCoT
---
.../llm/prompts/cot_followup_system_prompt.txt | 3 +++
.../llm/prompts/cot_followup_user_prompt.txt | 14 ++++++++++++++
.../llm/prompts/cot_validation_system_prompt.txt | 2 ++
.../llm/prompts/cot_validation_user_prompt.txt | 11 +++++++++++
4 files changed, 30 insertions(+)
create mode 100644 cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt
create mode 100644 cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt
create mode 100644 cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt
create mode 100644 cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt
diff --git a/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt
new file mode 100644
index 0000000000..5c7ac10356
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_followup_system_prompt.txt
@@ -0,0 +1,3 @@
+You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,
+to collect the missing piece of information needed to fully answer the user’s original query.
+Respond with the question only (no extra text, no punctuation beyond what’s needed).
diff --git a/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt
new file mode 100644
index 0000000000..5547ca58dc
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_followup_user_prompt.txt
@@ -0,0 +1,14 @@
+Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer.
+Think in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks
+
+
+`{{ query}}`
+
+
+
+`{{ answer }}`
+
+
+
+`{{ reasoning }}`
+
diff --git a/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt
new file mode 100644
index 0000000000..b8066d0818
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_validation_system_prompt.txt
@@ -0,0 +1,2 @@
+You are a helpful agent who are allowed to use only the provided question answer and context.
+I want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context.
diff --git a/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt
new file mode 100644
index 0000000000..b989595de7
--- /dev/null
+++ b/cognee/infrastructure/llm/prompts/cot_validation_user_prompt.txt
@@ -0,0 +1,11 @@
+
+`{{ query}}`
+
+
+
+`{{ answer }}`
+
+
+
+`{{ context }}`
+
From 79fa98546cd6d21158586bf10e3be6385df0f588 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 11:13:40 +0200
Subject: [PATCH 03/10] feat: Adds GraphCompletionCOT retriever
---
.../graph_completion_cot_retriever.py | 84 +++++++++++++++++++
cognee/modules/search/methods/search.py | 5 ++
cognee/modules/search/types/SearchType.py | 1 +
3 files changed, 90 insertions(+)
create mode 100644 cognee/modules/retrieval/graph_completion_cot_retriever.py
diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py
new file mode 100644
index 0000000000..942f3c0d31
--- /dev/null
+++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py
@@ -0,0 +1,84 @@
+from typing import Any, Optional, List
+from cognee.shared.logging_utils import get_logger
+from cognee.infrastructure.llm.get_llm_client import get_llm_client
+from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
+from cognee.modules.retrieval.utils.completion import generate_completion
+from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
+
+logger = get_logger()
+
+
+class GraphCompletionCotRetriever(GraphCompletionRetriever):
+ def __init__(
+ self,
+ user_prompt_path: str = "graph_context_for_question.txt",
+ system_prompt_path: str = "answer_simple_question.txt",
+ validation_user_prompt_path: str = "cot_validation_user_prompt.txt",
+ validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
+ followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
+ followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
+ top_k: Optional[int] = 5,
+ ):
+ super().__init__(
+ user_prompt_path=user_prompt_path,
+ system_prompt_path=system_prompt_path,
+ top_k=top_k,
+ )
+ self.validation_system_prompt_path = validation_system_prompt_path
+ self.validation_user_prompt_path = validation_user_prompt_path
+ self.followup_system_prompt_path = followup_system_prompt_path
+ self.followup_user_prompt_path = followup_user_prompt_path
+
+ async def get_completion(
+ self, query: str, context: Optional[Any] = None, max_iter=4
+ ) -> List[str]:
+ llm_client = get_llm_client()
+ followup_question = ""
+ triplets = []
+ answer: List[str] = [""]
+
+ for round_idx in range(max_iter + 1):
+ if round_idx == 0:
+ if context is None:
+ context = await self.get_context(query)
+ else:
+ triplets += await self.get_triplets(followup_question)
+ context = await self.resolve_edges_to_text(list(set(triplets)))
+
+ answer = await generate_completion(
+ query=query,
+ context=context,
+ user_prompt_path=self.user_prompt_path,
+ system_prompt_path=self.system_prompt_path,
+ )
+ logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}")
+ if round_idx < max_iter:
+ valid_args = {"query": query, "answer": answer, "context": context}
+ valid_user_prompt = render_prompt(
+ filename=self.validation_user_prompt_path, context=valid_args
+ )
+ valid_system_prompt = read_query_prompt(
+ prompt_file_name=self.validation_system_prompt_path
+ )
+
+ reasoning = await llm_client.acreate_structured_output(
+ text_input=valid_user_prompt,
+ system_prompt=valid_system_prompt,
+ response_model=str,
+ )
+ followup_args = {"query": query, "answer": answer, "reasoning": reasoning}
+ followup_prompt = render_prompt(
+ filename=self.followup_user_prompt_path, context=followup_args
+ )
+ followup_system = read_query_prompt(
+ prompt_file_name=self.followup_system_prompt_path
+ )
+
+ followup_question = await llm_client.acreate_structured_output(
+ text_input=followup_prompt, system_prompt=followup_system, response_model=str
+ )
+ logger.info(
+ f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
+ )
+
+ return [answer]
diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py
index 5901001d98..9c9144d59a 100644
--- a/cognee/modules/search/methods/search.py
+++ b/cognee/modules/search/methods/search.py
@@ -20,6 +20,7 @@
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
from ..operations import log_query, log_result
+from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
async def search(
@@ -70,6 +71,10 @@ async def specific_search(
system_prompt_path=system_prompt_path,
top_k=top_k,
).get_completion,
+ SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
+ system_prompt_path=system_prompt_path,
+ top_k=top_k,
+ ).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
).get_completion,
diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py
index 9c6a844833..aa85249123 100644
--- a/cognee/modules/search/types/SearchType.py
+++ b/cognee/modules/search/types/SearchType.py
@@ -11,3 +11,4 @@ class SearchType(Enum):
CODE = "CODE"
CYPHER = "CYPHER"
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
+ GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
From 3de9b4ebcd6e7f3080ad326f58e157a0a693d5ef Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 11:14:20 +0200
Subject: [PATCH 04/10] feat: adds GraphCompletionCoT retriever to eval
framework
---
.../answer_generation/answer_generation_executor.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py
index 67eb025784..6875e41ebe 100644
--- a/cognee/eval_framework/answer_generation/answer_generation_executor.py
+++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py
@@ -1,5 +1,6 @@
-from typing import List, Dict
+from typing import List, Dict, Any
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
+from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
@@ -8,8 +9,9 @@
from cognee.modules.retrieval.base_retriever import BaseRetriever
-retriever_options: Dict[str, BaseRetriever] = {
+retriever_options: Dict[str, Any] = {
"cognee_graph_completion": GraphCompletionRetriever,
+ "cognee_graph_completion_cot": GraphCompletionCotRetriever,
"cognee_completion": CompletionRetriever,
"graph_summary_completion": GraphSummaryCompletionRetriever,
}
From 0ee2ba1fb91fe3261899d3846c20e8121a5d9796 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 11:31:44 +0200
Subject: [PATCH 05/10] feat: adds GrapCompletionCoT retriever unit tests
---
.../graph_completion_retriever_cot_test.py | 183 ++++++++++++++++++
1 file changed, 183 insertions(+)
create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
new file mode 100644
index 0000000000..855f1f857c
--- /dev/null
+++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
@@ -0,0 +1,183 @@
+import os
+import pytest
+import pathlib
+from typing import Optional, Union
+
+import cognee
+from cognee.low_level import setup, DataPoint
+from cognee.tasks.storage import add_data_points
+from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
+from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
+
+
+class TestGraphCompletionRetriever:
+ @pytest.mark.asyncio
+ async def test_graph_completion_context_simple(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+ person1 = Person(name="Steve Rodger", works_for=company1)
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person3 = Person(name="Jason Statham", works_for=company1)
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person5 = Person(name="Christina Mayer", works_for=company2)
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ retriever = GraphCompletionCotRetriever()
+
+ context = await retriever.get_context("Who works at Canva?")
+
+ assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
+ assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
+
+ answer = await retriever.get_completion("Who works at Canva?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+ @pytest.mark.asyncio
+ async def test_graph_completion_context_complex(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ metadata: dict = {"index_fields": ["name"]}
+
+ class Car(DataPoint):
+ brand: str
+ model: str
+ year: int
+
+ class Location(DataPoint):
+ country: str
+ city: str
+
+ class Home(DataPoint):
+ location: Location
+ rooms: int
+ sqm: int
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ owns: Optional[list[Union[Car, Home]]] = None
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+
+ person1 = Person(name="Mike Rodger", works_for=company1)
+ person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
+
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person2.owns = [
+ Car(brand="Tesla", model="Model S", year=2021),
+ Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
+ ]
+
+ person3 = Person(name="Jason Statham", works_for=company1)
+
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
+
+ person5 = Person(name="Christina Mayer", works_for=company2)
+ person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ retriever = GraphCompletionCotRetriever(top_k=20)
+
+ context = await retriever.get_context("Who works at Figma?")
+
+ print(context)
+
+ assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
+ assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
+ assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+ @pytest.mark.asyncio
+ async def test_get_graph_completion_context_on_empty_graph(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ retriever = GraphCompletionCotRetriever()
+
+ with pytest.raises(DatabaseNotCreatedError):
+ await retriever.get_context("Who works at Figma?")
+
+ await setup()
+
+ context = await retriever.get_context("Who works at Figma?")
+ assert context == "", "Context should be empty on an empty graph"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+if __name__ == "__main__":
+ from asyncio import run
+
+ test = TestGraphCompletionRetriever()
+
+ async def main():
+ await test.test_graph_completion_context_simple()
+ await test.test_graph_completion_context_complex()
+ await test.test_get_graph_completion_context_on_empty_graph()
+
+ run(main())
From 87f89f94d413248f6ff2b67d8107f09f8392d6c1 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 12:43:34 +0200
Subject: [PATCH 06/10] chore: fixes false typing
---
cognee/modules/retrieval/graph_completion_cot_retriever.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py
index 942f3c0d31..012e9574d3 100644
--- a/cognee/modules/retrieval/graph_completion_cot_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py
@@ -35,7 +35,7 @@ async def get_completion(
llm_client = get_llm_client()
followup_question = ""
triplets = []
- answer: List[str] = [""]
+ answer = [""]
for round_idx in range(max_iter + 1):
if round_idx == 0:
From 10c852a9efea891ed922160c30c4a66400e8d916 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 16:29:17 +0200
Subject: [PATCH 07/10] feat: Adds graph completion context extension to search
---
..._completion_context_extension_retriever.py | 74 +++++++++++++++++++
cognee/modules/search/methods/search.py | 9 ++-
cognee/modules/search/types/SearchType.py | 1 +
3 files changed, 83 insertions(+), 1 deletion(-)
create mode 100644 cognee/modules/retrieval/graph_completion_context_extension_retriever.py
diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
new file mode 100644
index 0000000000..26364b2ff0
--- /dev/null
+++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
@@ -0,0 +1,74 @@
+from typing import Any, Optional, List
+from cognee.shared.logging_utils import get_logger
+from cognee.infrastructure.llm.get_llm_client import get_llm_client
+from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
+from cognee.modules.retrieval.utils.completion import generate_completion
+from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
+
+logger = get_logger()
+
+
+class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
+ def __init__(
+ self,
+ user_prompt_path: str = "graph_context_for_question.txt",
+ system_prompt_path: str = "answer_simple_question.txt",
+ top_k: Optional[int] = 5,
+ ):
+ super().__init__(
+ user_prompt_path=user_prompt_path,
+ system_prompt_path=system_prompt_path,
+ top_k=top_k,
+ )
+
+ async def get_completion(
+ self, query: str, context: Optional[Any] = None, context_extension_rounds=4
+ ) -> List[str]:
+ triplets = []
+
+ if context is None:
+ triplets += await self.get_triplets(query)
+ context = await self.resolve_edges_to_text(triplets)
+
+ round_idx = 1
+
+ while round_idx <= context_extension_rounds:
+ prev_size = len(triplets)
+
+ logger.info(
+ f"Context extension: round {round_idx} - generating next graph locational query."
+ )
+ completion = await generate_completion(
+ query=query,
+ context=context,
+ user_prompt_path=self.user_prompt_path,
+ system_prompt_path=self.system_prompt_path,
+ )
+
+ triplets += await self.get_triplets(completion)
+ triplets = list(set(triplets))
+ context = await self.resolve_edges_to_text(triplets)
+
+ num_triplets = len(triplets)
+
+ if num_triplets == prev_size:
+ logger.info(
+ f"Context extension: round {round_idx} – no new triplets found; stopping early."
+ )
+ break
+
+ logger.info(
+ f"Context extension: round {round_idx} - "
+ f"number of unique retrieved triplets: {num_triplets}"
+ )
+
+ round_idx += 1
+
+ answer = await generate_completion(
+ query=query,
+ context=context,
+ user_prompt_path=self.user_prompt_path,
+ system_prompt_path=self.system_prompt_path,
+ )
+
+ return [answer]
diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py
index 9c9144d59a..2e3257016e 100644
--- a/cognee/modules/search/methods/search.py
+++ b/cognee/modules/search/methods/search.py
@@ -11,6 +11,10 @@
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
+from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
+from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
+ GraphCompletionContextExtensionRetriever,
+)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
@@ -20,7 +24,6 @@
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
from ..operations import log_query, log_result
-from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
async def search(
@@ -75,6 +78,10 @@ async def specific_search(
system_prompt_path=system_prompt_path,
top_k=top_k,
).get_completion,
+ SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
+ system_prompt_path=system_prompt_path,
+ top_k=top_k,
+ ).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
).get_completion,
diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py
index aa85249123..1c672f0f04 100644
--- a/cognee/modules/search/types/SearchType.py
+++ b/cognee/modules/search/types/SearchType.py
@@ -12,3 +12,4 @@ class SearchType(Enum):
CYPHER = "CYPHER"
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
+ GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
From 2b26e19447cc94d4b94f9eb189358737179832a9 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 16:29:54 +0200
Subject: [PATCH 08/10] feat: adds context extension search to eval framework
---
.../answer_generation/answer_generation_executor.py | 4 ++++
cognee/eval_framework/eval_config.py | 4 +---
2 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py
index 6875e41ebe..6f166657e3 100644
--- a/cognee/eval_framework/answer_generation/answer_generation_executor.py
+++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py
@@ -1,5 +1,8 @@
from typing import List, Dict, Any
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
+from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
+ GraphCompletionContextExtensionRetriever,
+)
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
@@ -12,6 +15,7 @@
retriever_options: Dict[str, Any] = {
"cognee_graph_completion": GraphCompletionRetriever,
"cognee_graph_completion_cot": GraphCompletionCotRetriever,
+ "cognee_graph_completion_context_extension": GraphCompletionContextExtensionRetriever,
"cognee_completion": CompletionRetriever,
"graph_summary_completion": GraphSummaryCompletionRetriever,
}
diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py
index 1a0825f752..7baf58cf95 100644
--- a/cognee/eval_framework/eval_config.py
+++ b/cognee/eval_framework/eval_config.py
@@ -14,9 +14,7 @@ class EvalConfig(BaseSettings):
# Question answering params
answering_questions: bool = True
- qa_engine: str = (
- "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion'
- )
+ qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
# Evaluation params
evaluating_answers: bool = True
From 3e4f6c42be84f8b6eb1e7a8e06fee1820b1259e6 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 16:30:34 +0200
Subject: [PATCH 09/10] feat: adds unit test for context extension search
---
...letion_retriever_context_extension_test.py | 185 ++++++++++++++++++
.../graph_completion_retriever_cot_test.py | 6 +-
2 files changed, 188 insertions(+), 3 deletions(-)
create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py
diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py
new file mode 100644
index 0000000000..2e2f2e197c
--- /dev/null
+++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py
@@ -0,0 +1,185 @@
+import os
+import pytest
+import pathlib
+from typing import Optional, Union
+
+import cognee
+from cognee.low_level import setup, DataPoint
+from cognee.tasks.storage import add_data_points
+from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
+from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
+ GraphCompletionContextExtensionRetriever,
+)
+
+
+class TestGraphCompletionRetriever:
+ @pytest.mark.asyncio
+ async def test_graph_completion_extension_context_simple(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+ person1 = Person(name="Steve Rodger", works_for=company1)
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person3 = Person(name="Jason Statham", works_for=company1)
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person5 = Person(name="Christina Mayer", works_for=company2)
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ retriever = GraphCompletionContextExtensionRetriever()
+
+ context = await retriever.get_context("Who works at Canva?")
+
+ assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
+ assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
+
+ answer = await retriever.get_completion("Who works at Canva?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+ @pytest.mark.asyncio
+ async def test_graph_completion_extension_context_complex(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+ await setup()
+
+ class Company(DataPoint):
+ name: str
+ metadata: dict = {"index_fields": ["name"]}
+
+ class Car(DataPoint):
+ brand: str
+ model: str
+ year: int
+
+ class Location(DataPoint):
+ country: str
+ city: str
+
+ class Home(DataPoint):
+ location: Location
+ rooms: int
+ sqm: int
+
+ class Person(DataPoint):
+ name: str
+ works_for: Company
+ owns: Optional[list[Union[Car, Home]]] = None
+
+ company1 = Company(name="Figma")
+ company2 = Company(name="Canva")
+
+ person1 = Person(name="Mike Rodger", works_for=company1)
+ person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]
+
+ person2 = Person(name="Ike Loma", works_for=company1)
+ person2.owns = [
+ Car(brand="Tesla", model="Model S", year=2021),
+ Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
+ ]
+
+ person3 = Person(name="Jason Statham", works_for=company1)
+
+ person4 = Person(name="Mike Broski", works_for=company2)
+ person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]
+
+ person5 = Person(name="Christina Mayer", works_for=company2)
+ person5.owns = [Car(brand="Honda", model="Civic", year=2023)]
+
+ entities = [company1, company2, person1, person2, person3, person4, person5]
+
+ await add_data_points(entities)
+
+ retriever = GraphCompletionContextExtensionRetriever(top_k=20)
+
+ context = await retriever.get_context("Who works at Figma?")
+
+ print(context)
+
+ assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
+ assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
+ assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+ @pytest.mark.asyncio
+ async def test_get_graph_completion_extension_context_on_empty_graph(self):
+ system_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
+ )
+ cognee.config.system_root_directory(system_directory_path)
+ data_directory_path = os.path.join(
+ pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
+ )
+ cognee.config.data_root_directory(data_directory_path)
+
+ await cognee.prune.prune_data()
+ await cognee.prune.prune_system(metadata=True)
+
+ retriever = GraphCompletionContextExtensionRetriever()
+
+ with pytest.raises(DatabaseNotCreatedError):
+ await retriever.get_context("Who works at Figma?")
+
+ await setup()
+
+ context = await retriever.get_context("Who works at Figma?")
+ assert context == "", "Context should be empty on an empty graph"
+
+ answer = await retriever.get_completion("Who works at Figma?")
+
+ assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
+ assert all(isinstance(item, str) and item.strip() for item in answer), (
+ "Answer must contain only non-empty strings"
+ )
+
+
+if __name__ == "__main__":
+ from asyncio import run
+
+ test = TestGraphCompletionRetriever()
+
+ async def main():
+ await test.test_graph_completion_context_simple()
+ await test.test_graph_completion_context_complex()
+ await test.test_get_graph_completion_context_on_empty_graph()
+
+ run(main())
diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
index 855f1f857c..5723afd4a4 100644
--- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
+++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py
@@ -12,7 +12,7 @@
class TestGraphCompletionRetriever:
@pytest.mark.asyncio
- async def test_graph_completion_context_simple(self):
+ async def test_graph_completion_cot_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
)
@@ -60,7 +60,7 @@ class Person(DataPoint):
)
@pytest.mark.asyncio
- async def test_graph_completion_context_complex(self):
+ async def test_graph_completion_cot_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
@@ -139,7 +139,7 @@ class Person(DataPoint):
)
@pytest.mark.asyncio
- async def test_get_graph_completion_context_on_empty_graph(self):
+ async def test_get_graph_completion_cot_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
From 76d562e1dea81a8ca9d7ee416d44aaea44876bb6 Mon Sep 17 00:00:00 2001
From: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Date: Thu, 22 May 2025 16:40:30 +0200
Subject: [PATCH 10/10] chore: changes imports in search
---
cognee/modules/search/methods/search.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py
index 4cef49cab7..63c25924f2 100644
--- a/cognee/modules/search/methods/search.py
+++ b/cognee/modules/search/methods/search.py
@@ -23,8 +23,7 @@
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
-from ..operations import log_query, log_result
-from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
+from cognee.modules.search.operations import log_query, log_result
async def search(