diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 55792f2da3..9385507ced 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -165,7 +165,6 @@ async def get_default_tasks( task_config={"batch_size": 10}, ), Task(add_data_points, task_config={"batch_size": 10}), - Task(store_descriptive_metrics, include_optional=True), ] except Exception as error: send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id) diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index f2685e73d6..767b211423 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -531,16 +531,144 @@ async def get_filtered_graph_data(self, attribute_filters): return (nodes, edges) + async def graph_exists(self, graph_name="myGraph"): + query = "CALL gds.graph.list() YIELD graphName RETURN collect(graphName) AS graphNames;" + result = await self.query(query) + graph_names = result[0]["graphNames"] if result else [] + return graph_name in graph_names + + async def project_entire_graph(self, graph_name="myGraph"): + """ + Projects all node labels and all relationship types into an in-memory GDS graph. + """ + if await self.graph_exists(graph_name): + return + + node_labels_query = "CALL db.labels() YIELD label RETURN collect(label) AS labels;" + node_labels_result = await self.query(node_labels_query) + node_labels = node_labels_result[0]["labels"] if node_labels_result else [] + + relationship_types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) AS relationships;" + relationship_types_result = await self.query(relationship_types_query) + relationship_types = ( + relationship_types_result[0]["relationships"] if relationship_types_result else [] + ) + + if not node_labels or not relationship_types: + raise ValueError("No node labels or relationship types found in the database.") + + node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]" + relationship_types_str = "[" + ", ".join(f"'{rel}'" for rel in relationship_types) + "]" + + query = f""" + CALL gds.graph.project( + '{graph_name}', + {node_labels_str}, + {relationship_types_str} + ) YIELD graphName; + """ + + await self.query(query) + + async def drop_graph(self, graph_name="myGraph"): + if await self.graph_exists(graph_name): + drop_query = f"CALL gds.graph.drop('{graph_name}');" + await self.query(drop_query) + async def get_graph_metrics(self, include_optional=False): - return { - "num_nodes": -1, - "num_edges": -1, - "mean_degree": -1, - "edge_density": -1, - "num_connected_components": -1, - "sizes_of_connected_components": -1, - "num_selfloops": -1, - "diameter": -1, - "avg_shortest_path_length": -1, - "avg_clustering": -1, + nodes, edges = await self.get_model_independent_graph_data() + graph_name = "myGraph" + await self.drop_graph(graph_name) + await self.project_entire_graph(graph_name) + + async def _get_edge_density(): + query = """ + MATCH (n) + WITH count(n) AS num_nodes + MATCH ()-[r]->() + WITH num_nodes, count(r) AS num_edges + RETURN CASE + WHEN num_nodes < 2 THEN 0 + ELSE num_edges * 1.0 / (num_nodes * (num_nodes - 1)) + END AS edge_density; + """ + result = await self.query(query) + return result[0]["edge_density"] if result else 0 + + async def _get_num_connected_components(): + await self.drop_graph(graph_name) + await self.project_entire_graph(graph_name) + + query = f""" + CALL gds.wcc.stream('{graph_name}') + YIELD componentId + RETURN count(DISTINCT componentId) AS num_connected_components; + """ + + result = await self.query(query) + return result[0]["num_connected_components"] if result else 0 + + async def _get_size_of_connected_components(): + await self.drop_graph(graph_name) + await self.project_entire_graph(graph_name) + + query = f""" + CALL gds.wcc.stream('{graph_name}') + YIELD componentId + RETURN componentId, count(*) AS size + ORDER BY size DESC; + """ + + result = await self.query(query) + return [record["size"] for record in result] if result else [] + + async def _count_self_loops(): + query = """ + MATCH (n)-[r]->(n) + RETURN count(r) AS self_loop_count; + """ + result = await self.query(query) + return result[0]["self_loop_count"] if result else 0 + + async def _get_diameter(): + logging.warning("Diameter calculation is not implemented for neo4j.") + return -1 + + async def _get_avg_shortest_path_length(): + logging.warning( + "Average shortest path length calculation is not implemented for neo4j." + ) + return -1 + + async def _get_avg_clustering(): + logging.warning("Average clustering calculation is not implemented for neo4j.") + return -1 + + num_nodes = len(nodes[0]["nodes"]) + num_edges = len(edges[0]["elements"]) + + mandatory_metrics = { + "num_nodes": num_nodes, + "num_edges": num_edges, + "mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None, + "edge_density": await _get_edge_density(), + "num_connected_components": await _get_num_connected_components(), + "sizes_of_connected_components": await _get_size_of_connected_components(), } + + if include_optional: + optional_metrics = { + "num_selfloops": await _count_self_loops(), + "diameter": await _get_diameter(), + "avg_shortest_path_length": await _get_avg_shortest_path_length(), + "avg_clustering": await _get_avg_clustering(), + } + else: + optional_metrics = { + "num_selfloops": -1, + "diameter": -1, + "avg_shortest_path_length": -1, + "avg_clustering": -1, + } + + return mandatory_metrics | optional_metrics diff --git a/cognee/modules/data/models/GraphMetrics.py b/cognee/modules/data/models/GraphMetrics.py index 5ce0d7ded9..8e4624c943 100644 --- a/cognee/modules/data/models/GraphMetrics.py +++ b/cognee/modules/data/models/GraphMetrics.py @@ -24,5 +24,5 @@ class GraphMetrics(Base): avg_shortest_path_length = Column(Float, nullable=True) avg_clustering = Column(Float, nullable=True) - created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now())