diff --git a/cognee/tasks/entity_completion/__init__.py b/cognee/modules/data/extraction/entity_extractors/__init__.py similarity index 100% rename from cognee/tasks/entity_completion/__init__.py rename to cognee/modules/data/extraction/entity_extractors/__init__.py diff --git a/cognee/modules/retrieval/EntityCompletionRetriever.py b/cognee/modules/retrieval/EntityCompletionRetriever.py new file mode 100644 index 0000000000..bb02869e11 --- /dev/null +++ b/cognee/modules/retrieval/EntityCompletionRetriever.py @@ -0,0 +1,68 @@ +from typing import Any, Optional, List +import logging + +from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor +from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider +from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.utils.completion import generate_completion + + +logger = logging.getLogger("entity_completion_retriever") + + +class EntityCompletionRetriever(BaseRetriever): + """Retriever that uses entity-based completion for generating responses.""" + + def __init__( + self, + extractor: BaseEntityExtractor, + context_provider: BaseContextProvider, + user_prompt_path: str = "context_for_question.txt", + system_prompt_path: str = "answer_simple_question.txt", + ): + self.extractor = extractor + self.context_provider = context_provider + self.user_prompt_path = user_prompt_path + self.system_prompt_path = system_prompt_path + + async def get_context(self, query: str) -> Any: + """Get context using entity extraction and context provider.""" + try: + logger.info(f"Processing query: {query[:100]}") + + entities = await self.extractor.extract_entities(query) + if not entities: + logger.info("No entities extracted") + return None + + context = await self.context_provider.get_context(entities, query) + if not context: + logger.info("No context retrieved") + return None + + return context + + except Exception as e: + logger.error(f"Context retrieval failed: {str(e)}") + return None + + async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]: + """Generate completion using provided context or fetch new context.""" + try: + if context is None: + context = await self.get_context(query) + + if context is None: + return ["No relevant entities found for the query."] + + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) + return [completion] + + except Exception as e: + logger.error(f"Completion generation failed: {str(e)}") + return ["Completion generation failed"] diff --git a/cognee/tasks/entity_completion/context_providers/dummy_context_provider.py b/cognee/modules/retrieval/context_providers/DummyContextProvider.py similarity index 100% rename from cognee/tasks/entity_completion/context_providers/dummy_context_provider.py rename to cognee/modules/retrieval/context_providers/DummyContextProvider.py diff --git a/cognee/modules/retrieval/context_providers/SummarizedTripletSearchContextProvider.py b/cognee/modules/retrieval/context_providers/SummarizedTripletSearchContextProvider.py new file mode 100644 index 0000000000..ab2e14c157 --- /dev/null +++ b/cognee/modules/retrieval/context_providers/SummarizedTripletSearchContextProvider.py @@ -0,0 +1,22 @@ +from typing import List, Optional + +from cognee.modules.retrieval.utils.completion import summarize_text +from cognee.modules.retrieval.context_providers.TripletSearchContextProvider import ( + TripletSearchContextProvider, +) + + +class SummarizedTripletSearchContextProvider(TripletSearchContextProvider): + """Context provider that uses summarized triplet search results.""" + + async def _format_triplets( + self, triplets: List, entity_name: str, summarize_prompt_path: Optional[str] = None + ) -> str: + """Format triplets into a summarized text.""" + direct_text = await super()._format_triplets(triplets, entity_name) + + if summarize_prompt_path is None: + summarize_prompt_path = "summarize_search_results.txt" + + summary = await summarize_text(direct_text, summarize_prompt_path) + return f"Summary for {entity_name}:\n{summary}\n---\n" diff --git a/cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py b/cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py new file mode 100644 index 0000000000..d5626efdd5 --- /dev/null +++ b/cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py @@ -0,0 +1,97 @@ +from typing import List, Optional +import asyncio + +from cognee.infrastructure.context.BaseContextProvider import BaseContextProvider +from cognee.infrastructure.engine import DataPoint +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_triplet_search, + format_triplets, + get_memory_fragment, +) +from cognee.modules.users.methods import get_default_user +from cognee.modules.users.models import User + + +class TripletSearchContextProvider(BaseContextProvider): + """Context provider that uses brute force triplet search for each entity.""" + + def __init__( + self, + top_k: int = 3, + collections: List[str] = None, + properties_to_project: List[str] = None, + ): + self.top_k = top_k + self.collections = collections + self.properties_to_project = properties_to_project + + def _get_entity_text(self, entity: DataPoint) -> Optional[str]: + """Concatenates available entity text fields with graceful fallback.""" + texts = [] + if hasattr(entity, "name") and entity.name: + texts.append(entity.name) + if hasattr(entity, "description") and entity.description: + texts.append(entity.description) + if hasattr(entity, "text") and entity.text: + texts.append(entity.text) + + return " ".join(texts) if texts else None + + def _get_search_tasks( + self, + entities: List[DataPoint], + query: str, + user: User, + memory_fragment: CogneeGraph, + ) -> List: + """Creates search tasks for valid entities.""" + tasks = [ + brute_force_triplet_search( + query=f"{entity_text} {query}", + user=user, + top_k=self.top_k, + collections=self.collections, + properties_to_project=self.properties_to_project, + memory_fragment=memory_fragment, + ) + for entity in entities + if (entity_text := self._get_entity_text(entity)) is not None + ] + return tasks + + async def _format_triplets(self, triplets: List, entity_name: str) -> str: + """Format triplets into readable text.""" + direct_text = format_triplets(triplets) + return f"Context for {entity_name}:\n{direct_text}\n---\n" + + async def _results_to_context(self, entities: List[DataPoint], results: List) -> str: + """Formats search results into context string.""" + triplets = [] + + for entity, entity_triplets in zip(entities, results): + entity_name = ( + getattr(entity, "name", None) + or getattr(entity, "description", None) + or getattr(entity, "text", str(entity)) + ) + triplets.append(await self._format_triplets(entity_triplets, entity_name)) + + return "\n".join(triplets) if triplets else "No relevant context found." + + async def get_context(self, entities: List[DataPoint], query: str) -> str: + """Get context for each entity using brute force triplet search.""" + if not entities: + return "No entities provided for context search." + + user = await get_default_user() + memory_fragment = await get_memory_fragment(self.properties_to_project) + search_tasks = self._get_search_tasks(entities, query, user, memory_fragment) + + if not search_tasks: + return "No valid entities found for context search." + + results = await asyncio.gather(*search_tasks) + return await self._results_to_context(entities, results) diff --git a/cognee/tasks/entity_completion/context_providers/__init__.py b/cognee/modules/retrieval/context_providers/__init__.py similarity index 100% rename from cognee/tasks/entity_completion/context_providers/__init__.py rename to cognee/modules/retrieval/context_providers/__init__.py diff --git a/cognee/tasks/entity_completion/entity_extractors/dummy_entity_extractor.py b/cognee/modules/retrieval/entity_extractors/DummyEntityExtractor.py similarity index 100% rename from cognee/tasks/entity_completion/entity_extractors/dummy_entity_extractor.py rename to cognee/modules/retrieval/entity_extractors/DummyEntityExtractor.py diff --git a/cognee/tasks/entity_completion/entity_extractors/__init__.py b/cognee/modules/retrieval/entity_extractors/__init__.py similarity index 100% rename from cognee/tasks/entity_completion/entity_extractors/__init__.py rename to cognee/modules/retrieval/entity_extractors/__init__.py diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index db83cbee38..536bafe5d5 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -1,8 +1,7 @@ from typing import Optional -from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.utils.completion import summarize_text class GraphSummaryCompletionRetriever(GraphCompletionRetriever): @@ -26,11 +25,4 @@ def __init__( async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """Converts retrieved graph edges into a summary without redundancies.""" direct_text = await super().resolve_edges_to_text(retrieved_edges) - system_prompt = read_query_prompt(self.summarize_prompt_path) - - llm_client = get_llm_client() - return await llm_client.acreate_structured_output( - text_input=direct_text, - system_prompt=system_prompt, - response_model=str, - ) + return await summarize_text(direct_text, self.summarize_prompt_path) diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 5b896b9211..fec16d676f 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import List +from typing import List, Optional from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine @@ -49,12 +49,32 @@ def filter_attributes(obj, attributes): return "".join(triplets) +async def get_memory_fragment( + properties_to_project: Optional[List[str]] = None, +) -> CogneeGraph: + """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" + graph_engine = await get_graph_engine() + memory_fragment = CogneeGraph() + + if properties_to_project is None: + properties_to_project = ["id", "description", "name", "type", "text"] + + await memory_fragment.project_graph_from_db( + graph_engine, + node_properties_to_project=properties_to_project, + edge_properties_to_project=["relationship_name"], + ) + + return memory_fragment + + async def brute_force_triplet_search( query: str, user: User = None, top_k: int = 5, collections: List[str] = None, properties_to_project: List[str] = None, + memory_fragment: Optional[CogneeGraph] = None, ) -> list: if user is None: user = await get_default_user() @@ -63,7 +83,12 @@ async def brute_force_triplet_search( raise PermissionError("No user found in the system. Please create a user.") retrieved_results = await brute_force_search( - query, user, top_k, collections=collections, properties_to_project=properties_to_project + query, + user, + top_k, + collections=collections, + properties_to_project=properties_to_project, + memory_fragment=memory_fragment, ) return retrieved_results @@ -74,6 +99,7 @@ async def brute_force_search( top_k: int, collections: List[str] = None, properties_to_project: List[str] = None, + memory_fragment: Optional[CogneeGraph] = None, ) -> list: """ Performs a brute force search to retrieve the top triplets from the graph. @@ -82,7 +108,9 @@ async def brute_force_search( query (str): The search query. user (User): The user performing the search. top_k (int): The number of top results to retrieve. - collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections. + collections (Optional[List[str]]): List of collections to query. + properties_to_project (Optional[List[str]]): List of properties to project. + memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. Returns: list: The top triplet results. @@ -92,6 +120,9 @@ async def brute_force_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") + if memory_fragment is None: + memory_fragment = await get_memory_fragment(properties_to_project) + if collections is None: collections = [ "Entity_name", @@ -102,9 +133,8 @@ async def brute_force_search( try: vector_engine = get_vector_engine() - graph_engine = await get_graph_engine() except Exception as e: - logging.error("Failed to initialize engines: %s", e) + logging.error("Failed to initialize vector engine: %s", e) raise RuntimeError("Initialization error") from e send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) @@ -119,22 +149,12 @@ async def brute_force_search( node_distances = {collection: result for collection, result in zip(collections, results)} - memory_fragment = CogneeGraph() - - await memory_fragment.project_graph_from_db( - graph_engine, - node_properties_to_project=properties_to_project - or ["id", "description", "name", "type", "text"], - edge_properties_to_project=["relationship_name"], - ) - await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) - await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query) results = await memory_fragment.calculate_top_triplet_importances(k=top_k) - send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id) + send_telemetry("cognee.brute_force_triplet_search EXECUTION COMPLETED", user.id) return results diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index 512193c31c..b77ff0f0e9 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -21,3 +21,18 @@ async def generate_completion( system_prompt=system_prompt, response_model=str, ) + + +async def summarize_text( + text: str, + prompt_path: str = "summarize_search_results.txt", +) -> str: + """Summarizes text using LLM with the specified prompt.""" + system_prompt = read_query_prompt(prompt_path) + llm_client = get_llm_client() + + return await llm_client.acreate_structured_output( + text_input=text, + system_prompt=system_prompt, + response_model=str, + ) diff --git a/cognee/tasks/entity_completion/entity_completion.py b/cognee/tasks/entity_completion/entity_completion.py deleted file mode 100644 index a5c9c66d0b..0000000000 --- a/cognee/tasks/entity_completion/entity_completion.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import List -import logging - -from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt -from cognee.infrastructure.entities.BaseEntityExtractor import ( - BaseEntityExtractor, -) -from cognee.infrastructure.context.BaseContextProvider import ( - BaseContextProvider, -) - -logger = logging.getLogger("entity_completion") - -# Default prompt template paths -DEFAULT_SYSTEM_PROMPT_TEMPLATE = "answer_simple_question.txt" -DEFAULT_USER_PROMPT_TEMPLATE = "context_for_question.txt" - - -async def get_llm_response( - query: str, - context: str, - system_prompt_template: str = None, - user_prompt_template: str = None, -) -> str: - """Generate LLM response based on query and context.""" - try: - args = { - "question": query, - "context": context, - } - user_prompt = render_prompt(user_prompt_template or DEFAULT_USER_PROMPT_TEMPLATE, args) - system_prompt = read_query_prompt(system_prompt_template or DEFAULT_SYSTEM_PROMPT_TEMPLATE) - - llm_client = get_llm_client() - return await llm_client.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, - ) - except Exception as e: - logger.error(f"LLM response generation failed: {str(e)}") - raise - - -async def entity_completion( - query: str, - extractor: BaseEntityExtractor, - context_provider: BaseContextProvider, - system_prompt_template: str = None, - user_prompt_template: str = None, -) -> List[str]: - """Execute entity-based completion using provided components.""" - if not query or not isinstance(query, str): - logger.error("Invalid query type or empty query") - return ["Invalid query input"] - - try: - logger.info(f"Processing query: {query[:100]}") - - entities = await extractor.extract_entities(query) - if not entities: - logger.info("No entities extracted") - return ["No entities found"] - - context = await context_provider.get_context(entities, query) - if not context: - logger.info("No context retrieved") - return ["No context found"] - - response = await get_llm_response( - query, context, system_prompt_template, user_prompt_template - ) - return [response] - - except Exception as e: - logger.error(f"Entity completion failed: {str(e)}") - return ["Entity completion failed"] - - -if __name__ == "__main__": - # For testing purposes, will be removed by the end of the sprint - import asyncio - import logging - from cognee.tasks.entity_completion.entity_extractors.dummy_entity_extractor import ( - DummyEntityExtractor, - ) - from cognee.tasks.entity_completion.context_providers.dummy_context_provider import ( - DummyContextProvider, - ) - - logging.basicConfig(level=logging.INFO) - - async def run_entity_completion(): - # Uses config defaults - result = await entity_completion( - "Tell me about Einstein", - DummyEntityExtractor(), - DummyContextProvider(), - ) - print(f"Query Response: {result[0]}") - - asyncio.run(run_entity_completion()) diff --git a/examples/python/entity_completion_comparison.py b/examples/python/entity_completion_comparison.py new file mode 100644 index 0000000000..dc62179def --- /dev/null +++ b/examples/python/entity_completion_comparison.py @@ -0,0 +1,163 @@ +import cognee +import asyncio +import logging + +from cognee.api.v1.search import SearchType +from cognee.shared.utils import setup_logging +from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever +from cognee.modules.retrieval.context_providers.TripletSearchContextProvider import ( + TripletSearchContextProvider, +) +from cognee.modules.retrieval.context_providers.SummarizedTripletSearchContextProvider import ( + SummarizedTripletSearchContextProvider, +) +from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor + +article_1 = """ +Title: The Theory of Relativity: A Revolutionary Breakthrough +Author: Dr. Sarah Chen + +Albert Einstein's theory of relativity fundamentally changed our understanding of space, time, and gravity. Published in 1915, the general theory of relativity describes gravity as a consequence of the curvature of spacetime caused by mass and energy. This groundbreaking work built upon his special theory of relativity from 1905, which introduced the famous equation E=mc². + +Einstein's work at the Swiss Patent Office gave him time to develop these revolutionary ideas. His mathematical framework predicted several phenomena that were later confirmed, including: +- The bending of light by gravity +- The precession of Mercury's orbit +- The existence of black holes + +The theory continues to be tested and validated today, most recently through the detection of gravitational waves by LIGO in 2015, exactly 100 years after its publication. +""" + +article_2 = """ +Title: The Manhattan Project and Its Scientific Director +Author: Prof. Michael Werner + +J. Robert Oppenheimer's leadership of the Manhattan Project marked a pivotal moment in scientific history. As scientific director of the Los Alamos Laboratory, he assembled and led an extraordinary team of physicists in the development of the atomic bomb during World War II. + +Oppenheimer's journey to Los Alamos began at Harvard and continued through his groundbreaking work in quantum mechanics and nuclear physics at Berkeley. His expertise in theoretical physics and exceptional leadership abilities made him the ideal candidate to head the secret weapons laboratory. + +Key aspects of his directorship included: +- Recruitment of top scientific talent from across the country +- Integration of theoretical physics with practical engineering challenges +- Development of implosion-type nuclear weapons +- Management of complex security and ethical considerations + +After witnessing the first nuclear test, codenamed Trinity, Oppenheimer famously quoted the Bhagavad Gita: "Now I am become Death, the destroyer of worlds." This moment reflected the profound moral implications of scientific advancement that would shape his later advocacy for international atomic controls. +""" + +article_3 = """ +Title: The Birth of Quantum Physics +Author: Dr. Lisa Martinez + +The early 20th century witnessed a revolutionary transformation in our understanding of the microscopic world. The development of quantum mechanics emerged from the collaborative efforts of numerous brilliant physicists grappling with phenomena that classical physics couldn't explain. + +Key contributors and their insights included: +- Max Planck's discovery of energy quantization (1900) +- Niels Bohr's model of the atom with discrete energy levels (1913) +- Werner Heisenberg's uncertainty principle (1927) +- Erwin Schrödinger's wave equation (1926) +- Paul Dirac's quantum theory of the electron (1928) + +Einstein's 1905 paper on the photoelectric effect, which demonstrated light's particle nature, was a crucial contribution to this field. The Copenhagen interpretation, developed primarily by Bohr and Heisenberg, became the standard understanding of quantum mechanics, despite ongoing debates about its philosophical implications. These foundational developments continue to influence modern physics, from quantum computing to quantum field theory. +""" + + +async def main(enable_steps): + # Step 1: Reset data and system state + if enable_steps.get("prune_data"): + await cognee.prune.prune_data() + print("Data pruned.") + + if enable_steps.get("prune_system"): + await cognee.prune.prune_system(metadata=True) + print("System pruned.") + + # Step 2: Add text + if enable_steps.get("add_text"): + text_list = [article_1, article_2, article_3] + for text in text_list: + await cognee.add(text) + print(f"Added text: {text[:50]}...") + + # Step 3: Create knowledge graph + if enable_steps.get("cognify"): + await cognee.cognify() + print("Knowledge graph created.") + + # Step 4: Query insights using our new retrievers + if enable_steps.get("retriever"): + # Common settings + search_settings = { + "top_k": 5, + "collections": ["Entity_name", "TextSummary_text"], + "properties_to_project": ["name", "description", "text"], + } + + # Create both context providers + direct_provider = TripletSearchContextProvider(**search_settings) + summary_provider = SummarizedTripletSearchContextProvider(**search_settings) + + # Create retrievers with different providers + direct_retriever = EntityCompletionRetriever( + extractor=DummyEntityExtractor(), + context_provider=direct_provider, + system_prompt_path="answer_simple_question.txt", + user_prompt_path="context_for_question.txt", + ) + + summary_retriever = EntityCompletionRetriever( + extractor=DummyEntityExtractor(), + context_provider=summary_provider, + system_prompt_path="answer_simple_question.txt", + user_prompt_path="context_for_question.txt", + ) + + query = "What were the early contributions to quantum physics?" + print("\nQuery:", query) + + # Try with direct triplets + print("\n=== Direct Triplets ===") + context = await direct_retriever.get_context(query) + print("\nEntity Context:") + print(context) + + result = await direct_retriever.get_completion(query) + print("\nEntity Completion:") + print(result) + + # Try with summarized triplets + print("\n=== Summarized Triplets ===") + context = await summary_retriever.get_context(query) + print("\nEntity Context:") + print(context) + + result = await summary_retriever.get_completion(query) + print("\nEntity Completion:") + print(result) + + # Compare with standard search + print("\n=== Standard Search ===") + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=query + ) + print(search_results) + + +if __name__ == "__main__": + setup_logging(logging.ERROR) + + rebuild_kg = True + retrieve = True + steps_to_enable = { + "prune_data": rebuild_kg, + "prune_system": rebuild_kg, + "add_text": rebuild_kg, + "cognify": rebuild_kg, + "retriever": retrieve, + } + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main(steps_to_enable)) + finally: + loop.run_until_complete(loop.shutdown_asyncgens())