Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@



Build dynamic Agent memory using scalable, modular ECL (Extract, Cognify, Load) pipelines.
Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Extract, Cognify, Load) pipelines.

More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github.com/topoteretes/cognee/tree/main/evals)

Expand All @@ -55,7 +55,7 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
## Features

- Interconnect and retrieve your past conversations, documents, images and audio transcriptions
- Reduce hallucinations, developer effort, and cost.
- Replaces RAG systems and reduces developer effort, and cost.
- Load data to graph and vector databases using only Pydantic
- Manipulate your data while ingesting from 30+ data sources

Expand Down
93 changes: 50 additions & 43 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

logger = get_logger("Neo4jAdapter", level=ERROR)

BASE_LABEL = "__Node__"


class Neo4jAdapter(GraphDBInterface):
"""
Expand All @@ -48,6 +50,11 @@ def __init__(
graph_database_url,
auth=(graph_database_username, graph_database_password),
max_connection_lifetime=120,
notifications_min_severity="OFF",
)
# Create contraint/index
self.query(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GA tests are passing, but I just found out that the indexing was actually never awaited, it's a minor thing. I can fix it after we merge.

(f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:`{BASE_LABEL}`) REQUIRE n.id IS UNIQUE;")
)

@asynccontextmanager
Expand Down Expand Up @@ -103,8 +110,8 @@ async def has_node(self, node_id: str) -> bool:
- bool: True if the node exists, otherwise False.
"""
results = self.query(
"""
MATCH (n)
f"""
MATCH (n:`{BASE_LABEL}`)
WHERE n.id = $node_id
RETURN COUNT(n) > 0 AS node_exists
""",
Expand All @@ -129,7 +136,7 @@ async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())

query = dedent(
"""MERGE (node {id: $node_id})
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
WITH node, $node_label AS label
Expand Down Expand Up @@ -161,9 +168,9 @@ async def add_nodes(self, nodes: list[DataPoint]) -> None:

- None: None
"""
query = """
query = f"""
UNWIND $nodes AS node
MERGE (n {id: node.node_id})
MERGE (n: `{BASE_LABEL}`{{id: node.node_id}})
ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties, n.updated_at = timestamp()
WITH n, node.label AS label
Expand Down Expand Up @@ -215,9 +222,9 @@ async def extract_nodes(self, node_ids: List[str]):

A list of nodes represented as dictionaries.
"""
query = """
query = f"""
UNWIND $node_ids AS id
MATCH (node {id: id})
MATCH (node: `{BASE_LABEL}`{{id: id}})
RETURN node"""

params = {"node_ids": node_ids}
Expand All @@ -240,7 +247,7 @@ async def delete_node(self, node_id: str):

The result of the query execution, typically indicating success or failure.
"""
query = "MATCH (node {id: $node_id}) DETACH DELETE node"
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id}

return await self.query(query, params)
Expand All @@ -259,9 +266,9 @@ async def delete_nodes(self, node_ids: list[str]) -> None:

- None: None
"""
query = """
query = f"""
UNWIND $node_ids AS id
MATCH (node {id: id})
MATCH (node: `{BASE_LABEL}`{{id: id}})
DETACH DELETE node"""

params = {"node_ids": node_ids}
Expand All @@ -284,16 +291,15 @@ async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> boo

- bool: True if the edge exists, otherwise False.
"""
query = """
MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
RETURN COUNT(relationship) > 0 AS edge_exists
"""

params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}
Comment on lines +294 to 303
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Edge label should be parameterized to prevent Cypher injection.

The edge label is now interpolated directly into the query string instead of being passed as a parameter. This could lead to Cypher injection vulnerabilities if the edge label comes from untrusted sources.

Consider parameterizing the edge label or validating it against a whitelist of allowed labels before interpolation.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
294 to 303, the edge label is directly interpolated into the Cypher query
string, which risks Cypher injection. To fix this, avoid direct string
interpolation of the edge label. Instead, validate the edge label against a
predefined whitelist of allowed labels before including it in the query, or
refactor the query to use parameterized inputs if supported by the Neo4j driver.
This ensures only safe, expected edge labels are used in the query.


edge_exists = await self.query(query, params)
Expand Down Expand Up @@ -366,9 +372,9 @@ async def add_edge(

query = dedent(
f"""\
MATCH (from_node {{id: $from_node}}),
(to_node {{id: $to_node}})
MERGE (from_node)-[r:{relationship_name}]->(to_node)
MATCH (from_node :`{BASE_LABEL}`{{id: $from_node}}),
(to_node :`{BASE_LABEL}`{{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
Comment on lines +375 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Relationship name should be parameterized to prevent Cypher injection.

The relationship name is interpolated directly into the query string, creating a potential Cypher injection vulnerability.

Consider parameterizing the relationship name or validating it against a whitelist before interpolation:

-            MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
+            CALL apoc.merge.relationship(from_node, $relationship_name, {}, $properties, to_node) YIELD rel AS r

And update the params to include the relationship name:

         params = {
             "from_node": str(from_node),
             "to_node": str(to_node),
-            "relationship_name": relationship_name,
+            "relationship_name": relationship_name,
             "properties": serialized_properties,
         }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
375 to 377, the relationship name is directly interpolated into the Cypher query
string, which poses a Cypher injection risk. To fix this, validate the
relationship name against a predefined whitelist of allowed names before
including it in the query. Alternatively, if possible, parameterize the
relationship name safely. Update the query construction to use the validated or
parameterized relationship name and ensure the parameters dictionary includes
this validated value instead of direct string interpolation.

ON CREATE SET r += $properties, r.updated_at = timestamp()
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
Expand Down Expand Up @@ -400,17 +406,17 @@ async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) ->

- None: None
"""
query = """
query = f"""
UNWIND $edges AS edge
MATCH (from_node {id: edge.from_node})
MATCH (to_node {id: edge.to_node})
MATCH (from_node: `{BASE_LABEL}`{{id: edge.from_node}})
MATCH (to_node: `{BASE_LABEL}`{{id: edge.to_node}})
CALL apoc.merge.relationship(
from_node,
edge.relationship_name,
{
{{
source_node_id: edge.from_node,
target_node_id: edge.to_node
},
}},
edge.properties,
to_node
) YIELD rel
Expand Down Expand Up @@ -451,8 +457,8 @@ async def get_edges(self, node_id: str):

A list of edges connecting to the specified node, represented as tuples of details.
"""
query = """
MATCH (n {id: $node_id})-[r]-(m)
query = f"""
MATCH (n: `{BASE_LABEL}`{{id: $node_id}})-[r]-(m)
RETURN n, r, m
"""

Expand Down Expand Up @@ -525,24 +531,23 @@ async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[s
- list[str]: A list of predecessor node IDs.
"""
if edge_label is not None:
query = """
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id AND type(r) = $edge_label
query = f"""
MATCH (node: `{BASE_LABEL}`)<-[r:`{edge_label}`]-(predecessor)
WHERE node.id = $node_id
Comment on lines +534 to +536
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Edge label should be parameterized to prevent Cypher injection.

The edge label is interpolated directly into the query string, creating a potential Cypher injection vulnerability.

Consider using parameterized queries or validating the edge label against a whitelist before interpolation.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
534 to 536, the edge_label is directly interpolated into the Cypher query
string, which risks Cypher injection. To fix this, avoid direct string
interpolation of edge_label; instead, validate edge_label against a predefined
whitelist of allowed labels before including it in the query. Alternatively,
refactor the query to use parameterized queries if supported by the Neo4j
driver, ensuring edge_label is safely handled without direct string insertion.

RETURN predecessor
"""

results = await self.query(
query,
dict(
node_id=node_id,
edge_label=edge_label,
),
)

return [result["predecessor"] for result in results]
else:
query = """
MATCH (node)<-[r]-(predecessor)
query = f"""
MATCH (node: `{BASE_LABEL}`)<-[r]-(predecessor)
WHERE node.id = $node_id
RETURN predecessor
"""
Expand Down Expand Up @@ -572,9 +577,9 @@ async def get_successors(self, node_id: str, edge_label: str = None) -> list[str
- list[str]: A list of successor node IDs.
"""
if edge_label is not None:
query = """
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id AND type(r) = $edge_label
query = f"""
MATCH (node: `{BASE_LABEL}`)-[r:`{edge_label}`]->(successor)
WHERE node.id = $node_id
Comment on lines +580 to +582
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Edge label should be parameterized to prevent Cypher injection.

The edge label is interpolated directly into the query string, creating a potential Cypher injection vulnerability.

Consider using parameterized queries or validating the edge label against a whitelist before interpolation.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
580 to 582, the edge label is directly interpolated into the Cypher query
string, which poses a Cypher injection risk. To fix this, avoid direct string
interpolation of the edge label; instead, validate the edge label against a
predefined whitelist of allowed labels before including it in the query.
Alternatively, restructure the query to use parameterized inputs for the edge
label if supported by the Neo4j driver, ensuring no untrusted input is directly
embedded in the query string.

RETURN successor
"""

Expand All @@ -588,8 +593,8 @@ async def get_successors(self, node_id: str, edge_label: str = None) -> list[str

return [result["successor"] for result in results]
else:
query = """
MATCH (node)-[r]->(successor)
query = f"""
MATCH (node: `{BASE_LABEL}`)-[r]->(successor)
WHERE node.id = $node_id
RETURN successor
"""
Expand Down Expand Up @@ -634,8 +639,8 @@ async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
- Optional[Dict[str, Any]]: The requested node as a dictionary, or None if it does
not exist.
"""
query = """
MATCH (node {id: $node_id})
query = f"""
MATCH (node: `{BASE_LABEL}`{{id: $node_id}})
RETURN node
"""
results = await self.query(query, {"node_id": node_id})
Expand All @@ -655,9 +660,9 @@ async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:

- List[Dict[str, Any]]: A list of nodes represented as dictionaries.
"""
query = """
query = f"""
UNWIND $node_ids AS id
MATCH (node {id: id})
MATCH (node:`{BASE_LABEL}` {{id: id}})
RETURN node
"""
results = await self.query(query, {"node_ids": node_ids})
Expand All @@ -677,13 +682,13 @@ async def get_connections(self, node_id: UUID) -> list:

- list: A list of connections represented as tuples of details.
"""
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
predecessors_query = f"""
MATCH (node:`{BASE_LABEL}`)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
successors_query = f"""
MATCH (node:`{BASE_LABEL}`)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
Expand Down Expand Up @@ -723,6 +728,7 @@ async def remove_connection_to_predecessors_of(

- None: None
"""
# Not understanding
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
Expand Down Expand Up @@ -751,6 +757,7 @@ async def remove_connection_to_successors_of(

- None: None
"""
# Not understanding
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ async def get_num_connected_components(adapter: Neo4jAdapter, graph_name: str):
found.
"""
query = f"""
CALL gds.wcc.stream('{graph_name}')
YIELD componentId
RETURN count(DISTINCT componentId) AS num_connected_components;
CALL gds.wcc.stats('{graph_name}')
YIELD componentCount
RETURN componentCount AS num_connected_components;
"""

result = await adapter.query(query)
Expand Down Expand Up @@ -181,9 +181,9 @@ async def get_avg_clustering(adapter: Neo4jAdapter, graph_name: str):
The average clustering coefficient as a float, or 0 if no results are available.
"""
query = f"""
CALL gds.localClusteringCoefficient.stream('{graph_name}')
YIELD localClusteringCoefficient
RETURN avg(localClusteringCoefficient) AS avg_clustering;
CALL gds.localClusteringCoefficient.stats('{graph_name}')
YIELD averageClusteringCoefficient
RETURN averageClusteringCoefficient AS avg_clustering;
"""

result = await adapter.query(query)
Expand Down
6 changes: 2 additions & 4 deletions cognee/modules/users/methods/get_authenticated_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ async def get_authenticated_user(authorization: str = Header(...)) -> SimpleName
token, os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret"), algorithms=["HS256"]
)

if payload["tenant_id"]:
if payload.get("tenant_id"):
# SimpleNamespace lets us access dictionary elements like attributes
auth_data = SimpleNamespace(
id=UUID(payload["user_id"]),
tenant_id=UUID(payload["tenant_id"]),
roles=payload["roles"],
)
else:
auth_data = SimpleNamespace(
id=UUID(payload["user_id"]), tenant_id=None, roles=payload["roles"]
)
auth_data = SimpleNamespace(id=UUID(payload["user_id"]), tenant_id=None, roles=[])

return auth_data

Expand Down
Loading