Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add edge text, embed and expose it
  • Loading branch information
lxobr committed Oct 28, 2025
commit be7d315f97b941f0f6dc4b407ecd40a0abef5b15
14 changes: 13 additions & 1 deletion cognee/infrastructure/engine/models/Edge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import Optional, Any, Dict


Expand All @@ -18,9 +18,21 @@ class Edge(BaseModel):

# Mixed usage
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])

# With edge_text for rich embedding representation
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
"""

weight: Optional[float] = None
weights: Optional[Dict[str, float]] = None
relationship_type: Optional[str] = None
properties: Optional[Dict[str, Any]] = None
edge_text: Optional[str] = None

@field_validator("edge_text", mode="before")
@classmethod
def ensure_edge_text(cls, v, info):
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
if v is None and info.data.get("relationship_type"):
return info.data["relationship_type"]
return v
3 changes: 2 additions & 1 deletion cognee/modules/chunking/models/DocumentChunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Union

from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.tasks.temporal_graph.models import Event
Expand Down Expand Up @@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Union[Entity, Event]] = None
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None

metadata: dict = {"index_fields": ["text"]}
6 changes: 4 additions & 2 deletions cognee/modules/graph/cognee_graph/CogneeGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ async def map_vector_distances_to_graph_edges(
embedding_map = {result.payload["text"]: result.score for result in edge_distances}

for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
distance = embedding_map.get(relationship_type, None)
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance

Expand Down
21 changes: 19 additions & 2 deletions cognee/modules/graph/utils/expand_with_nodes_and_edges.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
Expand Down Expand Up @@ -243,10 +244,26 @@ def _process_graph_nodes(
ontology_relationships,
)

# Add entity to data chunk
if data_chunk.contains is None:
data_chunk.contains = []
data_chunk.contains.append(entity_node)

edge_text = "; ".join(
[
"relationship_name: contains",
f"entity_name: {entity_node.name}",
f"entity_description: {entity_node.description}",
]
)

data_chunk.contains.append(
(
Edge(
relationship_type="contains",
edge_text=edge_text,
),
entity_node,
)
)


def _process_graph_edges(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def get_memory_fragment(
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
)
Expand Down
55 changes: 33 additions & 22 deletions cognee/tasks/storage/index_data_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,58 @@


async def index_data_points(data_points: list[DataPoint]):
created_indexes = {}
index_points = {}
"""Index data points in the vector engine by creating embeddings for specified fields.
Process:
1. Groups data points into a nested dict: {type_name: {field_name: [points]}}
2. Creates vector indexes for each (type, field) combination on first encounter
3. Batches points per (type, field) and creates async indexing tasks
4. Executes all indexing tasks in parallel for efficient embedding generation
Args:
data_points: List of DataPoint objects to index. Each DataPoint's metadata must
contain an 'index_fields' list specifying which fields to embed.
Returns:
The original data_points list.
"""
data_points_by_type = {}

vector_engine = get_vector_engine()

for data_point in data_points:
data_point_type = type(data_point)
type_name = data_point_type.__name__

for field_name in data_point.metadata["index_fields"]:
if getattr(data_point, field_name, None) is None:
continue

index_name = f"{data_point_type.__name__}_{field_name}"
if type_name not in data_points_by_type:
data_points_by_type[type_name] = {}

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []
if field_name not in data_points_by_type[type_name]:
await vector_engine.create_vector_index(type_name, field_name)
data_points_by_type[type_name][field_name] = []

indexed_data_point = data_point.model_copy()
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
data_points_by_type[type_name][field_name].append(indexed_data_point)

tasks: list[asyncio.Task] = []
batch_size = vector_engine.embedding_engine.get_batch_size()

for index_name_and_field, points in index_points.items():
first = index_name_and_field.index("_")
index_name = index_name_and_field[:first]
field_name = index_name_and_field[first + 1 :]
batches = (
(type_name, field_name, points[i : i + batch_size])
for type_name, fields in data_points_by_type.items()
for field_name, points in fields.items()
for i in range(0, len(points), batch_size)
)

# Create embedding requests per batch to run in parallel later
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
tasks.append(
asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
)
tasks = [
asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points))
for type_name, field_name, batch_points in batches
]

# Run all embedding requests in parallel
await asyncio.gather(*tasks)

return data_points
Expand Down
94 changes: 37 additions & 57 deletions cognee/tasks/storage/index_graph_edges.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
import asyncio
from collections import Counter
from typing import Optional, Dict, Any, List, Tuple, Union

from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger
from collections import Counter
from typing import Optional, Dict, Any, List, Tuple, Union
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
from cognee.tasks.storage.index_data_points import index_data_points

logger = get_logger()


def _get_edge_text(item: dict) -> str:
"""Extract edge text for embedding - prefers edge_text field with fallback."""
if "edge_text" in item:
return item["edge_text"]

if "relationship_name" in item:
return item["relationship_name"]

return ""


def create_edge_type_datapoints(edges_data) -> list[EdgeType]:
"""Transform raw edge data into EdgeType datapoints."""
edge_texts = [
_get_edge_text(item)
for edge in edges_data
for item in edge
if isinstance(item, dict) and "relationship_name" in item
]

edge_types = Counter(edge_texts)

return [
EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count)
for text, count in edge_types.items()
]


async def index_graph_edges(
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
):
Expand All @@ -23,24 +50,17 @@ async def index_graph_edges(
the `relationship_name` field.

Steps:
1. Initialize the vector engine and graph engine.
2. Retrieve graph edge data and count relationship types (`relationship_name`).
3. Create vector indexes for `relationship_name` if they don't exist.
4. Transform the counted relationships into `EdgeType` objects.
5. Index the transformed data points in the vector engine.
1. Initialize the graph engine if needed and retrieve edge data.
2. Transform edge data into EdgeType datapoints.
3. Index the EdgeType datapoints using the standard indexing function.

Raises:
RuntimeError: If initialization of the vector engine or graph engine fails.
RuntimeError: If initialization of the graph engine fails.

Returns:
None
"""
try:
created_indexes = {}
index_points = {}

vector_engine = get_vector_engine()

if edges_data is None:
graph_engine = await get_graph_engine()
_, edges_data = await graph_engine.get_graph_data()
Expand All @@ -51,47 +71,7 @@ async def index_graph_edges(
logger.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e

edge_types = Counter(
item.get("relationship_name")
for edge in edges_data
for item in edge
if isinstance(item, dict) and "relationship_name" in item
)

for text, count in edge_types.items():
edge = EdgeType(
id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
)
data_point_type = type(edge)

for field_name in edge.metadata["index_fields"]:
index_name = f"{data_point_type.__name__}.{field_name}"

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []

indexed_data_point = edge.model_copy()
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)

# Get maximum batch size for embedding model
batch_size = vector_engine.embedding_engine.get_batch_size()
tasks: list[asyncio.Task] = []

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")

# Create embedding tasks to run in parallel later
for start in range(0, len(indexable_points), batch_size):
batch = indexable_points[start : start + batch_size]

tasks.append(vector_engine.index_data_points(index_name, field_name, batch))

# Start all embedding tasks and wait for completion
await asyncio.gather(*tasks)
edge_type_datapoints = create_edge_type_datapoints(edges_data)
await index_data_points(edge_type_datapoints)

return None
13 changes: 13 additions & 0 deletions cognee/tests/test_edge_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ async def test_edge_ingestion():

edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])

"Tests edge_text presence and format"
contains_edges = [edge for edge in graph[1] if edge[2] == "contains"]
assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification"
edge_properties = contains_edges[0][3]
assert "edge_text" in edge_properties, "Expected edge_text in edge properties"
edge_text = edge_properties["edge_text"]
assert "relationship_name: contains" in edge_text, (
f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}"
)
assert "entity_name:" in edge_text or "entity_description:" in edge_text, (
f"Expected entity info in edge_text, got: {edge_text}"
)

"Tests the presence of basic nested edges"
for basic_nested_edge in basic_nested_edges:
assert edge_type_counts.get(basic_nested_edge, 0) >= 1, (
Expand Down
Loading