Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions cognee/api/v1/search/search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional, List, Type

from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType
Expand All @@ -13,6 +13,8 @@ async def search(
datasets: Union[list[str], str, None] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
Expand All @@ -22,12 +24,14 @@ async def search(
user = await get_default_user()

filtered_search_results = await search_function(
query_text,
query_type,
datasets,
user,
query_text=query_text,
query_type=query_type,
datasets=datasets,
user=user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)

return filtered_search_results
14 changes: 14 additions & 0 deletions cognee/infrastructure/databases/exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ def __init__(
status_code=status.HTTP_409_CONFLICT,
):
super().__init__(message, name, status_code)


class NodesetFilterNotSupportedError(CogneeApiError):
"""Nodeset filter is not supported by the current database"""

def __init__(
self,
message: str = "The nodeset filter is not supported in the current graph database.",
name: str = "NodeSetFilterNotSupportedError",
status_code=status.HTTP_404_NOT_FOUND,
):
self.message = message
self.name = name
self.status_code = status_code
9 changes: 8 additions & 1 deletion cognee/infrastructure/databases/graph/graph_db_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import wraps
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple
from typing import Optional, Dict, Any, List, Tuple, Type
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
Expand Down Expand Up @@ -183,6 +183,13 @@ async def get_neighbors(self, node_id: str) -> List[NodeData]:
"""Get all neighboring nodes."""
raise NotImplementedError

@abstractmethod
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NotImplementedError

@abstractmethod
async def get_connections(
self, node_id: str
Expand Down
62 changes: 61 additions & 1 deletion cognee/infrastructure/databases/graph/kuzu/adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Adapter for Kuzu graph database."""

from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.shared.logging_utils import get_logger
import json
import os
import shutil
import asyncio
from typing import Dict, Any, List, Union, Optional, Tuple
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from datetime import datetime, timezone
from uuid import UUID
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -728,6 +729,65 @@ async def get_graph_data(
logger.error(f"Failed to get graph data: {e}")
raise

async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
label = node_type.__name__
primary_query = """
UNWIND $names AS wantedName
MATCH (n:Node)
WHERE n.type = $label AND n.name = wantedName
RETURN DISTINCT n.id
"""
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
primary_ids = [row[0] for row in primary_rows]
if not primary_ids:
return [], []

neighbor_query = """
MATCH (n:Node)-[:EDGE]-(nbr:Node)
WHERE n.id IN $ids
RETURN DISTINCT nbr.id
"""
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
neighbor_ids = [row[0] for row in nbr_rows]

all_ids = list({*primary_ids, *neighbor_ids})

nodes_query = """
MATCH (n:Node)
WHERE n.id IN $ids
RETURN n.id, n.name, n.type, n.properties
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, dict]] = []
for node_id, name, typ, props in node_rows:
data = {"id": node_id, "name": name, "type": typ}
if props:
try:
data.update(json.loads(props))
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for node {node_id}")
nodes.append((node_id, data))

edges_query = """
MATCH (a:Node)-[r:EDGE]-(b:Node)
WHERE a.id IN $ids AND b.id IN $ids
RETURN a.id, b.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, dict]] = []
for from_id, to_id, rel_type, props in edge_rows:
data = {}
if props:
try:
data = json.loads(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, data))

return nodes, edges

async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from cognee.shared.logging_utils import get_logger, ERROR
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
Expand All @@ -13,6 +13,7 @@
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.storage.utils import JSONEncoder
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError

logger = get_logger("MemgraphAdapter", level=ERROR)

Expand Down Expand Up @@ -482,6 +483,12 @@ async def get_graph_data(self):

return (nodes, edges)

async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NodesetFilterNotSupportedError

async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.
Expand Down
50 changes: 49 additions & 1 deletion cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from cognee.shared.logging_utils import get_logger, ERROR
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
Expand Down Expand Up @@ -517,6 +517,54 @@ async def get_graph_data(self):

return (nodes, edges)

async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
label = node_type.__name__

query = f"""
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""

result = await self.query(query, {"names": node_name})
if not result:
return [], []

raw_nodes = result[0]["rawNodes"]
raw_rels = result[0]["rawRels"]

nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
edges = [
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
)
for r in raw_rels
]

return nodes, edges

async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.
Expand Down
10 changes: 9 additions & 1 deletion cognee/infrastructure/databases/graph/networkx/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import os
import json
import asyncio

from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.shared.logging_utils import get_logger
from typing import Dict, Any, List, Union
from typing import Dict, Any, List, Union, Type, Tuple
from uuid import UUID
import aiofiles
import aiofiles.os as aiofiles_os
Expand Down Expand Up @@ -396,6 +398,12 @@ async def delete_graph(self, file_path: str = None):
logger.error("Failed to delete graph: %s", error)
raise error

async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NodesetFilterNotSupportedError

async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):
Expand Down
1 change: 0 additions & 1 deletion cognee/modules/engine/models/node_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ class NodeSet(DataPoint):
"""NodeSet data point."""

name: str
metadata: dict = {"index_fields": ["name"]}
25 changes: 18 additions & 7 deletions cognee/modules/graph/cognee_graph/CogneeGraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union
from typing import List, Dict, Union, Optional, Type

from cognee.exceptions import InvalidValueError
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
Expand Down Expand Up @@ -61,22 +61,33 @@ async def project_graph_from_db(
node_dimension=1,
edge_dimension=1,
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers")

try:
if len(memory_fragment_filter) == 0:
if node_type is not None and node_name is not None:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodetes projected from the database."
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)

if not nodes_data:
raise EntityNotFoundError(message="No node data retrieved from the database.")
if not edges_data:
raise EntityNotFoundError(message="No edge data retrieved from the database.")
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Empty filtered graph projected from the database."
)

for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
Expand Down
3 changes: 3 additions & 0 deletions cognee/modules/graph/utils/expand_with_nodes_and_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def expand_with_nodes_and_edges(
name=ont_node_name,
description=ont_node_name,
ontology_valid=True,
belongs_to_set=data_chunk.belongs_to_set,
)

for source, relation, target in ontology_entity_type_edges:
Expand Down Expand Up @@ -144,6 +145,7 @@ def expand_with_nodes_and_edges(
is_a=type_node,
description=node.description,
ontology_valid=ontology_validated_source_ent,
belongs_to_set=data_chunk.belongs_to_set,
)

added_nodes_map[entity_node_key] = entity_node
Expand Down Expand Up @@ -174,6 +176,7 @@ def expand_with_nodes_and_edges(
name=ont_node_name,
description=ont_node_name,
ontology_valid=True,
belongs_to_set=data_chunk.belongs_to_set,
)

for source, relation, target in ontology_entity_edges:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
Expand All @@ -14,11 +14,15 @@ def __init__(
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)

async def get_completion(
Expand Down
6 changes: 5 additions & 1 deletion cognee/modules/retrieval/graph_completion_cot_retriever.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
Expand All @@ -18,11 +18,15 @@ def __init__(
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path
Expand Down
Loading
Loading