Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e0ffdfb
refactor: move entity extractors and context providers to modules
lxobr Feb 27, 2025
e6cbb14
refactor: extract memory fragment creation
lxobr Feb 27, 2025
b45c6e2
feat: add brute force triplet search provider
lxobr Feb 27, 2025
30d0613
feat: add entity completion retriever
lxobr Feb 27, 2025
d0bb3ea
refactor: move dummy entity extractor
lxobr Feb 27, 2025
32fad89
feat: improve entity completion retriever
lxobr Feb 27, 2025
ed88559
delete: old entity completion
lxobr Feb 27, 2025
188e79a
refactor: extract summarization
lxobr Feb 27, 2025
13e544e
feat: add summarization context provider
lxobr Feb 27, 2025
b46ac3e
refactor: move dummy entity extractor
lxobr Feb 27, 2025
8233e6d
feat: add entity completion comparison example
lxobr Feb 28, 2025
9a2da29
Merge branch 'dev' into feat/COG-1325-entity-brute-force-triplet-search
borisarzentar Mar 3, 2025
70bfa9f
Merge branch 'dev' into feat/COG-1325-entity-brute-force-triplet-search
lxobr Mar 4, 2025
139055e
Merge branch 'dev' into feat/COG-1325-entity-brute-force-triplet-search
lxobr Mar 4, 2025
eb6d625
refactor: update names
lxobr Mar 4, 2025
7337255
refactor: update prompt handling, remove redundant init
lxobr Mar 4, 2025
376a53a
refactor: move memory fragment creation to start of brute_force_search
lxobr Mar 4, 2025
02b30f9
fix: update imports
lxobr Mar 4, 2025
5923a9e
Merge branch 'dev' into feat/COG-1325-entity-brute-force-triplet-search
borisarzentar Mar 4, 2025
acad4bc
refactor: update retriever name
lxobr Mar 4, 2025
c53fd6b
Merge branch 'dev' into feat/COG-1325-entity-brute-force-triplet-search
lxobr Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions cognee/modules/retrieval/EntityCompletionRetriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Optional, List
import logging

from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.completion import generate_completion


logger = logging.getLogger("entity_completion_retriever")


class EntityCompletionRetriever(BaseRetriever):
"""Retriever that uses entity-based completion for generating responses."""

def __init__(
self,
extractor: BaseEntityExtractor,
context_provider: BaseContextProvider,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
):
self.extractor = extractor
self.context_provider = context_provider
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path

async def get_context(self, query: str) -> Any:
"""Get context using entity extraction and context provider."""
try:
logger.info(f"Processing query: {query[:100]}")

entities = await self.extractor.extract_entities(query)
if not entities:
logger.info("No entities extracted")
return None

context = await self.context_provider.get_context(entities, query)
if not context:
logger.info("No context retrieved")
return None

return context

except Exception as e:
logger.error(f"Context retrieval failed: {str(e)}")
return None

async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
"""Generate completion using provided context or fetch new context."""
try:
if context is None:
context = await self.get_context(query)

if context is None:
return ["No relevant entities found for the query."]

completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]

except Exception as e:
logger.error(f"Completion generation failed: {str(e)}")
return ["Completion generation failed"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import List, Optional

from cognee.modules.retrieval.utils.completion import summarize_text
from cognee.modules.retrieval.context_providers.TripletSearchContextProvider import (
TripletSearchContextProvider,
)


class SummarizedTripletSearchContextProvider(TripletSearchContextProvider):
"""Context provider that uses summarized triplet search results."""

async def _format_triplets(
self, triplets: List, entity_name: str, summarize_prompt_path: Optional[str] = None
) -> str:
"""Format triplets into a summarized text."""
direct_text = await super()._format_triplets(triplets, entity_name)

if summarize_prompt_path is None:
summarize_prompt_path = "summarize_search_results.txt"

summary = await summarize_text(direct_text, summarize_prompt_path)
return f"Summary for {entity_name}:\n{summary}\n---\n"
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import List, Optional
import asyncio

from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search,
format_triplets,
get_memory_fragment,
)
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User


class TripletSearchContextProvider(BaseContextProvider):
"""Context provider that uses brute force triplet search for each entity."""

def __init__(
self,
top_k: int = 3,
collections: List[str] = None,
properties_to_project: List[str] = None,
):
self.top_k = top_k
self.collections = collections
self.properties_to_project = properties_to_project

def _get_entity_text(self, entity: DataPoint) -> Optional[str]:
"""Concatenates available entity text fields with graceful fallback."""
texts = []
if hasattr(entity, "name") and entity.name:
texts.append(entity.name)
if hasattr(entity, "description") and entity.description:
texts.append(entity.description)
if hasattr(entity, "text") and entity.text:
texts.append(entity.text)

return " ".join(texts) if texts else None

def _get_search_tasks(
self,
entities: List[DataPoint],
query: str,
user: User,
memory_fragment: CogneeGraph,
) -> List:
"""Creates search tasks for valid entities."""
tasks = [
brute_force_triplet_search(
query=f"{entity_text} {query}",
user=user,
top_k=self.top_k,
collections=self.collections,
properties_to_project=self.properties_to_project,
memory_fragment=memory_fragment,
)
for entity in entities
if (entity_text := self._get_entity_text(entity)) is not None
]
return tasks

async def _format_triplets(self, triplets: List, entity_name: str) -> str:
"""Format triplets into readable text."""
direct_text = format_triplets(triplets)
return f"Context for {entity_name}:\n{direct_text}\n---\n"

async def _results_to_context(self, entities: List[DataPoint], results: List) -> str:
"""Formats search results into context string."""
triplets = []

for entity, entity_triplets in zip(entities, results):
entity_name = (
getattr(entity, "name", None)
or getattr(entity, "description", None)
or getattr(entity, "text", str(entity))
)
triplets.append(await self._format_triplets(entity_triplets, entity_name))

return "\n".join(triplets) if triplets else "No relevant context found."

async def get_context(self, entities: List[DataPoint], query: str) -> str:
"""Get context for each entity using brute force triplet search."""
if not entities:
return "No entities provided for context search."

user = await get_default_user()
memory_fragment = await get_memory_fragment(self.properties_to_project)
search_tasks = self._get_search_tasks(entities, query, user, memory_fragment)

if not search_tasks:
return "No valid entities found for context search."

results = await asyncio.gather(*search_tasks)
return await self._results_to_context(entities, results)
12 changes: 2 additions & 10 deletions cognee/modules/retrieval/graph_summary_completion_retriever.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional

from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import summarize_text


class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
Expand All @@ -26,11 +25,4 @@ def __init__(
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a summary without redundancies."""
direct_text = await super().resolve_edges_to_text(retrieved_edges)
system_prompt = read_query_prompt(self.summarize_prompt_path)

llm_client = get_llm_client()
return await llm_client.acreate_structured_output(
text_input=direct_text,
system_prompt=system_prompt,
response_model=str,
)
return await summarize_text(direct_text, self.summarize_prompt_path)
52 changes: 36 additions & 16 deletions cognee/modules/retrieval/utils/brute_force_triplet_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import List
from typing import List, Optional

from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
Expand Down Expand Up @@ -49,12 +49,32 @@ def filter_attributes(obj, attributes):
return "".join(triplets)


async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
graph_engine = await get_graph_engine()
memory_fragment = CogneeGraph()

if properties_to_project is None:
properties_to_project = ["id", "description", "name", "type", "text"]

await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
)

return memory_fragment


async def brute_force_triplet_search(
query: str,
user: User = None,
top_k: int = 5,
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
) -> list:
if user is None:
user = await get_default_user()
Expand All @@ -63,7 +83,12 @@ async def brute_force_triplet_search(
raise PermissionError("No user found in the system. Please create a user.")

retrieved_results = await brute_force_search(
query, user, top_k, collections=collections, properties_to_project=properties_to_project
query,
user,
top_k,
collections=collections,
properties_to_project=properties_to_project,
memory_fragment=memory_fragment,
)
return retrieved_results

Expand All @@ -74,6 +99,7 @@ async def brute_force_search(
top_k: int,
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
Expand All @@ -82,7 +108,9 @@ async def brute_force_search(
query (str): The search query.
user (User): The user performing the search.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
collections (Optional[List[str]]): List of collections to query.
properties_to_project (Optional[List[str]]): List of properties to project.
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.

Returns:
list: The top triplet results.
Expand All @@ -92,6 +120,9 @@ async def brute_force_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")

if memory_fragment is None:
memory_fragment = await get_memory_fragment(properties_to_project)

if collections is None:
collections = [
"Entity_name",
Expand All @@ -102,9 +133,8 @@ async def brute_force_search(

try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
logging.error("Failed to initialize vector engine: %s", e)
raise RuntimeError("Initialization error") from e

send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
Expand All @@ -119,22 +149,12 @@ async def brute_force_search(

node_distances = {collection: result for collection, result in zip(collections, results)}

memory_fragment = CogneeGraph()

await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project
or ["id", "description", "name", "type", "text"],
edge_properties_to_project=["relationship_name"],
)

await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)

await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)

results = await memory_fragment.calculate_top_triplet_importances(k=top_k)

send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
send_telemetry("cognee.brute_force_triplet_search EXECUTION COMPLETED", user.id)

return results

Expand Down
15 changes: 15 additions & 0 deletions cognee/modules/retrieval/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,18 @@ async def generate_completion(
system_prompt=system_prompt,
response_model=str,
)


async def summarize_text(
text: str,
prompt_path: str = "summarize_search_results.txt",
) -> str:
"""Summarizes text using LLM with the specified prompt."""
system_prompt = read_query_prompt(prompt_path)
llm_client = get_llm_client()

return await llm_client.acreate_structured_output(
text_input=text,
system_prompt=system_prompt,
response_model=str,
)
Loading
Loading