-
Notifications
You must be signed in to change notification settings - Fork 966
fix: fixes cognify duplicated edges and resets the methods to an olde… #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,16 +1,8 @@ | ||||||
| from datetime import datetime, timezone | ||||||
|
|
||||||
| from cognee.infrastructure.engine import DataPoint | ||||||
| from cognee.modules.storage.utils import copy_model | ||||||
|
|
||||||
|
|
||||||
| def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None): | ||||||
|
|
||||||
| if not added_nodes: | ||||||
| added_nodes = {} | ||||||
| if not added_edges: | ||||||
| added_edges = {} | ||||||
|
|
||||||
| def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): | ||||||
| nodes = [] | ||||||
| edges = [] | ||||||
|
|
||||||
|
|
@@ -20,94 +12,87 @@ def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=No | |||||
| for field_name, field_value in data_point: | ||||||
| if field_name == "_metadata": | ||||||
| continue | ||||||
| elif isinstance(field_value, DataPoint): | ||||||
|
|
||||||
| if isinstance(field_value, DataPoint): | ||||||
| excluded_properties.add(field_name) | ||||||
| nodes, edges, added_nodes, added_edges = add_nodes_and_edges( | ||||||
| data_point, | ||||||
| field_name, | ||||||
| field_value, | ||||||
| nodes, | ||||||
| edges, | ||||||
| added_nodes, | ||||||
| added_edges, | ||||||
| ) | ||||||
|
|
||||||
| elif ( | ||||||
| isinstance(field_value, list) | ||||||
| and len(field_value) > 0 | ||||||
| and isinstance(field_value[0], DataPoint) | ||||||
| ): | ||||||
|
|
||||||
| property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) | ||||||
|
|
||||||
| for node in property_nodes: | ||||||
| if str(node.id) not in added_nodes: | ||||||
| nodes.append(node) | ||||||
| added_nodes[str(node.id)] = True | ||||||
|
|
||||||
| for edge in property_edges: | ||||||
| edge_key = str(edge[0]) + str(edge[1]) + edge[2] | ||||||
|
|
||||||
| if str(edge_key) not in added_edges: | ||||||
| edges.append(edge) | ||||||
| added_edges[str(edge_key)] = True | ||||||
|
|
||||||
| for property_node in get_own_properties(property_nodes, property_edges): | ||||||
| edge_key = str(data_point.id) + str(property_node.id) + field_name | ||||||
|
|
||||||
| if str(edge_key) not in added_edges: | ||||||
| edges.append((data_point.id, property_node.id, field_name, { | ||||||
| "source_node_id": data_point.id, | ||||||
| "target_node_id": property_node.id, | ||||||
| "relationship_name": field_name, | ||||||
| "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), | ||||||
| })) | ||||||
| added_edges[str(edge_key)] = True | ||||||
| continue | ||||||
|
|
||||||
| if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): | ||||||
| excluded_properties.add(field_name) | ||||||
|
|
||||||
| for item in field_value: | ||||||
| n_edges_before = len(edges) | ||||||
| nodes, edges, added_nodes, added_edges = add_nodes_and_edges( | ||||||
| data_point, field_name, item, nodes, edges, added_nodes, added_edges | ||||||
| ) | ||||||
| edges = edges[:n_edges_before] + [ | ||||||
| (*edge[:3], {**edge[3], "metadata": {"type": "list"}}) | ||||||
| for edge in edges[n_edges_before:] | ||||||
| ] | ||||||
| else: | ||||||
| data_point_properties[field_name] = field_value | ||||||
| property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) | ||||||
|
|
||||||
| for node in property_nodes: | ||||||
| if str(node.id) not in added_nodes: | ||||||
| nodes.append(node) | ||||||
| added_nodes[str(node.id)] = True | ||||||
|
|
||||||
| for edge in property_edges: | ||||||
| edge_key = str(edge[0]) + str(edge[1]) + edge[2] | ||||||
|
|
||||||
| if str(edge_key) not in added_edges: | ||||||
| edges.append(edge) | ||||||
| added_edges[edge_key] = True | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure consistent key types in In line 62, Apply this diff to fix the inconsistency: - added_edges[edge_key] = True
+ added_edges[str(edge_key)] = True📝 Committable suggestion
Suggested change
|
||||||
|
|
||||||
| for property_node in get_own_properties(property_nodes, property_edges): | ||||||
| edge_key = str(data_point.id) + str(property_node.id) + field_name | ||||||
|
|
||||||
| if str(edge_key) not in added_edges: | ||||||
| edges.append((data_point.id, property_node.id, field_name, { | ||||||
| "source_node_id": data_point.id, | ||||||
| "target_node_id": property_node.id, | ||||||
| "relationship_name": field_name, | ||||||
| "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), | ||||||
| "metadata": { | ||||||
| "type": "list" | ||||||
| }, | ||||||
| })) | ||||||
| added_edges[edge_key] = True | ||||||
| continue | ||||||
|
|
||||||
| data_point_properties[field_name] = field_value | ||||||
|
|
||||||
| SimpleDataPointModel = copy_model( | ||||||
| type(data_point), | ||||||
| include_fields={ | ||||||
| include_fields = { | ||||||
| "_metadata": (dict, data_point._metadata), | ||||||
| }, | ||||||
| exclude_fields=excluded_properties, | ||||||
| exclude_fields = excluded_properties, | ||||||
| ) | ||||||
|
|
||||||
| nodes.append(SimpleDataPointModel(**data_point_properties)) | ||||||
| if include_root: | ||||||
| nodes.append(SimpleDataPointModel(**data_point_properties)) | ||||||
|
|
||||||
| return nodes, edges | ||||||
|
|
||||||
|
|
||||||
| def add_nodes_and_edges( | ||||||
| data_point, field_name, field_value, nodes, edges, added_nodes, added_edges | ||||||
| ): | ||||||
|
|
||||||
| property_nodes, property_edges = get_graph_from_model( | ||||||
| field_value, dict(added_nodes), dict(added_edges) | ||||||
| ) | ||||||
|
|
||||||
| for node in property_nodes: | ||||||
| if str(node.id) not in added_nodes: | ||||||
| nodes.append(node) | ||||||
| added_nodes[str(node.id)] = True | ||||||
|
|
||||||
| for edge in property_edges: | ||||||
| edge_key = str(edge[0]) + str(edge[1]) + edge[2] | ||||||
|
|
||||||
| if str(edge_key) not in added_edges: | ||||||
| edges.append(edge) | ||||||
| added_edges[str(edge_key)] = True | ||||||
|
|
||||||
| for property_node in get_own_properties(property_nodes, property_edges): | ||||||
| edge_key = str(data_point.id) + str(property_node.id) + field_name | ||||||
|
|
||||||
| if str(edge_key) not in added_edges: | ||||||
| edges.append( | ||||||
| ( | ||||||
| data_point.id, | ||||||
| property_node.id, | ||||||
| field_name, | ||||||
| { | ||||||
| "source_node_id": data_point.id, | ||||||
| "target_node_id": property_node.id, | ||||||
| "relationship_name": field_name, | ||||||
| "updated_at": datetime.now(timezone.utc).strftime( | ||||||
| "%Y-%m-%d %H:%M:%S" | ||||||
| ), | ||||||
| }, | ||||||
| ) | ||||||
| ) | ||||||
| added_edges[str(edge_key)] = True | ||||||
|
|
||||||
| return (nodes, edges, added_nodes, added_edges) | ||||||
|
|
||||||
|
|
||||||
| def get_own_properties(property_nodes, property_edges): | ||||||
| own_properties = [] | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,41 +1,29 @@ | ||
| from typing import Callable | ||
|
|
||
| from pydantic_core import PydanticUndefined | ||
|
|
||
| from cognee.infrastructure.engine import DataPoint | ||
| from cognee.modules.storage.utils import copy_model | ||
|
|
||
|
|
||
| def get_model_instance_from_graph( | ||
| nodes: list[DataPoint], | ||
| edges: list[tuple[str, str, str, dict[str, str]]], | ||
| entity_id: str, | ||
| ): | ||
| node_map = {node.id: node for node in nodes} | ||
| def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str): | ||
| node_map = {} | ||
|
|
||
| for source_node_id, target_node_id, edge_label, edge_properties in edges: | ||
| source_node = node_map[source_node_id] | ||
| target_node = node_map[target_node_id] | ||
| for node in nodes: | ||
| node_map[node.id] = node | ||
|
|
||
| for edge in edges: | ||
| source_node = node_map[edge[0]] | ||
| target_node = node_map[edge[1]] | ||
| edge_label = edge[2] | ||
| edge_properties = edge[3] if len(edge) == 4 else {} | ||
|
Comment on lines
+12
to
+16
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add validation for edge tuple length to prevent index errors Accessing elements of Apply this fix: for edge in edges:
if len(edge) < 3:
# Handle error or skip invalid edge
continue # or raise an exception with a descriptive message
source_node = node_map.get(edge[0])
target_node = node_map.get(edge[1])
if source_node is None or target_node is None:
# Handle missing nodes
continue # or raise an exception
edge_label = edge[2]
edge_properties = edge[3] if len(edge) >= 4 else {}
edge_metadata = edge_properties.get("metadata", {})
# rest of the code |
||
| edge_metadata = edge_properties.get("metadata", {}) | ||
| edge_type = edge_metadata.get("type", "default") | ||
| edge_type = edge_metadata.get("type") | ||
|
|
||
| if edge_type == "list": | ||
| NewModel = copy_model( | ||
| type(source_node), | ||
| {edge_label: (list[type(target_node)], PydanticUndefined)}, | ||
| ) | ||
| source_node_dict = source_node.model_dump() | ||
| source_node_edge_label_values = source_node_dict.get(edge_label, []) | ||
| source_node_dict[edge_label] = source_node_edge_label_values + [target_node] | ||
|
|
||
| node_map[source_node_id] = NewModel(**source_node_dict) | ||
| NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) }) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the syntax for specifying list type in In line 21, First, import +from typing import ListThen, correct line 21: - NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
+ NewModel = copy_model(type(source_node), { edge_label: (List[type(target_node)], PydanticUndefined) })
|
||
|
|
||
| node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] }) | ||
| else: | ||
| NewModel = copy_model( | ||
| type(source_node), {edge_label: (type(target_node), PydanticUndefined)} | ||
| ) | ||
| NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) }) | ||
|
|
||
| node_map[target_node_id] = NewModel( | ||
| **source_node.model_dump(), **{edge_label: target_node} | ||
| ) | ||
| node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node }) | ||
|
|
||
| return node_map[entity_id] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid mutable default arguments to prevent unintended behavior
Using mutable default arguments like
{}foradded_nodesandadded_edgescan lead to unexpected behavior because the default dictionaries are shared across all function calls. It's recommended to useNoneas the default value and initialize the dictionaries within the function.Apply this diff to fix the issue:
📝 Committable suggestion