Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions cognee/infrastructure/databases/graph/networkx/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def get_graph_data(self):
async def query(self, query: str, params: dict):
pass

async def has_node(self, node_id: str) -> bool:
async def has_node(self, node_id: UUID) -> bool:
return self.graph.has_node(node_id)

async def add_node(self, node: DataPoint) -> None:
Expand Down Expand Up @@ -136,7 +136,7 @@ async def add_edges(self, edges: list[tuple[str, str, str, dict]]) -> None:
logger.error(f"Failed to add edges: {e}")
raise

async def get_edges(self, node_id: str):
async def get_edges(self, node_id: UUID):
return list(self.graph.in_edges(node_id, data=True)) + list(
self.graph.out_edges(node_id, data=True)
)
Expand Down Expand Up @@ -174,13 +174,13 @@ async def get_disconnected_nodes(self) -> List[str]:

return disconnected_nodes

async def extract_node(self, node_id: str) -> dict:
async def extract_node(self, node_id: UUID) -> dict:
if self.graph.has_node(node_id):
return self.graph.nodes[node_id]

return None

async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
async def extract_nodes(self, node_ids: List[UUID]) -> List[dict]:
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]

async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
Expand Down Expand Up @@ -215,7 +215,7 @@ async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:

return nodes

async def get_neighbors(self, node_id: str) -> list:
async def get_neighbors(self, node_id: UUID) -> list:
if not self.graph.has_node(node_id):
return []

Expand Down Expand Up @@ -264,7 +264,7 @@ async def get_connections(self, node_id: UUID) -> list:
return connections

async def remove_connection_to_predecessors_of(
self, node_ids: list[str], edge_label: str
self, node_ids: list[UUID], edge_label: str
) -> None:
for node_id in node_ids:
if self.graph.has_node(node_id):
Expand All @@ -275,7 +275,7 @@ async def remove_connection_to_predecessors_of(
await self.save_graph_to_file(self.filename)

async def remove_connection_to_successors_of(
self, node_ids: list[str], edge_label: str
self, node_ids: list[UUID], edge_label: str
) -> None:
for node_id in node_ids:
if self.graph.has_node(node_id):
Expand Down Expand Up @@ -621,12 +621,12 @@ async def get_degree_one_nodes(self, node_type: str):
nodes.append(node_data)
return nodes

async def get_node(self, node_id: str) -> dict:
async def get_node(self, node_id: UUID) -> dict:
if self.graph.has_node(node_id):
return self.graph.nodes[node_id]
return None

async def get_nodes(self, node_ids: List[str] = None) -> List[dict]:
async def get_nodes(self, node_ids: List[UUID] = None) -> List[dict]:
if node_ids is None:
return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)]
return [
Expand Down
Loading