Skip to content
194 changes: 93 additions & 101 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,50 @@
from cognee.modules.storage.utils import copy_model


def _extract_field_info(field_value: Any) -> Tuple[str, Any, Optional[Edge]]:
"""Extract field type, actual value, and edge metadata from a field value."""

# Handle tuple[Edge, DataPoint]
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
and isinstance(field_value[1], DataPoint)
):
return "single_datapoint_with_edge", field_value[1], field_value[0]

# Handle tuple[Edge, list[DataPoint]]
def _extract_field_data(field_value: Any) -> List[Tuple[Optional[Edge], List[DataPoint]]]:
"""Extract edge metadata and datapoints from a field value."""
# Handle single DataPoint
if isinstance(field_value, DataPoint):
return [(None, [field_value])]

# Handle list - could contain DataPoints, edge tuples, or mixed
if isinstance(field_value, list) and len(field_value) > 0:
result = []
for item in field_value:
# Handle tuple[Edge, DataPoint or list[DataPoint]]
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Edge):
edge, data_value = item
if isinstance(data_value, DataPoint):
result.append((edge, [data_value]))
elif (
isinstance(data_value, list)
and len(data_value) > 0
and isinstance(data_value[0], DataPoint)
):
result.append((edge, data_value))
# Handle single DataPoint in list
elif isinstance(item, DataPoint):
result.append((None, [item]))
return result

# Handle tuple[Edge, DataPoint or list[DataPoint]]
if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
and isinstance(field_value[1], list)
and len(field_value[1]) > 0
and isinstance(field_value[1][0], DataPoint)
):
return "list_datapoint_with_edge", field_value[1], field_value[0]

# Handle single DataPoint
if isinstance(field_value, DataPoint):
return "single_datapoint", field_value, None

# Handle list of DataPoints
if (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):
return "list_datapoint", field_value, None
edge_metadata, data_value = field_value
if isinstance(data_value, DataPoint):
return [(edge_metadata, [data_value])]
elif (
isinstance(data_value, list)
and len(data_value) > 0
and isinstance(data_value[0], DataPoint)
):
return [(edge_metadata, data_value)]

# Regular property
return "property", field_value, None
# Regular property or empty list
return []


def _create_edge_properties(
Expand Down Expand Up @@ -80,30 +87,49 @@ def _get_relationship_key(field_name: str, edge_metadata: Optional[Edge]) -> str

def _generate_property_key(data_point_id: str, relationship_key: str, target_id: str) -> str:
"""Generate a unique property key for visited_properties tracking."""
return f"{data_point_id}{relationship_key}{target_id}"
return f"{data_point_id}_{relationship_key}_{target_id}"


def _process_datapoint_field(
data_point_id: str,
field_name: str,
datapoints: List[DataPoint],
edge_metadata: Optional[Edge],
edge_datapoint_pairs: List[Tuple[Optional[Edge], List[DataPoint]]],
visited_properties: Dict[str, bool],
properties_to_visit: set,
excluded_properties: set,
) -> None:
"""Process a field containing DataPoint(s), handling both single and list cases."""
"""Process a field containing DataPoints, always working with lists."""
excluded_properties.add(field_name)
relationship_key = _get_relationship_key(field_name, edge_metadata)

for index, datapoint in enumerate(datapoints):
property_key = _generate_property_key(data_point_id, relationship_key, str(datapoint.id))
if property_key in visited_properties:
for edge_metadata, datapoints in edge_datapoint_pairs:
relationship_key = _get_relationship_key(field_name, edge_metadata)

for datapoint in datapoints:
property_key = _generate_property_key(
data_point_id, relationship_key, str(datapoint.id)
)
if property_key in visited_properties:
continue

# Always use field_name since we're working with lists
properties_to_visit.add(field_name)


def _targets_generator(
data_point: DataPoint,
properties_to_visit: set,
) -> Tuple[DataPoint, str, Optional[Edge]]:
"""Generator that yields (target_datapoint, field_name, edge_metadata) tuples."""
for field_name in properties_to_visit:
field_value = getattr(data_point, field_name)
edge_datapoint_pairs = _extract_field_data(field_value)

if not edge_datapoint_pairs:
continue

# For single datapoint, use field_name; for list, use field_name.index
field_identifier = field_name if len(datapoints) == 1 else f"{field_name}.{index}"
properties_to_visit.add(field_identifier)
for edge_metadata, datapoints in edge_datapoint_pairs:
for target_datapoint in datapoints:
yield target_datapoint, field_name, edge_metadata


async def get_graph_from_model(
Expand Down Expand Up @@ -143,26 +169,17 @@ async def get_graph_from_model(
if field_name == "metadata":
continue

field_type, actual_value, edge_metadata = _extract_field_info(field_value)
edge_datapoint_pairs = _extract_field_data(field_value)

if field_type == "property":
if not edge_datapoint_pairs:
# Regular property
data_point_properties[field_name] = field_value
elif field_type in ["single_datapoint", "single_datapoint_with_edge"]:
_process_datapoint_field(
data_point_id,
field_name,
[actual_value],
edge_metadata,
visited_properties,
properties_to_visit,
excluded_properties,
)
elif field_type in ["list_datapoint", "list_datapoint_with_edge"]:
else:
# DataPoint relationship
_process_datapoint_field(
data_point_id,
field_name,
actual_value,
edge_metadata,
edge_datapoint_pairs,
visited_properties,
properties_to_visit,
excluded_properties,
Expand All @@ -176,65 +193,40 @@ async def get_graph_from_model(
nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[data_point_id] = True

# Process all relationships
for field_name_with_index in properties_to_visit:
# Parse field name and index
if "." in field_name_with_index:
field_name, index_str = field_name_with_index.split(".")
index = int(index_str)
else:
field_name, index = field_name_with_index, None

# Get field value and extract edge metadata
field_value = getattr(data_point, field_name)
edge_metadata = None

if (
isinstance(field_value, tuple)
and len(field_value) == 2
and isinstance(field_value[0], Edge)
):
edge_metadata, field_value = field_value

# Get specific datapoint - handle both single and list cases
if index is not None:
# List case: extract specific item by index
target_datapoint = field_value[index]
elif isinstance(field_value, list):
# Single datapoint case that was wrapped in a list
target_datapoint = field_value[0]
else:
# True single datapoint case
target_datapoint = field_value
# Process all relationships using generator
for target_datapoint, field_name, edge_metadata in _targets_generator(
data_point, properties_to_visit
):
relationship_name = _get_relationship_key(field_name, edge_metadata)

# Create edge if not already added
edge_key = f"{data_point_id}{target_datapoint.id}{field_name}"
edge_key = f"{data_point_id}_{target_datapoint.id}_{field_name}"
if edge_key not in added_edges:
relationship_name = _get_relationship_key(field_name, edge_metadata)
edge_properties = _create_edge_properties(
data_point.id, target_datapoint.id, relationship_name, edge_metadata
)
edges.append((data_point.id, target_datapoint.id, relationship_name, edge_properties))
added_edges[edge_key] = True

# Mark property as visited - CRITICAL for preventing infinite loops
relationship_key = _get_relationship_key(field_name, edge_metadata)
property_key = _generate_property_key(
data_point_id, relationship_key, str(target_datapoint.id)
data_point_id, relationship_name, str(target_datapoint.id)
)
visited_properties[property_key] = True

# Recursively process target node if not already processed
if str(target_datapoint.id) not in added_nodes:
child_nodes, child_edges = await get_graph_from_model(
target_datapoint,
include_root=True,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
nodes.extend(child_nodes)
edges.extend(child_edges)
if str(target_datapoint.id) in added_nodes:
continue

child_nodes, child_edges = await get_graph_from_model(
target_datapoint,
include_root=True,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
nodes.extend(child_nodes)
edges.extend(child_edges)

return nodes, edges

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from typing import List
from cognee.infrastructure.engine import DataPoint
from typing import List, Any
from cognee.infrastructure.engine import DataPoint, Edge

from cognee.modules.graph.utils import get_graph_from_model

Expand Down Expand Up @@ -28,7 +28,20 @@ class Entity(DataPoint):
metadata: dict = {"index_fields": ["name"]}


class Company(DataPoint):
name: str
employees: List[Any] = None # Allow flexible edge system with tuples
metadata: dict = {"index_fields": ["name"]}


class Employee(DataPoint):
name: str
role: str
metadata: dict = {"index_fields": ["name"]}


DocumentChunk.model_rebuild()
Company.model_rebuild()


@pytest.mark.asyncio
Expand All @@ -50,7 +63,7 @@ async def test_get_graph_from_model_simple_structure():
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edges, got {len(edges)}"

edge_key = str(entity.id) + str(entitytype.id) + "is_type"
edge_key = f"{str(entity.id)}_{str(entitytype.id)}_is_type"
assert edge_key in added_edges, f"Edge {edge_key} not found"


Expand Down Expand Up @@ -149,3 +162,48 @@ async def test_get_graph_from_model_no_contains():

assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"


@pytest.mark.asyncio
async def test_get_graph_from_model_flexible_edges():
"""Tests the new flexible edge system with mixed relationships"""
# Create employees
manager = Employee(name="Manager", role="Manager")
sales1 = Employee(name="Sales1", role="Sales")
sales2 = Employee(name="Sales2", role="Sales")
admin1 = Employee(name="Admin1", role="Admin")
admin2 = Employee(name="Admin2", role="Admin")

# Create company with mixed employee relationships
company = Company(
name="Test Company",
employees=[
# Weighted relationship
(Edge(weight=0.9, relationship_type="manages"), manager),
# Multiple weights relationship
(
Edge(weights={"performance": 0.8, "experience": 0.7}, relationship_type="employs"),
sales1,
),
# Simple relationship
sales2,
# Group relationship
(Edge(weights={"team_efficiency": 0.8}, relationship_type="employs"), [admin1, admin2]),
],
)

added_nodes = {}
added_edges = {}
visited_properties = {}

nodes, edges = await get_graph_from_model(company, added_nodes, added_edges, visited_properties)

# Should have 6 nodes: company + 5 employees
assert len(nodes) == 6, f"Expected 6 nodes, got {len(nodes)}"
# Should have 5 edges: 4 employee relationships
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"

# Verify all employees are connected
employee_ids = {str(emp.id) for emp in [manager, sales1, sales2, admin1, admin2]}
edge_target_ids = {str(edge[1]) for edge in edges}
assert employee_ids.issubset(edge_target_ids), "Not all employees are connected"
Loading
Loading