From 0073d2e3dac7b82955266b23201ae0c2fcaba182 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 23 May 2025 14:52:39 +0200 Subject: [PATCH 1/9] feat: Adds get_nodeset_subgraph to graphdb adapters --- .../databases/exceptions/exceptions.py | 14 ++++++ .../databases/graph/graph_db_interface.py | 9 +++- .../databases/graph/kuzu/adapter.py | 9 +++- .../graph/memgraph/memgraph_adapter.py | 9 +++- .../databases/graph/neo4j_driver/adapter.py | 50 ++++++++++++++++++- .../databases/graph/networkx/adapter.py | 10 +++- 6 files changed, 96 insertions(+), 5 deletions(-) diff --git a/cognee/infrastructure/databases/exceptions/exceptions.py b/cognee/infrastructure/databases/exceptions/exceptions.py index eacfc40951..b6d860909e 100644 --- a/cognee/infrastructure/databases/exceptions/exceptions.py +++ b/cognee/infrastructure/databases/exceptions/exceptions.py @@ -37,3 +37,17 @@ def __init__( status_code=status.HTTP_409_CONFLICT, ): super().__init__(message, name, status_code) + + +class NodesetFilterNotSupportedError(CogneeApiError): + """Nodeset filter is not supported by the current database""" + + def __init__( + self, + message: str = "The nodeset filter is not supported in the current graph database.", + name: str = "NodeSetFilterNotSupportedError", + status_code=status.HTTP_404_NOT_FOUND, + ): + self.message = message + self.name = name + self.status_code = status_code diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 8c582107f7..aea6a72d26 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -2,7 +2,7 @@ from functools import wraps from abc import abstractmethod, ABC from datetime import datetime, timezone -from typing import Optional, Dict, Any, List, Tuple +from typing import Optional, Dict, Any, List, Tuple, Type from uuid import NAMESPACE_OID, UUID, uuid5 from cognee.shared.logging_utils import get_logger from cognee.infrastructure.engine import DataPoint @@ -183,6 +183,13 @@ async def get_neighbors(self, node_id: str) -> List[NodeData]: """Get all neighboring nodes.""" raise NotImplementedError + @abstractmethod + async def get_nodeset_subgraph( + self, node_type: Type[Any], node_name: List[str] + ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: + """Get nodeset subgraph""" + raise NotImplementedError + @abstractmethod async def get_connections( self, node_id: str diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 380f3f7132..20fd4dfacd 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -1,11 +1,12 @@ """Adapter for Kuzu graph database.""" +from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError from cognee.shared.logging_utils import get_logger import json import os import shutil import asyncio -from typing import Dict, Any, List, Union, Optional, Tuple +from typing import Dict, Any, List, Union, Optional, Tuple, Type from datetime import datetime, timezone from uuid import UUID from contextlib import asynccontextmanager @@ -728,6 +729,12 @@ async def get_graph_data( logger.error(f"Failed to get graph data: {e}") raise + async def get_nodeset_subgraph( + self, node_type: Type[Any], node_name: List[str] + ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: + """Get nodeset subgraph""" + raise NodesetFilterNotSupportedError + async def get_filtered_graph_data( self, attribute_filters: List[Dict[str, List[Union[str, int]]]] ): diff --git a/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py b/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py index 5ef4380770..b86b7d08a8 100644 --- a/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py +++ b/cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py @@ -4,7 +4,7 @@ from cognee.shared.logging_utils import get_logger, ERROR import asyncio from textwrap import dedent -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Type, Tuple from contextlib import asynccontextmanager from uuid import UUID from neo4j import AsyncSession @@ -13,6 +13,7 @@ from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.modules.storage.utils import JSONEncoder +from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError logger = get_logger("MemgraphAdapter", level=ERROR) @@ -482,6 +483,12 @@ async def get_graph_data(self): return (nodes, edges) + async def get_nodeset_subgraph( + self, node_type: Type[Any], node_name: List[str] + ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: + """Get nodeset subgraph""" + raise NodesetFilterNotSupportedError + async def get_filtered_graph_data(self, attribute_filters): """ Fetches nodes and relationships filtered by specified attribute values. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index ccc76cbcf1..906db02660 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -6,7 +6,7 @@ from cognee.shared.logging_utils import get_logger, ERROR import asyncio from textwrap import dedent -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Type, Tuple from contextlib import asynccontextmanager from uuid import UUID from neo4j import AsyncSession @@ -517,6 +517,54 @@ async def get_graph_data(self): return (nodes, edges) + async def get_nodeset_subgraph( + self, node_type: Type[Any], node_name: List[str] + ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: + label = node_type.__name__ + + query = f""" + UNWIND $names AS wantedName + MATCH (n:`{label}`) + WHERE n.name = wantedName + WITH collect(DISTINCT n) AS primary + UNWIND primary AS p + OPTIONAL MATCH (p)--(nbr) + WITH primary, collect(DISTINCT nbr) AS nbrs + WITH primary + nbrs AS nodelist + UNWIND nodelist AS node + WITH collect(DISTINCT node) AS nodes + MATCH (a)-[r]-(b) + WHERE a IN nodes AND b IN nodes + WITH nodes, collect(DISTINCT r) AS rels + RETURN + [n IN nodes | + {{ id: n.id, + properties: properties(n) }}] AS rawNodes, + [r IN rels | + {{ type: type(r), + properties: properties(r) }}] AS rawRels + """ + + result = await self.query(query, {"names": node_name}) + if not result: + return [], [] + + raw_nodes = result[0]["rawNodes"] + raw_rels = result[0]["rawRels"] + + nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes] + edges = [ + ( + r["properties"]["source_node_id"], + r["properties"]["target_node_id"], + r["type"], + r["properties"], + ) + for r in raw_rels + ] + + return nodes, edges + async def get_filtered_graph_data(self, attribute_filters): """ Fetches nodes and relationships filtered by specified attribute values. diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index 2c86dba0a6..9f57bfcb6d 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -4,8 +4,10 @@ import os import json import asyncio + +from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError from cognee.shared.logging_utils import get_logger -from typing import Dict, Any, List, Union +from typing import Dict, Any, List, Union, Type, Tuple from uuid import UUID import aiofiles import aiofiles.os as aiofiles_os @@ -396,6 +398,12 @@ async def delete_graph(self, file_path: str = None): logger.error("Failed to delete graph: %s", error) raise error + async def get_nodeset_subgraph( + self, node_type: Type[Any], node_name: List[str] + ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: + """Get nodeset subgraph""" + raise NodesetFilterNotSupportedError + async def get_filtered_graph_data( self, attribute_filters: List[Dict[str, List[Union[str, int]]]] ): From db15389edc1707b0f6ca3b993e6a7845251dcf25 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 23 May 2025 14:52:58 +0200 Subject: [PATCH 2/9] chore: removes Nodeset node embedding --- cognee/modules/engine/models/node_set.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cognee/modules/engine/models/node_set.py b/cognee/modules/engine/models/node_set.py index 33fe3f5577..480752bcab 100644 --- a/cognee/modules/engine/models/node_set.py +++ b/cognee/modules/engine/models/node_set.py @@ -5,4 +5,3 @@ class NodeSet(DataPoint): """NodeSet data point.""" name: str - metadata: dict = {"index_fields": ["name"]} From c42a58191db94d0c034fb26bbf9acba28062729c Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 23 May 2025 14:53:27 +0200 Subject: [PATCH 3/9] feat: adds nodeset to entity connections --- cognee/modules/graph/utils/expand_with_nodes_and_edges.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index 7759306380..b392094aa8 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -95,6 +95,7 @@ def expand_with_nodes_and_edges( name=ont_node_name, description=ont_node_name, ontology_valid=True, + belongs_to_set=data_chunk.belongs_to_set, ) for source, relation, target in ontology_entity_type_edges: @@ -144,6 +145,7 @@ def expand_with_nodes_and_edges( is_a=type_node, description=node.description, ontology_valid=ontology_validated_source_ent, + belongs_to_set=data_chunk.belongs_to_set, ) added_nodes_map[entity_node_key] = entity_node @@ -174,6 +176,7 @@ def expand_with_nodes_and_edges( name=ont_node_name, description=ont_node_name, ontology_valid=True, + belongs_to_set=data_chunk.belongs_to_set, ) for source, relation, target in ontology_entity_edges: From 1b6699eca8dcea5d536b350f9387e1784a1fa6eb Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 23 May 2025 14:55:09 +0200 Subject: [PATCH 4/9] feat: exposes node_type and node_name parameters along all the searches --- cognee/api/v1/search/search.py | 14 ++++++---- .../modules/graph/cognee_graph/CogneeGraph.py | 6 ++-- ..._completion_context_extension_retriever.py | 6 +++- .../graph_completion_cot_retriever.py | 6 +++- .../retrieval/graph_completion_retriever.py | 12 ++++++-- .../graph_summary_completion_retriever.py | 6 +++- .../utils/brute_force_triplet_search.py | 16 +++++++++-- cognee/modules/search/methods/search.py | 28 +++++++++++++++---- 8 files changed, 75 insertions(+), 19 deletions(-) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 723f41bdb4..0e817c14eb 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional, List, Type from cognee.modules.users.models import User from cognee.modules.search.types import SearchType @@ -13,6 +13,8 @@ async def search( datasets: Union[list[str], str, None] = None, system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ) -> list: # We use lists from now on for datasets if isinstance(datasets, str): @@ -22,12 +24,14 @@ async def search( user = await get_default_user() filtered_search_results = await search_function( - query_text, - query_type, - datasets, - user, + query_text=query_text, + query_type=query_type, + datasets=datasets, + user=user, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, ) return filtered_search_results diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index f79bbd010d..45d0d048fe 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,5 +1,5 @@ from cognee.shared.logging_utils import get_logger -from typing import List, Dict, Union +from typing import List, Dict, Union, Optional, Type from cognee.exceptions import InvalidValueError from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError @@ -61,10 +61,12 @@ async def project_graph_from_db( node_dimension=1, edge_dimension=1, memory_fragment_filter=[], + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ) -> None: if node_dimension < 1 or edge_dimension < 1: raise InvalidValueError(message="Dimensions must be positive integers") - + # :TODO: NODESET IMPLEMENTATION CONTINUE HERE try: if len(memory_fragment_filter) == 0: nodes_data, edges_data = await adapter.get_graph_data() diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 26364b2ff0..fa418d3489 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List +from typing import Any, Optional, List, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -14,11 +14,15 @@ def __init__( user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", top_k: Optional[int] = 5, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ): super().__init__( user_prompt_path=user_prompt_path, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, ) async def get_completion( diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 012e9574d3..2df7b8b8b6 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List +from typing import Any, Optional, List, Type from cognee.shared.logging_utils import get_logger from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -18,11 +18,15 @@ def __init__( followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt", top_k: Optional[int] = 5, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ): super().__init__( user_prompt_path=user_prompt_path, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, ) self.validation_system_prompt_path = validation_system_prompt_path self.validation_user_prompt_path = validation_user_prompt_path diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 98ac5dd614..ecdfdf111f 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Type, List from collections import Counter import string @@ -18,11 +18,15 @@ def __init__( user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", top_k: Optional[int] = 5, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ): """Initialize retriever with prompt paths and search parameters.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 5 + self.node_type = node_type + self.node_name = node_name def _get_nodes(self, retrieved_edges: list) -> dict: """Creates a dictionary of nodes with their names and content.""" @@ -68,7 +72,11 @@ async def get_triplets(self, query: str) -> list: vector_index_collections.append(f"{subclass.__name__}_{field_name}") found_triplets = await brute_force_triplet_search( - query, top_k=self.top_k, collections=vector_index_collections or None + query, + top_k=self.top_k, + collections=vector_index_collections or None, + node_type=self.node_type, + node_name=self.node_name, ) return found_triplets diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 76ed5f5d47..bdd3dd6ebb 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type, List from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import summarize_text @@ -13,12 +13,16 @@ def __init__( system_prompt_path: str = "answer_simple_question.txt", summarize_prompt_path: str = "summarize_search_results.txt", top_k: Optional[int] = 5, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__( user_prompt_path=user_prompt_path, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, ) self.summarize_prompt_path = summarize_prompt_path diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 0a08fbd002..06f11ab26d 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Optional +from typing import List, Optional, Type from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError @@ -55,6 +55,8 @@ def filter_attributes(obj, attributes): async def get_memory_fragment( properties_to_project: Optional[List[str]] = None, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ) -> CogneeGraph: """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" graph_engine = await get_graph_engine() @@ -68,6 +70,8 @@ async def get_memory_fragment( graph_engine, node_properties_to_project=properties_to_project, edge_properties_to_project=["relationship_name"], + node_type=node_type, + node_name=node_name, ) except EntityNotFoundError: pass @@ -82,6 +86,8 @@ async def brute_force_triplet_search( collections: List[str] = None, properties_to_project: List[str] = None, memory_fragment: Optional[CogneeGraph] = None, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ) -> list: if user is None: user = await get_default_user() @@ -93,6 +99,8 @@ async def brute_force_triplet_search( collections=collections, properties_to_project=properties_to_project, memory_fragment=memory_fragment, + node_type=node_type, + node_name=node_name, ) return retrieved_results @@ -104,6 +112,8 @@ async def brute_force_search( collections: List[str] = None, properties_to_project: List[str] = None, memory_fragment: Optional[CogneeGraph] = None, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ) -> list: """ Performs a brute force search to retrieve the top triplets from the graph. @@ -125,7 +135,9 @@ async def brute_force_search( raise ValueError("top_k must be a positive integer.") if memory_fragment is None: - memory_fragment = await get_memory_fragment(properties_to_project) + memory_fragment = await get_memory_fragment( + properties_to_project, node_type=node_type, node_name=node_name + ) if collections is None: collections = [ diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 63c25924f2..4b43ebf31b 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -1,5 +1,5 @@ import json -from typing import Callable +from typing import Callable, Optional, List, Type from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine.utils import parse_id @@ -33,12 +33,20 @@ async def search( user: User, system_prompt_path="answer_simple_question.txt", top_k: int = 10, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ): query = await log_query(query_text, query_type.value, user.id) own_document_ids = await get_document_ids_for_user(user.id, datasets) search_results = await specific_search( - query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k + query_type, + query_text, + user, + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, ) filtered_search_results = [] @@ -61,29 +69,39 @@ async def specific_search( user: User, system_prompt_path="answer_simple_question.txt", top_k: int = 10, + node_type: Optional[Type] = None, + node_name: Optional[List[str]] = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, SearchType.RAG_COMPLETION: CompletionRetriever( - system_prompt_path=system_prompt_path, - top_k=top_k, + system_prompt_path=system_prompt_path, top_k=top_k ).get_completion, SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, ).get_completion, SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, ).get_completion, SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( - system_prompt_path=system_prompt_path, top_k=top_k + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, ).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion, From e2b5cc5e7fcdde73c7988dbd95c9fb3df899e611 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 26 May 2025 13:42:32 +0200 Subject: [PATCH 5/9] feat: adds subgraph filtering to CogneeGraph --- .../modules/graph/cognee_graph/CogneeGraph.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 45d0d048fe..ada8821ebe 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -66,19 +66,28 @@ async def project_graph_from_db( ) -> None: if node_dimension < 1 or edge_dimension < 1: raise InvalidValueError(message="Dimensions must be positive integers") - # :TODO: NODESET IMPLEMENTATION CONTINUE HERE try: - if len(memory_fragment_filter) == 0: + if node_type is not None and node_name is not None: + nodes_data, edges_data = await adapter.get_nodeset_subgraph( + node_type=node_type, node_name=node_name + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError( + message="Nodeset does not exist, or empty nodetes projected from the database." + ) + elif len(memory_fragment_filter) == 0: nodes_data, edges_data = await adapter.get_graph_data() + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty graph projected from the database.") else: nodes_data, edges_data = await adapter.get_filtered_graph_data( attribute_filters=memory_fragment_filter ) - if not nodes_data: - raise EntityNotFoundError(message="No node data retrieved from the database.") - if not edges_data: - raise EntityNotFoundError(message="No edge data retrieved from the database.") + if not nodes_data or not edges_data: + raise EntityNotFoundError( + message="Empty filtered graph projected from the database." + ) for node_id, properties in nodes_data: node_attributes = {key: properties.get(key) for key in node_properties_to_project} From 6fc27e811d9ba82ce76690b9d7cdb40a736d103d Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 26 May 2025 13:43:34 +0200 Subject: [PATCH 6/9] feat: Adds Kuzu support for subgraph filter retriever --- .../databases/graph/kuzu/adapter.py | 59 ++++++++++++++++++- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 20fd4dfacd..033e06e7a4 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -731,9 +731,62 @@ async def get_graph_data( async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] - ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: - """Get nodeset subgraph""" - raise NodesetFilterNotSupportedError + ) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]: + label = node_type.__name__ + primary_query = """ + UNWIND $names AS wantedName + MATCH (n:Node) + WHERE n.type = $label AND n.name = wantedName + RETURN DISTINCT n.id + """ + primary_rows = await self.query(primary_query, {"names": node_name, "label": label}) + primary_ids = [row[0] for row in primary_rows] + if not primary_ids: + return [], [] + + neighbor_query = """ + MATCH (n:Node)-[:EDGE]-(nbr:Node) + WHERE n.id IN $ids + RETURN DISTINCT nbr.id + """ + nbr_rows = await self.query(neighbor_query, {"ids": primary_ids}) + neighbor_ids = [row[0] for row in nbr_rows] + + all_ids = list({*primary_ids, *neighbor_ids}) + + nodes_query = """ + MATCH (n:Node) + WHERE n.id IN $ids + RETURN n.id, n.name, n.type, n.properties + """ + node_rows = await self.query(nodes_query, {"ids": all_ids}) + nodes: List[Tuple[str, dict]] = [] + for node_id, name, typ, props in node_rows: + data = {"id": node_id, "name": name, "type": typ} + if props: + try: + data.update(json.loads(props)) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON props for node {node_id}") + nodes.append((node_id, data)) + + edges_query = """ + MATCH (a:Node)-[r:EDGE]-(b:Node) + WHERE a.id IN $ids AND b.id IN $ids + RETURN a.id, b.id, r.relationship_name, r.properties + """ + edge_rows = await self.query(edges_query, {"ids": all_ids}) + edges: List[Tuple[str, str, str, dict]] = [] + for from_id, to_id, rel_type, props in edge_rows: + data = {} + if props: + try: + data = json.loads(props) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}") + edges.append((from_id, to_id, rel_type, data)) + + return nodes, edges async def get_filtered_graph_data( self, attribute_filters: List[Dict[str, List[Union[str, int]]]] From 3aa37e4f06b2c7eca7ab848852410b01408b3d09 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 26 May 2025 14:00:52 +0200 Subject: [PATCH 7/9] chore: adds new params to search unit test --- cognee/tests/unit/modules/search/search_methods_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index f8e440ca49..7a76d5f7cd 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from pylint.checkers.utils import node_type from cognee.exceptions import InvalidValueError from cognee.modules.search.methods.search import search, specific_search @@ -68,7 +69,13 @@ async def test_search( mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) mock_get_document_ids.assert_called_once_with(mock_user.id, datasets) mock_specific_search.assert_called_once_with( - query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10 + query_type, + query_text, + mock_user, + system_prompt_path="answer_simple_question.txt", + top_k=10, + node_type=None, + node_name=None, ) # Only the first two results should be included (doc_id3 is filtered out) From b84e04a2ddf8a50b3c74c6cdd81b565e2ce774c2 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 26 May 2025 16:43:14 +0200 Subject: [PATCH 8/9] feat: adds Neo4j test for NodeSet search --- cognee/tests/test_neo4j.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index e87c656a0a..1d4af2b2ce 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -1,10 +1,12 @@ import os import pathlib import cognee +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.search.operations import get_history from cognee.modules.users.methods import get_default_user from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType +from cognee.modules.engine.models import NodeSet logger = get_logger() @@ -89,6 +91,30 @@ async def main(): assert len(history) == 6, "Search history is not correct." + nodeset_text = "Neo4j is a graph database that supports cypher." + + await cognee.add([nodeset_text], dataset_name, node_set=["first"]) + + await cognee.cognify([dataset_name]) + + context_nonempty = await GraphCompletionRetriever( + node_type=NodeSet, + node_name=["first"], + ).get_context("What is in the context?") + + context_empty = await GraphCompletionRetriever( + node_type=NodeSet, + node_name=["nonexistent"], + ).get_context("What is in the context?") + + assert isinstance(context_nonempty, str) and context_nonempty != "", ( + f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" + ) + + assert context_empty == "", ( + f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" + ) + await cognee.prune.prune_data() assert not os.path.isdir(data_directory_path), "Local data files are not deleted" From 94ae7f57ea9c85076d299e6dea61ee1c92f8fa55 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 26 May 2025 17:02:02 +0200 Subject: [PATCH 9/9] feat: adds kuzu nodeset search test --- cognee/tests/test_kuzu.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/cognee/tests/test_kuzu.py b/cognee/tests/test_kuzu.py index d0b03057ed..c2a2389abb 100644 --- a/cognee/tests/test_kuzu.py +++ b/cognee/tests/test_kuzu.py @@ -2,6 +2,9 @@ import shutil import cognee import pathlib + +from cognee.modules.engine.models import NodeSet +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.modules.search.operations import get_history @@ -84,6 +87,30 @@ async def main(): history = await get_history(user.id) assert len(history) == 6, "Search history is not correct." + nodeset_text = "Neo4j is a graph database that supports cypher." + + await cognee.add([nodeset_text], dataset_name, node_set=["first"]) + + await cognee.cognify([dataset_name]) + + context_nonempty = await GraphCompletionRetriever( + node_type=NodeSet, + node_name=["first"], + ).get_context("What is in the context?") + + context_empty = await GraphCompletionRetriever( + node_type=NodeSet, + node_name=["nonexistent"], + ).get_context("What is in the context?") + + assert isinstance(context_nonempty, str) and context_nonempty != "", ( + f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" + ) + + assert context_empty == "", ( + f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" + ) + await cognee.prune.prune_data() assert not os.path.isdir(data_directory_path), "Local data files are not deleted"