Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion cognee/infrastructure/engine/models/Edge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import Optional, Any, Dict


Expand All @@ -18,9 +18,21 @@ class Edge(BaseModel):

# Mixed usage
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])

# With edge_text for rich embedding representation
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
"""

weight: Optional[float] = None
weights: Optional[Dict[str, float]] = None
relationship_type: Optional[str] = None
properties: Optional[Dict[str, Any]] = None
edge_text: Optional[str] = None

@field_validator("edge_text", mode="before")
@classmethod
def ensure_edge_text(cls, v, info):
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
if v is None and info.data.get("relationship_type"):
return info.data["relationship_type"]
return v
3 changes: 2 additions & 1 deletion cognee/modules/chunking/models/DocumentChunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Union

from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.tasks.temporal_graph.models import Event
Expand Down Expand Up @@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Union[Entity, Event]] = None
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None

metadata: dict = {"index_fields": ["text"]}
6 changes: 4 additions & 2 deletions cognee/modules/graph/cognee_graph/CogneeGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ async def map_vector_distances_to_graph_edges(
embedding_map = {result.payload["text"]: result.score for result in edge_distances}

for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
distance = embedding_map.get(relationship_type, None)
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance

Expand Down
21 changes: 19 additions & 2 deletions cognee/modules/graph/utils/expand_with_nodes_and_edges.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
Expand Down Expand Up @@ -243,10 +244,26 @@ def _process_graph_nodes(
ontology_relationships,
)

# Add entity to data chunk
if data_chunk.contains is None:
data_chunk.contains = []
data_chunk.contains.append(entity_node)

edge_text = "; ".join(
[
"relationship_name: contains",
f"entity_name: {entity_node.name}",
f"entity_description: {entity_node.description}",
]
)

data_chunk.contains.append(
(
Edge(
relationship_type="contains",
edge_text=edge_text,
),
entity_node,
)
)


def _process_graph_edges(
Expand Down
97 changes: 48 additions & 49 deletions cognee/modules/graph/utils/resolve_edges_to_text.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,70 @@
import string
from typing import List
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge

from collections import Counter

async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
"""
Converts retrieved graph edges into a human-readable string format.
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS

Parameters:
-----------

- retrieved_edges (list): A list of edges retrieved from the graph.
def _get_top_n_frequent_words(
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
) -> str:
"""Concatenates the top N frequent words in text."""
if stop_words is None:
stop_words = DEFAULT_STOP_WORDS

Returns:
--------
words = [word.lower().strip(string.punctuation) for word in text.split()]
words = [word for word in words if word and word not in stop_words]

- str: A formatted string representation of the nodes and their connections.
"""
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)

def _get_nodes(retrieved_edges: List[Edge]) -> dict:
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
if stop_words is None:
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS

stop_words = DEFAULT_STOP_WORDS
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
"""Creates a title by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
return f"{' '.join(first_words)}... [{top_words}]"

import string

words = [word.lower().strip(string.punctuation) for word in text.split()]
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
"""Creates a dictionary of nodes with their names and content."""
nodes = {}

if stop_words:
words = [word for word in words if word and word not in stop_words]
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id in nodes:
continue

from collections import Counter
text = node.attributes.get("text")
if text:
name = _create_title_from_text(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)

top_words = [word for word, freq in Counter(words).most_common(top_n)]
nodes[node.id] = {"node": node, "name": name, "content": content}

return separator.join(top_words)
return nodes

"""Creates a title, by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _top_n_words(text, top_n=first_n_words)
return f"{' '.join(first_words)}... [{top_words}]"

"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id not in nodes:
text = node.attributes.get("text")
if text:
name = _get_title(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
"""Converts retrieved graph edges into a human-readable string format."""
nodes = _extract_nodes_from_edges(retrieved_edges)

nodes = _get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
connection_section = "\n".join(
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
for edge in retrieved_edges
)

connections = []
for edge in retrieved_edges:
source_name = nodes[edge.node1.id]["name"]
target_name = nodes[edge.node2.id]["name"]
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")

connection_section = "\n".join(connections)

return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def get_memory_fragment(
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
)
Expand Down
55 changes: 33 additions & 22 deletions cognee/tasks/storage/index_data_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,58 @@


async def index_data_points(data_points: list[DataPoint]):
created_indexes = {}
index_points = {}
"""Index data points in the vector engine by creating embeddings for specified fields.
Process:
1. Groups data points into a nested dict: {type_name: {field_name: [points]}}
2. Creates vector indexes for each (type, field) combination on first encounter
3. Batches points per (type, field) and creates async indexing tasks
4. Executes all indexing tasks in parallel for efficient embedding generation
Args:
data_points: List of DataPoint objects to index. Each DataPoint's metadata must
contain an 'index_fields' list specifying which fields to embed.
Returns:
The original data_points list.
"""
data_points_by_type = {}

vector_engine = get_vector_engine()

for data_point in data_points:
data_point_type = type(data_point)
type_name = data_point_type.__name__

for field_name in data_point.metadata["index_fields"]:
if getattr(data_point, field_name, None) is None:
continue

index_name = f"{data_point_type.__name__}_{field_name}"
if type_name not in data_points_by_type:
data_points_by_type[type_name] = {}

if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True

if index_name not in index_points:
index_points[index_name] = []
if field_name not in data_points_by_type[type_name]:
await vector_engine.create_vector_index(type_name, field_name)
data_points_by_type[type_name][field_name] = []

indexed_data_point = data_point.model_copy()
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
data_points_by_type[type_name][field_name].append(indexed_data_point)

tasks: list[asyncio.Task] = []
batch_size = vector_engine.embedding_engine.get_batch_size()

for index_name_and_field, points in index_points.items():
first = index_name_and_field.index("_")
index_name = index_name_and_field[:first]
field_name = index_name_and_field[first + 1 :]
batches = (
(type_name, field_name, points[i : i + batch_size])
for type_name, fields in data_points_by_type.items()
for field_name, points in fields.items()
for i in range(0, len(points), batch_size)
)

# Create embedding requests per batch to run in parallel later
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
tasks.append(
asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
)
tasks = [
asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points))
for type_name, field_name, batch_points in batches
]

# Run all embedding requests in parallel
await asyncio.gather(*tasks)

return data_points
Expand Down
Loading
Loading