Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 67 additions & 82 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from datetime import datetime, timezone

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model


def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):

if not added_nodes:
added_nodes = {}
if not added_edges:
added_edges = {}

def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid mutable default arguments to prevent unintended behavior

Using mutable default arguments like {} for added_nodes and added_edges can lead to unexpected behavior because the default dictionaries are shared across all function calls. It's recommended to use None as the default value and initialize the dictionaries within the function.

Apply this diff to fix the issue:

-def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
+def get_graph_from_model(data_point: DataPoint, include_root=True, added_nodes=None, added_edges=None):
+    if added_nodes is None:
+        added_nodes = {}
+    if added_edges is None:
+        added_edges = {}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
def get_graph_from_model(data_point: DataPoint, include_root=True, added_nodes=None, added_edges=None):
if added_nodes is None:
added_nodes = {}
if added_edges is None:
added_edges = {}

nodes = []
edges = []

Expand All @@ -20,94 +12,87 @@ def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=No
for field_name, field_value in data_point:
if field_name == "_metadata":
continue
elif isinstance(field_value, DataPoint):

if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
)

elif (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):

property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue

if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)

for item in field_value:
n_edges_before = len(edges)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, item, nodes, edges, added_nodes, added_edges
)
edges = edges[:n_edges_before] + [
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges[n_edges_before:]
]
else:
data_point_properties[field_name] = field_value
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure consistent key types in added_edges dictionary

In line 62, added_edges[edge_key] = True is missing the str() conversion used elsewhere. This inconsistency could lead to duplicate edges not being detected correctly due to mismatched key types.

Apply this diff to fix the inconsistency:

-                        added_edges[edge_key] = True
+                        added_edges[str(edge_key)] = True
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
added_edges[edge_key] = True
added_edges[str(edge_key)] = True


for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue

data_point_properties[field_name] = field_value

SimpleDataPointModel = copy_model(
type(data_point),
include_fields={
include_fields = {
"_metadata": (dict, data_point._metadata),
},
exclude_fields=excluded_properties,
exclude_fields = excluded_properties,
)

nodes.append(SimpleDataPointModel(**data_point_properties))
if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))

return nodes, edges


def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
):

property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]

if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True

return (nodes, edges, added_nodes, added_edges)


def get_own_properties(property_nodes, property_edges):
own_properties = []

Expand Down
44 changes: 16 additions & 28 deletions cognee/modules/graph/utils/get_model_instance_from_graph.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,29 @@
from typing import Callable

from pydantic_core import PydanticUndefined

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model


def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}

for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
for node in nodes:
node_map[node.id] = node

for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
Comment on lines +12 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add validation for edge tuple length to prevent index errors

Accessing elements of edge without ensuring it has the expected length can raise IndexError. Validate the length of edge before accessing its elements.

Apply this fix:

for edge in edges:
    if len(edge) < 3:
        # Handle error or skip invalid edge
        continue  # or raise an exception with a descriptive message
    source_node = node_map.get(edge[0])
    target_node = node_map.get(edge[1])
    if source_node is None or target_node is None:
        # Handle missing nodes
        continue  # or raise an exception
    edge_label = edge[2]
    edge_properties = edge[3] if len(edge) >= 4 else {}
    edge_metadata = edge_properties.get("metadata", {})
    # rest of the code

edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type", "default")
edge_type = edge_metadata.get("type")

if edge_type == "list":
NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]

node_map[source_node_id] = NewModel(**source_node_dict)
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Correct the syntax for specifying list type in copy_model

In line 21, list[type(target_node)] is invalid syntax. To specify a list of a type, use List[type(target_node)] from the typing module.

First, import List from typing:

+from typing import List

Then, correct line 21:

-            NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
+            NewModel = copy_model(type(source_node), { edge_label: (List[type(target_node)], PydanticUndefined) })

Committable suggestion skipped: line range outside the PR's diff.


node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
else:
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })

node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })

return node_map[entity_id]
Loading