1+ from typing import Callable
2+
13from pydantic_core import PydanticUndefined
4+
25from cognee .infrastructure .engine import DataPoint
36from cognee .modules .storage .utils import copy_model
47
5- def merge_dicts (dict1 , dict2 , agg_fn ):
8+
9+ def merge_dicts (dict1 : dict , dict2 : dict , agg_fn : Callable ) -> dict :
610 merged_dict = {}
711 for key , value in dict1 .items ():
812 if key in dict2 :
@@ -15,22 +19,38 @@ def merge_dicts(dict1, dict2, agg_fn):
1519 merged_dict [key ] = value
1620 return merged_dict
1721
18- def get_model_instance_from_graph (nodes : list [DataPoint ], edges : list [tuple [str , str , str , dict [str , str ]]], entity_id : str ):
22+
23+ def get_model_instance_from_graph (
24+ nodes : list [DataPoint ],
25+ edges : list [tuple [str , str , str , dict [str , str ]]],
26+ entity_id : str ,
27+ ):
1928 node_map = {node .id : node for node in nodes }
2029
2130 for source_node_id , target_node_id , edge_label , edge_properties in edges :
2231 source_node = node_map [source_node_id ]
2332 target_node = node_map [target_node_id ]
2433 edge_metadata = edge_properties .get ("metadata" , {})
25- edge_type = edge_metadata .get ("type" )
34+ edge_type = edge_metadata .get ("type" , "default" )
2635
2736 if edge_type == "list" :
28- NewModel = copy_model (type (source_node ), { edge_label : (list [type (target_node )], PydanticUndefined ) })
29- new_model_dict = merge_dicts (source_node .model_dump (), { edge_label : [target_node ] }, lambda a , b : a + b )
37+ NewModel = copy_model (
38+ type (source_node ),
39+ {edge_label : (list [type (target_node )], PydanticUndefined )},
40+ )
41+ new_model_dict = merge_dicts (
42+ source_node .model_dump (),
43+ {edge_label : [target_node ]},
44+ lambda a , b : a + b ,
45+ )
3046 node_map [source_node_id ] = NewModel (** new_model_dict )
3147 else :
32- NewModel = copy_model (type (source_node ), { edge_label : (type (target_node ), PydanticUndefined ) })
48+ NewModel = copy_model (
49+ type (source_node ), {edge_label : (type (target_node ), PydanticUndefined )}
50+ )
3351
34- node_map [target_node_id ] = NewModel (** source_node .model_dump (), ** { edge_label : target_node })
52+ node_map [target_node_id ] = NewModel (
53+ ** source_node .model_dump (), ** {edge_label : target_node }
54+ )
3555
3656 return node_map [entity_id ]
0 commit comments