diff --git a/cognee/infrastructure/engine/models/Edge.py b/cognee/infrastructure/engine/models/Edge.py index 5ad9c84dd9..59f01a9aba 100644 --- a/cognee/infrastructure/engine/models/Edge.py +++ b/cognee/infrastructure/engine/models/Edge.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from typing import Optional, Any, Dict @@ -18,9 +18,21 @@ class Edge(BaseModel): # Mixed usage has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item]) + + # With edge_text for rich embedding representation + contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity) """ weight: Optional[float] = None weights: Optional[Dict[str, float]] = None relationship_type: Optional[str] = None properties: Optional[Dict[str, Any]] = None + edge_text: Optional[str] = None + + @field_validator("edge_text", mode="before") + @classmethod + def ensure_edge_text(cls, v, info): + """Auto-populate edge_text from relationship_type if not explicitly provided.""" + if v is None and info.data.get("relationship_type"): + return info.data["relationship_type"] + return v diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index 9f8c57486e..e024bf00b8 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -1,6 +1,7 @@ from typing import List, Union from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.Edge import Edge from cognee.modules.data.processing.document_types import Document from cognee.modules.engine.models import Entity from cognee.tasks.temporal_graph.models import Event @@ -31,6 +32,6 @@ class DocumentChunk(DataPoint): chunk_index: int cut_type: str is_part_of: Document - contains: List[Union[Entity, Event]] = None + contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None metadata: dict = {"index_fields": ["text"]} diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 9703928f06..cb75624223 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -171,8 +171,10 @@ async def map_vector_distances_to_graph_edges( embedding_map = {result.payload["text"]: result.score for result in edge_distances} for edge in self.edges: - relationship_type = edge.attributes.get("relationship_type") - distance = embedding_map.get(relationship_type, None) + edge_key = edge.attributes.get("edge_text") or edge.attributes.get( + "relationship_type" + ) + distance = embedding_map.get(edge_key, None) if distance is not None: edge.attributes["vector_distance"] = distance diff --git a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py index 3b01f5af4b..c68eb494d4 100644 --- a/cognee/modules/graph/utils/expand_with_nodes_and_edges.py +++ b/cognee/modules/graph/utils/expand_with_nodes_and_edges.py @@ -1,5 +1,6 @@ from typing import Optional +from cognee.infrastructure.engine.models.Edge import Edge from cognee.modules.chunking.models import DocumentChunk from cognee.modules.engine.models import Entity, EntityType from cognee.modules.engine.utils import ( @@ -243,10 +244,26 @@ def _process_graph_nodes( ontology_relationships, ) - # Add entity to data chunk if data_chunk.contains is None: data_chunk.contains = [] - data_chunk.contains.append(entity_node) + + edge_text = "; ".join( + [ + "relationship_name: contains", + f"entity_name: {entity_node.name}", + f"entity_description: {entity_node.description}", + ] + ) + + data_chunk.contains.append( + ( + Edge( + relationship_type="contains", + edge_text=edge_text, + ), + entity_node, + ) + ) def _process_graph_edges( diff --git a/cognee/modules/graph/utils/resolve_edges_to_text.py b/cognee/modules/graph/utils/resolve_edges_to_text.py index eb5bedd2c5..5deb13ba82 100644 --- a/cognee/modules/graph/utils/resolve_edges_to_text.py +++ b/cognee/modules/graph/utils/resolve_edges_to_text.py @@ -1,71 +1,70 @@ +import string from typing import List -from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge - +from collections import Counter -async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str: - """ - Converts retrieved graph edges into a human-readable string format. +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS - Parameters: - ----------- - - retrieved_edges (list): A list of edges retrieved from the graph. +def _get_top_n_frequent_words( + text: str, stop_words: set = None, top_n: int = 3, separator: str = ", " +) -> str: + """Concatenates the top N frequent words in text.""" + if stop_words is None: + stop_words = DEFAULT_STOP_WORDS - Returns: - -------- + words = [word.lower().strip(string.punctuation) for word in text.split()] + words = [word for word in words if word and word not in stop_words] - - str: A formatted string representation of the nodes and their connections. - """ + top_words = [word for word, freq in Counter(words).most_common(top_n)] + return separator.join(top_words) - def _get_nodes(retrieved_edges: List[Edge]) -> dict: - def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: - def _top_n_words(text, stop_words=None, top_n=3, separator=", "): - """Concatenates the top N frequent words in text.""" - if stop_words is None: - from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS - stop_words = DEFAULT_STOP_WORDS +def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: + """Creates a title by combining first words with most frequent words from the text.""" + first_words = text.split()[:first_n_words] + top_words = _get_top_n_frequent_words(text, top_n=top_n_words) + return f"{' '.join(first_words)}... [{top_words}]" - import string - words = [word.lower().strip(string.punctuation) for word in text.split()] +def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict: + """Creates a dictionary of nodes with their names and content.""" + nodes = {} - if stop_words: - words = [word for word in words if word and word not in stop_words] + for edge in retrieved_edges: + for node in (edge.node1, edge.node2): + if node.id in nodes: + continue - from collections import Counter + text = node.attributes.get("text") + if text: + name = _create_title_from_text(text) + content = text + else: + name = node.attributes.get("name", "Unnamed Node") + content = node.attributes.get("description", name) - top_words = [word for word, freq in Counter(words).most_common(top_n)] + nodes[node.id] = {"node": node, "name": name, "content": content} - return separator.join(top_words) + return nodes - """Creates a title, by combining first words with most frequent words from the text.""" - first_words = text.split()[:first_n_words] - top_words = _top_n_words(text, top_n=first_n_words) - return f"{' '.join(first_words)}... [{top_words}]" - """Creates a dictionary of nodes with their names and content.""" - nodes = {} - for edge in retrieved_edges: - for node in (edge.node1, edge.node2): - if node.id not in nodes: - text = node.attributes.get("text") - if text: - name = _get_title(text) - content = text - else: - name = node.attributes.get("name", "Unnamed Node") - content = node.attributes.get("description", name) - nodes[node.id] = {"node": node, "name": name, "content": content} - return nodes +async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str: + """Converts retrieved graph edges into a human-readable string format.""" + nodes = _extract_nodes_from_edges(retrieved_edges) - nodes = _get_nodes(retrieved_edges) node_section = "\n".join( f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n" for info in nodes.values() ) - connection_section = "\n".join( - f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}" - for edge in retrieved_edges - ) + + connections = [] + for edge in retrieved_edges: + source_name = nodes[edge.node1.id]["name"] + target_name = nodes[edge.node2.id]["name"] + edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + connections.append(f"{source_name} --[{edge_label}]--> {target_name}") + + connection_section = "\n".join(connections) + return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}" diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 1ef7545c21..f8bdbb97d7 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -71,7 +71,7 @@ async def get_memory_fragment( await memory_fragment.project_graph_from_db( graph_engine, node_properties_to_project=properties_to_project, - edge_properties_to_project=["relationship_name"], + edge_properties_to_project=["relationship_name", "edge_text"], node_type=node_type, node_name=node_name, ) diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 902789c808..b0ec3a5b4a 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -8,47 +8,58 @@ async def index_data_points(data_points: list[DataPoint]): - created_indexes = {} - index_points = {} + """Index data points in the vector engine by creating embeddings for specified fields. + + Process: + 1. Groups data points into a nested dict: {type_name: {field_name: [points]}} + 2. Creates vector indexes for each (type, field) combination on first encounter + 3. Batches points per (type, field) and creates async indexing tasks + 4. Executes all indexing tasks in parallel for efficient embedding generation + + Args: + data_points: List of DataPoint objects to index. Each DataPoint's metadata must + contain an 'index_fields' list specifying which fields to embed. + + Returns: + The original data_points list. + """ + data_points_by_type = {} vector_engine = get_vector_engine() for data_point in data_points: data_point_type = type(data_point) + type_name = data_point_type.__name__ for field_name in data_point.metadata["index_fields"]: if getattr(data_point, field_name, None) is None: continue - index_name = f"{data_point_type.__name__}_{field_name}" + if type_name not in data_points_by_type: + data_points_by_type[type_name] = {} - if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) - created_indexes[index_name] = True - - if index_name not in index_points: - index_points[index_name] = [] + if field_name not in data_points_by_type[type_name]: + await vector_engine.create_vector_index(type_name, field_name) + data_points_by_type[type_name][field_name] = [] indexed_data_point = data_point.model_copy() indexed_data_point.metadata["index_fields"] = [field_name] - index_points[index_name].append(indexed_data_point) + data_points_by_type[type_name][field_name].append(indexed_data_point) - tasks: list[asyncio.Task] = [] batch_size = vector_engine.embedding_engine.get_batch_size() - for index_name_and_field, points in index_points.items(): - first = index_name_and_field.index("_") - index_name = index_name_and_field[:first] - field_name = index_name_and_field[first + 1 :] + batches = ( + (type_name, field_name, points[i : i + batch_size]) + for type_name, fields in data_points_by_type.items() + for field_name, points in fields.items() + for i in range(0, len(points), batch_size) + ) - # Create embedding requests per batch to run in parallel later - for i in range(0, len(points), batch_size): - batch = points[i : i + batch_size] - tasks.append( - asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch)) - ) + tasks = [ + asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points)) + for type_name, field_name, batch_points in batches + ] - # Run all embedding requests in parallel await asyncio.gather(*tasks) return data_points diff --git a/cognee/tasks/storage/index_graph_edges.py b/cognee/tasks/storage/index_graph_edges.py index 4fa8cfc75c..03b5a25a53 100644 --- a/cognee/tasks/storage/index_graph_edges.py +++ b/cognee/tasks/storage/index_graph_edges.py @@ -1,17 +1,44 @@ -import asyncio +from collections import Counter +from typing import Optional, Dict, Any, List, Tuple, Union from cognee.modules.engine.utils.generate_edge_id import generate_edge_id from cognee.shared.logging_utils import get_logger -from collections import Counter -from typing import Optional, Dict, Any, List, Tuple, Union -from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.models.EdgeType import EdgeType from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData +from cognee.tasks.storage.index_data_points import index_data_points logger = get_logger() +def _get_edge_text(item: dict) -> str: + """Extract edge text for embedding - prefers edge_text field with fallback.""" + if "edge_text" in item: + return item["edge_text"] + + if "relationship_name" in item: + return item["relationship_name"] + + return "" + + +def create_edge_type_datapoints(edges_data) -> list[EdgeType]: + """Transform raw edge data into EdgeType datapoints.""" + edge_texts = [ + _get_edge_text(item) + for edge in edges_data + for item in edge + if isinstance(item, dict) and "relationship_name" in item + ] + + edge_types = Counter(edge_texts) + + return [ + EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count) + for text, count in edge_types.items() + ] + + async def index_graph_edges( edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None, ): @@ -23,24 +50,17 @@ async def index_graph_edges( the `relationship_name` field. Steps: - 1. Initialize the vector engine and graph engine. - 2. Retrieve graph edge data and count relationship types (`relationship_name`). - 3. Create vector indexes for `relationship_name` if they don't exist. - 4. Transform the counted relationships into `EdgeType` objects. - 5. Index the transformed data points in the vector engine. + 1. Initialize the graph engine if needed and retrieve edge data. + 2. Transform edge data into EdgeType datapoints. + 3. Index the EdgeType datapoints using the standard indexing function. Raises: - RuntimeError: If initialization of the vector engine or graph engine fails. + RuntimeError: If initialization of the graph engine fails. Returns: None """ try: - created_indexes = {} - index_points = {} - - vector_engine = get_vector_engine() - if edges_data is None: graph_engine = await get_graph_engine() _, edges_data = await graph_engine.get_graph_data() @@ -51,47 +71,7 @@ async def index_graph_edges( logger.error("Failed to initialize engines: %s", e) raise RuntimeError("Initialization error") from e - edge_types = Counter( - item.get("relationship_name") - for edge in edges_data - for item in edge - if isinstance(item, dict) and "relationship_name" in item - ) - - for text, count in edge_types.items(): - edge = EdgeType( - id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count - ) - data_point_type = type(edge) - - for field_name in edge.metadata["index_fields"]: - index_name = f"{data_point_type.__name__}.{field_name}" - - if index_name not in created_indexes: - await vector_engine.create_vector_index(data_point_type.__name__, field_name) - created_indexes[index_name] = True - - if index_name not in index_points: - index_points[index_name] = [] - - indexed_data_point = edge.model_copy() - indexed_data_point.metadata["index_fields"] = [field_name] - index_points[index_name].append(indexed_data_point) - - # Get maximum batch size for embedding model - batch_size = vector_engine.embedding_engine.get_batch_size() - tasks: list[asyncio.Task] = [] - - for index_name, indexable_points in index_points.items(): - index_name, field_name = index_name.split(".") - - # Create embedding tasks to run in parallel later - for start in range(0, len(indexable_points), batch_size): - batch = indexable_points[start : start + batch_size] - - tasks.append(vector_engine.index_data_points(index_name, field_name, batch)) - - # Start all embedding tasks and wait for completion - await asyncio.gather(*tasks) + edge_type_datapoints = create_edge_type_datapoints(edges_data) + await index_data_points(edge_type_datapoints) return None diff --git a/cognee/tests/test_edge_ingestion.py b/cognee/tests/test_edge_ingestion.py index 5b23f78199..0d1407fabc 100755 --- a/cognee/tests/test_edge_ingestion.py +++ b/cognee/tests/test_edge_ingestion.py @@ -52,6 +52,33 @@ async def test_edge_ingestion(): edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) + "Tests edge_text presence and format" + contains_edges = [edge for edge in graph[1] if edge[2] == "contains"] + assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification" + + edge_properties = contains_edges[0][3] + assert "edge_text" in edge_properties, "Expected edge_text in edge properties" + + edge_text = edge_properties["edge_text"] + assert "relationship_name: contains" in edge_text, ( + f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}" + ) + assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}" + assert "entity_description:" in edge_text, ( + f"Expected 'entity_description:' in edge_text, got: {edge_text}" + ) + + all_edge_texts = [ + edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3] + ] + expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"] + found_entity = any( + any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts + ) + assert found_entity, ( + f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}" + ) + "Tests the presence of basic nested edges" for basic_nested_edge in basic_nested_edges: assert edge_type_counts.get(basic_nested_edge, 0) >= 1, ( diff --git a/cognee/tests/unit/infrastructure/databases/test_index_data_points.py b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py new file mode 100644 index 0000000000..21a5695ded --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/test_index_data_points.py @@ -0,0 +1,27 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from cognee.tasks.storage.index_data_points import index_data_points +from cognee.infrastructure.engine import DataPoint + + +class TestDataPoint(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + +@pytest.mark.asyncio +async def test_index_data_points_calls_vector_engine(): + """Test that index_data_points creates vector index and indexes data.""" + data_points = [TestDataPoint(name="test1")] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100) + + with patch.dict( + index_data_points.__globals__, + {"get_vector_engine": lambda: mock_vector_engine}, + ): + await index_data_points(data_points) + + assert mock_vector_engine.create_vector_index.await_count >= 1 + assert mock_vector_engine.index_data_points.await_count >= 1 diff --git a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py index 48bbc53e3a..cee0896c2f 100644 --- a/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +++ b/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py @@ -5,8 +5,7 @@ @pytest.mark.asyncio async def test_index_graph_edges_success(): - """Test that index_graph_edges uses the index datapoints and creates vector index.""" - # Create the mocks for the graph and vector engines. + """Test that index_graph_edges retrieves edges and delegates to index_data_points.""" mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = ( None, @@ -15,26 +14,23 @@ async def test_index_graph_edges_success(): [{"relationship_name": "rel2"}], ], ) - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100) + mock_index_data_points = AsyncMock() - # Patch the globals of the function so that when it does: - # vector_engine = get_vector_engine() - # graph_engine = await get_graph_engine() - # it uses the mocked versions. with patch.dict( index_graph_edges.__globals__, { "get_graph_engine": AsyncMock(return_value=mock_graph_engine), - "get_vector_engine": lambda: mock_vector_engine, + "index_data_points": mock_index_data_points, }, ): await index_graph_edges() - # Assertions on the mock calls. mock_graph_engine.get_graph_data.assert_awaited_once() - assert mock_vector_engine.create_vector_index.await_count == 1 - assert mock_vector_engine.index_data_points.await_count == 1 + mock_index_data_points.assert_awaited_once() + + call_args = mock_index_data_points.call_args[0][0] + assert len(call_args) == 2 + assert all(hasattr(item, "relationship_name") for item in call_args) @pytest.mark.asyncio @@ -42,20 +38,22 @@ async def test_index_graph_edges_no_relationships(): """Test that index_graph_edges handles empty relationships correctly.""" mock_graph_engine = AsyncMock() mock_graph_engine.get_graph_data.return_value = (None, []) - mock_vector_engine = AsyncMock() + mock_index_data_points = AsyncMock() with patch.dict( index_graph_edges.__globals__, { "get_graph_engine": AsyncMock(return_value=mock_graph_engine), - "get_vector_engine": lambda: mock_vector_engine, + "index_data_points": mock_index_data_points, }, ): await index_graph_edges() mock_graph_engine.get_graph_data.assert_awaited_once() - mock_vector_engine.create_vector_index.assert_not_awaited() - mock_vector_engine.index_data_points.assert_not_awaited() + mock_index_data_points.assert_awaited_once() + + call_args = mock_index_data_points.call_args[0][0] + assert len(call_args) == 0 @pytest.mark.asyncio