diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index c04c9198eb..2c86dba0a6 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -42,7 +42,7 @@ async def get_graph_data(self): async def query(self, query: str, params: dict): pass - async def has_node(self, node_id: str) -> bool: + async def has_node(self, node_id: UUID) -> bool: return self.graph.has_node(node_id) async def add_node(self, node: DataPoint) -> None: @@ -136,7 +136,7 @@ async def add_edges(self, edges: list[tuple[str, str, str, dict]]) -> None: logger.error(f"Failed to add edges: {e}") raise - async def get_edges(self, node_id: str): + async def get_edges(self, node_id: UUID): return list(self.graph.in_edges(node_id, data=True)) + list( self.graph.out_edges(node_id, data=True) ) @@ -174,13 +174,13 @@ async def get_disconnected_nodes(self) -> List[str]: return disconnected_nodes - async def extract_node(self, node_id: str) -> dict: + async def extract_node(self, node_id: UUID) -> dict: if self.graph.has_node(node_id): return self.graph.nodes[node_id] return None - async def extract_nodes(self, node_ids: List[str]) -> List[dict]: + async def extract_nodes(self, node_ids: List[UUID]) -> List[dict]: return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)] async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list: @@ -215,7 +215,7 @@ async def get_successors(self, node_id: UUID, edge_label: str = None) -> list: return nodes - async def get_neighbors(self, node_id: str) -> list: + async def get_neighbors(self, node_id: UUID) -> list: if not self.graph.has_node(node_id): return [] @@ -264,7 +264,7 @@ async def get_connections(self, node_id: UUID) -> list: return connections async def remove_connection_to_predecessors_of( - self, node_ids: list[str], edge_label: str + self, node_ids: list[UUID], edge_label: str ) -> None: for node_id in node_ids: if self.graph.has_node(node_id): @@ -275,7 +275,7 @@ async def remove_connection_to_predecessors_of( await self.save_graph_to_file(self.filename) async def remove_connection_to_successors_of( - self, node_ids: list[str], edge_label: str + self, node_ids: list[UUID], edge_label: str ) -> None: for node_id in node_ids: if self.graph.has_node(node_id): @@ -621,12 +621,12 @@ async def get_degree_one_nodes(self, node_type: str): nodes.append(node_data) return nodes - async def get_node(self, node_id: str) -> dict: + async def get_node(self, node_id: UUID) -> dict: if self.graph.has_node(node_id): return self.graph.nodes[node_id] return None - async def get_nodes(self, node_ids: List[str] = None) -> List[dict]: + async def get_nodes(self, node_ids: List[UUID] = None) -> List[dict]: if node_ids is None: return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)] return [