Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import List, Dict, Any
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
Expand All @@ -12,6 +15,7 @@
retriever_options: Dict[str, Any] = {
"cognee_graph_completion": GraphCompletionRetriever,
"cognee_graph_completion_cot": GraphCompletionCotRetriever,
"cognee_graph_completion_context_extension": GraphCompletionContextExtensionRetriever,
"cognee_completion": CompletionRetriever,
"graph_summary_completion": GraphSummaryCompletionRetriever,
}
Expand Down
4 changes: 1 addition & 3 deletions cognee/eval_framework/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ class EvalConfig(BaseSettings):

# Question answering params
answering_questions: bool = True
qa_engine: str = (
"cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion'
)
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'

# Evaluation params
evaluating_answers: bool = True
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any, Optional, List
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt

logger = get_logger()


class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
)

async def get_completion(
self, query: str, context: Optional[Any] = None, context_extension_rounds=4
) -> List[str]:
triplets = []

if context is None:
triplets += await self.get_triplets(query)
context = await self.resolve_edges_to_text(triplets)

round_idx = 1

while round_idx <= context_extension_rounds:
prev_size = len(triplets)

logger.info(
f"Context extension: round {round_idx} - generating next graph locational query."
)
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)

triplets += await self.get_triplets(completion)
triplets = list(set(triplets))
context = await self.resolve_edges_to_text(triplets)

num_triplets = len(triplets)

if num_triplets == prev_size:
logger.info(
f"Context extension: round {round_idx} – no new triplets found; stopping early."
)
break

logger.info(
f"Context extension: round {round_idx} - "
f"number of unique retrieved triplets: {num_triplets}"
)

round_idx += 1

answer = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)

return [answer]
11 changes: 9 additions & 2 deletions cognee/modules/search/methods/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
Expand All @@ -19,8 +23,7 @@
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import get_document_ids_for_user
from cognee.shared.utils import send_telemetry
from ..operations import log_query, log_result
from ...retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.search.operations import log_query, log_result


async def search(
Expand Down Expand Up @@ -75,6 +78,10 @@ async def specific_search(
system_prompt_path=system_prompt_path,
top_k=top_k,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
).get_completion,
Expand Down
1 change: 1 addition & 0 deletions cognee/modules/search/types/SearchType.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class SearchType(Enum):
CYPHER = "CYPHER"
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
import pytest
import pathlib
from typing import Optional, Union

import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)


class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_extension_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
)
cognee.config.data_root_directory(data_directory_path)

await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()

class Company(DataPoint):
name: str

class Person(DataPoint):
name: str
works_for: Company

company1 = Company(name="Figma")
company2 = Company(name="Canva")
person1 = Person(name="Steve Rodger", works_for=company1)
person2 = Person(name="Ike Loma", works_for=company1)
person3 = Person(name="Jason Statham", works_for=company1)
person4 = Person(name="Mike Broski", works_for=company2)
person5 = Person(name="Christina Mayer", works_for=company2)

entities = [company1, company2, person1, person2, person3, person4, person5]

await add_data_points(entities)

retriever = GraphCompletionContextExtensionRetriever()

context = await retriever.get_context("Who works at Canva?")

assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"

answer = await retriever.get_completion("Who works at Canva?")

assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)

@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)

await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()

class Company(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}

class Car(DataPoint):
brand: str
model: str
year: int

class Location(DataPoint):
country: str
city: str

class Home(DataPoint):
location: Location
rooms: int
sqm: int

class Person(DataPoint):
name: str
works_for: Company
owns: Optional[list[Union[Car, Home]]] = None

company1 = Company(name="Figma")
company2 = Company(name="Canva")

person1 = Person(name="Mike Rodger", works_for=company1)
person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]

person2 = Person(name="Ike Loma", works_for=company1)
person2.owns = [
Car(brand="Tesla", model="Model S", year=2021),
Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
]

person3 = Person(name="Jason Statham", works_for=company1)

person4 = Person(name="Mike Broski", works_for=company2)
person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]

person5 = Person(name="Christina Mayer", works_for=company2)
person5.owns = [Car(brand="Honda", model="Civic", year=2023)]

entities = [company1, company2, person1, person2, person3, person4, person5]

await add_data_points(entities)

retriever = GraphCompletionContextExtensionRetriever(top_k=20)

context = await retriever.get_context("Who works at Figma?")

print(context)

assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"

answer = await retriever.get_completion("Who works at Figma?")

assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)

@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
)
cognee.config.data_root_directory(data_directory_path)

await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

retriever = GraphCompletionContextExtensionRetriever()

with pytest.raises(DatabaseNotCreatedError):
await retriever.get_context("Who works at Figma?")

await setup()

context = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"

answer = await retriever.get_completion("Who works at Figma?")

assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)


if __name__ == "__main__":
from asyncio import run

test = TestGraphCompletionRetriever()

async def main():
await test.test_graph_completion_context_simple()
await test.test_graph_completion_context_complex()
await test.test_get_graph_completion_context_on_empty_graph()

run(main())
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_context_simple(self):
async def test_graph_completion_cot_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
)
Expand Down Expand Up @@ -60,7 +60,7 @@ class Person(DataPoint):
)

@pytest.mark.asyncio
async def test_graph_completion_context_complex(self):
async def test_graph_completion_cot_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
Expand Down Expand Up @@ -139,7 +139,7 @@ class Person(DataPoint):
)

@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(self):
async def test_get_graph_completion_cot_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
)
Expand Down
Loading