Skip to content

Commit a334291

Browse files
committed
Apply cosmetic changes and autoformat
1 parent 5b420eb commit a334291

File tree

3 files changed

+39
-14
lines changed

3 files changed

+39
-14
lines changed

cognee/modules/graph/utils/get_graph_from_model.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from datetime import datetime, timezone
2+
23
from cognee.infrastructure.engine import DataPoint
34
from cognee.modules.storage.utils import copy_model
45

56

6-
def get_graph_from_model(
7-
data_point: DataPoint, added_nodes=None, added_edges=None
8-
):
7+
def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):
98

109
if not added_nodes:
1110
added_nodes = {}
@@ -24,7 +23,13 @@ def get_graph_from_model(
2423
elif isinstance(field_value, DataPoint):
2524
excluded_properties.add(field_name)
2625
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
27-
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
26+
data_point,
27+
field_name,
28+
field_value,
29+
nodes,
30+
edges,
31+
added_nodes,
32+
added_edges,
2833
)
2934

3035
elif (
@@ -35,12 +40,13 @@ def get_graph_from_model(
3540
excluded_properties.add(field_name)
3641

3742
for item in field_value:
43+
n_edges_before = len(edges)
3844
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
3945
data_point, field_name, item, nodes, edges, added_nodes, added_edges
4046
)
41-
edges = [
47+
edges = edges[:n_edges_before] + [
4248
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
43-
for edge in edges
49+
for edge in edges[n_edges_before:]
4450
]
4551
else:
4652
data_point_properties[field_name] = field_value
Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from typing import Callable
2+
13
from pydantic_core import PydanticUndefined
4+
25
from cognee.infrastructure.engine import DataPoint
36
from cognee.modules.storage.utils import copy_model
47

5-
def merge_dicts(dict1, dict2, agg_fn):
8+
9+
def merge_dicts(dict1: dict, dict2: dict, agg_fn: Callable) -> dict:
610
merged_dict = {}
711
for key, value in dict1.items():
812
if key in dict2:
@@ -15,22 +19,38 @@ def merge_dicts(dict1, dict2, agg_fn):
1519
merged_dict[key] = value
1620
return merged_dict
1721

18-
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list[tuple[str, str, str, dict[str, str]]], entity_id: str):
22+
23+
def get_model_instance_from_graph(
24+
nodes: list[DataPoint],
25+
edges: list[tuple[str, str, str, dict[str, str]]],
26+
entity_id: str,
27+
):
1928
node_map = {node.id: node for node in nodes}
2029

2130
for source_node_id, target_node_id, edge_label, edge_properties in edges:
2231
source_node = node_map[source_node_id]
2332
target_node = node_map[target_node_id]
2433
edge_metadata = edge_properties.get("metadata", {})
25-
edge_type = edge_metadata.get("type")
34+
edge_type = edge_metadata.get("type", "default")
2635

2736
if edge_type == "list":
28-
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
29-
new_model_dict = merge_dicts(source_node.model_dump(), { edge_label: [target_node] }, lambda a, b: a + b)
37+
NewModel = copy_model(
38+
type(source_node),
39+
{edge_label: (list[type(target_node)], PydanticUndefined)},
40+
)
41+
new_model_dict = merge_dicts(
42+
source_node.model_dump(),
43+
{edge_label: [target_node]},
44+
lambda a, b: a + b,
45+
)
3046
node_map[source_node_id] = NewModel(**new_model_dict)
3147
else:
32-
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
48+
NewModel = copy_model(
49+
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
50+
)
3351

34-
node_map[target_node_id] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
52+
node_map[target_node_id] = NewModel(
53+
**source_node.model_dump(), **{edge_label: target_node}
54+
)
3555

3656
return node_map[entity_id]

cognee/tests/unit/interfaces/graph/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def count_society(obj):
132132

133133

134134
def show_first_difference(str1, str2, str1_name, str2_name, context=30):
135-
"""Shows where two strings first diverge, with surrounding context."""
136135
for i, (c1, c2) in enumerate(zip(str1, str2)):
137136
if c1 != c2:
138137
start = max(0, i - context)

0 commit comments

Comments
 (0)