Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cognee/api/v1/cognify/cognify_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(add_data_points, task_config = { "batch_size": 10 }),
Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
summarization_model = cognee_config.summarization_model,
task_config = { "batch_size": 10 }
),
Task(add_data_points, task_config = { "batch_size": 10 }),
]

pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
Expand Down
3 changes: 3 additions & 0 deletions cognee/modules/chunking/TextChunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def read(self):
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = chunk_data["cut_type"],
contains = [],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
Expand All @@ -52,6 +53,7 @@ def read(self):
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
contains = [],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
Expand All @@ -73,6 +75,7 @@ def read(self):
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
contains = [],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
Expand Down
4 changes: 3 additions & 1 deletion cognee/modules/chunking/models/DocumentChunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from typing import List, Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity

class DocumentChunk(DataPoint):
__tablename__ = "document_chunk"
Expand All @@ -9,6 +10,7 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None

_metadata: Optional[dict] = {
"index_fields": ["text"],
Expand Down
1 change: 1 addition & 0 deletions cognee/modules/chunking/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .DocumentChunk import DocumentChunk
2 changes: 0 additions & 2 deletions cognee/modules/engine/models/Entity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.models.EntityType import EntityType


Expand All @@ -8,7 +7,6 @@ class Entity(DataPoint):
name: str
is_a: EntityType
description: str
mentioned_in: DocumentChunk

_metadata: dict = {
"index_fields": ["name"],
Expand Down
3 changes: 0 additions & 3 deletions cognee/modules/engine/models/EntityType.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk


class EntityType(DataPoint):
__tablename__ = "entity_type"
name: str
type: str
description: str
exists_in: DocumentChunk

_metadata: dict = {
"index_fields": ["name"],
Expand Down
17 changes: 9 additions & 8 deletions cognee/modules/graph/utils/expand_with_nodes_and_edges.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
generate_edge_name,
Expand All @@ -11,17 +11,19 @@


def expand_with_nodes_and_edges(
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
data_chunks: list[DocumentChunk],
chunk_graphs: list[KnowledgeGraph],
existing_edges_map: Optional[dict[str, bool]] = None,
):
if existing_edges_map is None:
existing_edges_map = {}

added_nodes_map = {}
relationships = []
data_points = []

for graph_source, graph in graph_node_index:
for index, data_chunk in enumerate(data_chunks):
graph = chunk_graphs[index]

if graph is None:
continue

Expand All @@ -38,7 +40,6 @@ def expand_with_nodes_and_edges(
name = type_node_name,
type = type_node_name,
description = type_node_name,
exists_in = graph_source,
)
added_nodes_map[f"{str(type_node_id)}_type"] = type_node
else:
Expand All @@ -50,9 +51,9 @@ def expand_with_nodes_and_edges(
name = node_name,
is_a = type_node,
description = node.description,
mentioned_in = graph_source,
)
data_points.append(entity_node)

data_chunk.contains.append(entity_node)
added_nodes_map[f"{str(node_id)}_entity"] = entity_node

# Add relationship that came from graphs.
Expand Down Expand Up @@ -80,4 +81,4 @@ def expand_with_nodes_and_edges(
)
existing_edges_map[edge_key] = True

return (data_points, relationships)
return (data_chunks, relationships)
152 changes: 57 additions & 95 deletions cognee/modules/graph/utils/get_graph_from_model.py
Original file line number Diff line number Diff line change
@@ -1,154 +1,116 @@
from datetime import datetime, timezone

from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model

async def get_graph_from_model(
data_point: DataPoint,
added_nodes: dict,
added_edges: dict,
visited_properties: dict = None,
include_root = True,
added_nodes = None,
added_edges = None,
visited_properties = None,
):
nodes = []
edges = []
added_nodes = added_nodes or {}
added_edges = added_edges or {}
visited_properties = visited_properties or {}

data_point_properties = {}
excluded_properties = set()

if str(data_point.id) in added_nodes:
return nodes, edges
properties_to_visit = set()

for field_name, field_value in data_point:
if field_name == "_metadata":
continue

if field_value is None:
excluded_properties.add(field_name)
continue

if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)

property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
property_key = str(data_point.id) + field_name + str(field_value.id)

if property_key in visited_properties:
continue

visited_properties[property_key] = True

nodes, edges = await add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
)
properties_to_visit.add(field_name)

continue

if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)

for field_value_item in field_value:
property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
for index, item in enumerate(field_value):
property_key = str(data_point.id) + field_name + str(item.id)

if property_key in visited_properties:
continue

visited_properties[property_key] = True

nodes, edges = await add_nodes_and_edges(
data_point,
field_name,
field_value_item,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
)
properties_to_visit.add(f"{field_name}.{index}")

continue

data_point_properties[field_name] = field_value

if include_root:

if include_root and str(data_point.id) not in added_nodes:
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
"__tablename__": data_point.__tablename__,
"__tablename__": (str, data_point.__tablename__),
},
exclude_fields = excluded_properties,
exclude_fields = list(excluded_properties),
)
nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[str(data_point.id)] = True

return nodes, edges
for field_name in properties_to_visit:
index = None

if "." in field_name:
field_name, index = field_name.split(".")

async def add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
):
property_nodes, property_edges = await get_graph_from_model(
field_value,
True,
added_nodes,
added_edges,
visited_properties,
)
field_value = getattr(data_point, field_name)

for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
if index is not None:
field_value = field_value[int(index)]

if str(field_value.id) in added_nodes:
continue

property_nodes, property_edges = await get_graph_from_model(
field_value,
include_root = True,
added_nodes = added_nodes,
added_edges = added_edges,
visited_properties = visited_properties,
)

for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
for node in property_nodes:
nodes.append(node)

if str(edge_key) not in added_edges:
for edge in property_edges:
edges.append(edge)
added_edges[str(edge_key)] = True

for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True

return (nodes, edges)


def get_own_properties(property_nodes, property_edges):

for property_node in get_own_property_nodes(property_nodes, property_edges):
if str(data_point.id) == str(property_node.id):
continue

edge_key = str(data_point.id) + str(property_node.id) + field_name

if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True

property_key = str(data_point.id) + field_name + str(field_value.id)
visited_properties[property_key] = True

return nodes, edges


def get_own_property_nodes(property_nodes, property_edges):
own_properties = []

destination_nodes = [str(property_edge[1]) for property_edge in property_edges]
Expand Down
13 changes: 8 additions & 5 deletions cognee/modules/graph/utils/retrieve_existing_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,34 @@


async def retrieve_existing_edges(
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
data_chunks: list[DataPoint],
chunk_graphs: list[KnowledgeGraph],
graph_engine: GraphDBInterface,
) -> dict[str, bool]:
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []

for graph_source, graph in graph_node_index:
for index, data_chunk in enumerate(data_chunks):
graph = chunk_graphs[index]

for node in graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)

if str(type_node_id) not in processed_nodes:
type_node_edges.append(
(str(graph_source), str(type_node_id), "exists_in")
(data_chunk.id, type_node_id, "exists_in")
)
processed_nodes[str(type_node_id)] = True

if str(entity_node_id) not in processed_nodes:
entity_node_edges.append(
(str(graph_source), entity_node_id, "mentioned_in")
(data_chunk.id, entity_node_id, "mentioned_in")
)
type_entity_edges.append(
(str(entity_node_id), str(type_node_id), "is_a")
(entity_node_id, type_node_id, "is_a")
)
processed_nodes[str(entity_node_id)] = True

Expand Down
Loading