Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@



Build dynamic Agent memory using scalable, modular ECL (Extract, Cognify, Load) pipelines.
Build dynamic memory for Agents and replace RAG using scalable, modular ECL (Extract, Cognify, Load) pipelines.

More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github.com/topoteretes/cognee/tree/main/evals)

Expand All @@ -55,7 +55,7 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
## Features

- Interconnect and retrieve your past conversations, documents, images and audio transcriptions
- Reduce hallucinations, developer effort, and cost.
- Replaces RAG systems and reduces developer effort, and cost.
- Load data to graph and vector databases using only Pydantic
- Manipulate your data while ingesting from 30+ data sources

Expand Down
93 changes: 50 additions & 43 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

logger = get_logger("Neo4jAdapter", level=ERROR)

BASE_LABEL = "__Node__"

class Neo4jAdapter(GraphDBInterface):
"""
Expand All @@ -48,6 +49,12 @@ def __init__(
graph_database_url,
auth=(graph_database_username, graph_database_password),
max_connection_lifetime=120,
notifications_min_severity="OFF",
)
# Create contraint/index
self.query(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GA tests are passing, but I just found out that the indexing was actually never awaited, it's a minor thing. I can fix it after we merge.

("CREATE CONSTRAINT IF NOT EXISTS FOR "
f"(n:`{BASE_LABEL}`) REQUIRE n.id IS UNIQUE;")
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix async/await and trailing whitespace issues.

Two issues to address:

  1. The constraint creation uses a synchronous call to self.query() but query() is an async method
  2. Trailing whitespace on line 56

Apply this diff to fix both issues:

-            notifications_min_severity="OFF",
-        )
-        # Create contraint/index
-        self.query(
-            ("CREATE CONSTRAINT IF NOT EXISTS FOR " 
-            f"(n:`{BASE_LABEL}`) REQUIRE n.id IS UNIQUE;")
+            notifications_min_severity="OFF",
         )
+        # Create constraint/index asynchronously in first query call
+        self._constraint_created = False

Then add this method to handle lazy constraint creation:

async def _ensure_constraint(self):
    if not self._constraint_created:
        await self.query(
            f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:`{BASE_LABEL}`) REQUIRE n.id IS UNIQUE;"
        )
        self._constraint_created = True

Call await self._ensure_constraint() at the beginning of methods that modify nodes.

🧰 Tools
🪛 Pylint (3.3.7)

[convention] 56-56: Trailing whitespace

(C0303)

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines 52
to 58, the call to self.query() is incorrectly synchronous while query() is an
async method, and there is trailing whitespace on line 56. To fix this, convert
the constraint creation call to use await with an async method, remove the
trailing whitespace, and implement the provided async _ensure_constraint()
method to lazily create the constraint. Then, call await
self._ensure_constraint() at the start of any methods that modify nodes to
ensure the constraint is created before modifications.


@asynccontextmanager
Expand Down Expand Up @@ -103,8 +110,8 @@ async def has_node(self, node_id: str) -> bool:
- bool: True if the node exists, otherwise False.
"""
results = self.query(
"""
MATCH (n)
f"""
MATCH (n:`{BASE_LABEL}`)
WHERE n.id = $node_id
RETURN COUNT(n) > 0 AS node_exists
""",
Expand All @@ -129,7 +136,7 @@ async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())

query = dedent(
"""MERGE (node {id: $node_id})
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
WITH node, $node_label AS label
Expand Down Expand Up @@ -161,9 +168,9 @@ async def add_nodes(self, nodes: list[DataPoint]) -> None:

- None: None
"""
query = """
query = f"""
UNWIND $nodes AS node
MERGE (n {id: node.node_id})
MERGE (n: `{BASE_LABEL}`{{id: node.node_id}})
ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties, n.updated_at = timestamp()
WITH n, node.label AS label
Expand Down Expand Up @@ -215,9 +222,9 @@ async def extract_nodes(self, node_ids: List[str]):

A list of nodes represented as dictionaries.
"""
query = """
query = f"""
UNWIND $node_ids AS id
MATCH (node {id: id})
MATCH (node: `{BASE_LABEL}`{{id: id}})
RETURN node"""

params = {"node_ids": node_ids}
Expand All @@ -240,7 +247,7 @@ async def delete_node(self, node_id: str):

The result of the query execution, typically indicating success or failure.
"""
query = "MATCH (node {id: $node_id}) DETACH DELETE node"
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id}

return await self.query(query, params)
Expand All @@ -259,9 +266,9 @@ async def delete_nodes(self, node_ids: list[str]) -> None:

- None: None
"""
query = """
query = f"""
UNWIND $node_ids AS id
MATCH (node {id: id})
MATCH (node: `{BASE_LABEL}`{{id: id}})
DETACH DELETE node"""

params = {"node_ids": node_ids}
Expand All @@ -284,16 +291,15 @@ async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> boo

- bool: True if the edge exists, otherwise False.
"""
query = """
MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
RETURN COUNT(relationship) > 0 AS edge_exists
"""

params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}
Comment on lines +294 to 303
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Edge label should be parameterized to prevent Cypher injection.

The edge label is now interpolated directly into the query string instead of being passed as a parameter. This could lead to Cypher injection vulnerabilities if the edge label comes from untrusted sources.

Consider parameterizing the edge label or validating it against a whitelist of allowed labels before interpolation.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
294 to 303, the edge label is directly interpolated into the Cypher query
string, which risks Cypher injection. To fix this, avoid direct string
interpolation of the edge label. Instead, validate the edge label against a
predefined whitelist of allowed labels before including it in the query, or
refactor the query to use parameterized inputs if supported by the Neo4j driver.
This ensures only safe, expected edge labels are used in the query.


edge_exists = await self.query(query, params)
Expand Down Expand Up @@ -366,9 +372,9 @@ async def add_edge(

query = dedent(
f"""\
MATCH (from_node {{id: $from_node}}),
(to_node {{id: $to_node}})
MERGE (from_node)-[r:{relationship_name}]->(to_node)
MATCH (from_node :`{BASE_LABEL}`{{id: $from_node}}),
(to_node :`{BASE_LABEL}`{{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
Comment on lines +375 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Relationship name should be parameterized to prevent Cypher injection.

The relationship name is interpolated directly into the query string, creating a potential Cypher injection vulnerability.

Consider parameterizing the relationship name or validating it against a whitelist before interpolation:

-            MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
+            CALL apoc.merge.relationship(from_node, $relationship_name, {}, $properties, to_node) YIELD rel AS r

And update the params to include the relationship name:

         params = {
             "from_node": str(from_node),
             "to_node": str(to_node),
-            "relationship_name": relationship_name,
+            "relationship_name": relationship_name,
             "properties": serialized_properties,
         }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
375 to 377, the relationship name is directly interpolated into the Cypher query
string, which poses a Cypher injection risk. To fix this, validate the
relationship name against a predefined whitelist of allowed names before
including it in the query. Alternatively, if possible, parameterize the
relationship name safely. Update the query construction to use the validated or
parameterized relationship name and ensure the parameters dictionary includes
this validated value instead of direct string interpolation.

ON CREATE SET r += $properties, r.updated_at = timestamp()
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
Expand Down Expand Up @@ -400,17 +406,17 @@ async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) ->

- None: None
"""
query = """
query = f"""
UNWIND $edges AS edge
MATCH (from_node {id: edge.from_node})
MATCH (to_node {id: edge.to_node})
MATCH (from_node: `{BASE_LABEL}`{{id: edge.from_node}})
MATCH (to_node: `{BASE_LABEL}`{{id: edge.to_node}})
CALL apoc.merge.relationship(
from_node,
edge.relationship_name,
{
{{
source_node_id: edge.from_node,
target_node_id: edge.to_node
},
}},
edge.properties,
to_node
) YIELD rel
Expand Down Expand Up @@ -451,8 +457,8 @@ async def get_edges(self, node_id: str):

A list of edges connecting to the specified node, represented as tuples of details.
"""
query = """
MATCH (n {id: $node_id})-[r]-(m)
query = f"""
MATCH (n: `{BASE_LABEL}`{{id: $node_id}})-[r]-(m)
RETURN n, r, m
"""

Expand Down Expand Up @@ -525,24 +531,23 @@ async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[s
- list[str]: A list of predecessor node IDs.
"""
if edge_label is not None:
query = """
MATCH (node)<-[r]-(predecessor)
WHERE node.id = $node_id AND type(r) = $edge_label
query = f"""
MATCH (node: `{BASE_LABEL}`)<-[r:`{edge_label}`]-(predecessor)
WHERE node.id = $node_id
Comment on lines +534 to +536
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Edge label should be parameterized to prevent Cypher injection.

The edge label is interpolated directly into the query string, creating a potential Cypher injection vulnerability.

Consider using parameterized queries or validating the edge label against a whitelist before interpolation.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
534 to 536, the edge_label is directly interpolated into the Cypher query
string, which risks Cypher injection. To fix this, avoid direct string
interpolation of edge_label; instead, validate edge_label against a predefined
whitelist of allowed labels before including it in the query. Alternatively,
refactor the query to use parameterized queries if supported by the Neo4j
driver, ensuring edge_label is safely handled without direct string insertion.

RETURN predecessor
"""

results = await self.query(
query,
dict(
node_id=node_id,
edge_label=edge_label,
),
)

return [result["predecessor"] for result in results]
else:
query = """
MATCH (node)<-[r]-(predecessor)
query = f"""
MATCH (node: `{BASE_LABEL}`)<-[r]-(predecessor)
WHERE node.id = $node_id
RETURN predecessor
"""
Expand Down Expand Up @@ -572,9 +577,9 @@ async def get_successors(self, node_id: str, edge_label: str = None) -> list[str
- list[str]: A list of successor node IDs.
"""
if edge_label is not None:
query = """
MATCH (node)-[r]->(successor)
WHERE node.id = $node_id AND type(r) = $edge_label
query = f"""
MATCH (node: `{BASE_LABEL}`)-[r:`{edge_label}`]->(successor)
WHERE node.id = $node_id
Comment on lines +580 to +582
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Security concern: Edge label should be parameterized to prevent Cypher injection.

The edge label is interpolated directly into the query string, creating a potential Cypher injection vulnerability.

Consider using parameterized queries or validating the edge label against a whitelist before interpolation.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/graph/neo4j_driver/adapter.py around lines
580 to 582, the edge label is directly interpolated into the Cypher query
string, which poses a Cypher injection risk. To fix this, avoid direct string
interpolation of the edge label; instead, validate the edge label against a
predefined whitelist of allowed labels before including it in the query.
Alternatively, restructure the query to use parameterized inputs for the edge
label if supported by the Neo4j driver, ensuring no untrusted input is directly
embedded in the query string.

RETURN successor
"""

Expand All @@ -588,8 +593,8 @@ async def get_successors(self, node_id: str, edge_label: str = None) -> list[str

return [result["successor"] for result in results]
else:
query = """
MATCH (node)-[r]->(successor)
query = f"""
MATCH (node: `{BASE_LABEL}`)-[r]->(successor)
WHERE node.id = $node_id
RETURN successor
"""
Expand Down Expand Up @@ -634,8 +639,8 @@ async def get_node(self, node_id: str) -> Optional[Dict[str, Any]]:
- Optional[Dict[str, Any]]: The requested node as a dictionary, or None if it does
not exist.
"""
query = """
MATCH (node {id: $node_id})
query = f"""
MATCH (node: `{BASE_LABEL}`{{id: $node_id}})
RETURN node
"""
results = await self.query(query, {"node_id": node_id})
Expand All @@ -655,9 +660,9 @@ async def get_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:

- List[Dict[str, Any]]: A list of nodes represented as dictionaries.
"""
query = """
query = f"""
UNWIND $node_ids AS id
MATCH (node {id: id})
MATCH (node:`{BASE_LABEL}` {{id: id}})
RETURN node
"""
results = await self.query(query, {"node_ids": node_ids})
Expand All @@ -677,13 +682,13 @@ async def get_connections(self, node_id: UUID) -> list:

- list: A list of connections represented as tuples of details.
"""
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
predecessors_query = f"""
MATCH (node:`{BASE_LABEL}`)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
successors_query = f"""
MATCH (node:`{BASE_LABEL}`)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
Expand Down Expand Up @@ -723,6 +728,7 @@ async def remove_connection_to_predecessors_of(

- None: None
"""
# Not understanding
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)-[r:{edge_label}]->(predecessor)
Expand Down Expand Up @@ -751,6 +757,7 @@ async def remove_connection_to_successors_of(

- None: None
"""
# Not understanding
query = f"""
UNWIND $node_ids AS id
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ async def get_num_connected_components(adapter: Neo4jAdapter, graph_name: str):
found.
"""
query = f"""
CALL gds.wcc.stream('{graph_name}')
YIELD componentId
RETURN count(DISTINCT componentId) AS num_connected_components;
CALL gds.wcc.stats('{graph_name}')
YIELD componentCount
RETURN componentCount AS num_connected_components;
"""

result = await adapter.query(query)
Expand Down Expand Up @@ -181,9 +181,9 @@ async def get_avg_clustering(adapter: Neo4jAdapter, graph_name: str):
The average clustering coefficient as a float, or 0 if no results are available.
"""
query = f"""
CALL gds.localClusteringCoefficient.stream('{graph_name}')
YIELD localClusteringCoefficient
RETURN avg(localClusteringCoefficient) AS avg_clustering;
CALL gds.localClusteringCoefficient.stats('{graph_name}')
YIELD averageClusteringCoefficient
RETURN averageClusteringCoefficient AS avg_clustering;
"""

result = await adapter.query(query)
Expand Down
6 changes: 2 additions & 4 deletions cognee/modules/users/methods/get_authenticated_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ async def get_authenticated_user(authorization: str = Header(...)) -> SimpleName
token, os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret"), algorithms=["HS256"]
)

if payload["tenant_id"]:
if payload.get("tenant_id"):
# SimpleNamespace lets us access dictionary elements like attributes
auth_data = SimpleNamespace(
id=UUID(payload["user_id"]),
tenant_id=UUID(payload["tenant_id"]),
roles=payload["roles"],
)
else:
auth_data = SimpleNamespace(
id=UUID(payload["user_id"]), tenant_id=None, roles=payload["roles"]
)
auth_data = SimpleNamespace(id=UUID(payload["user_id"]), tenant_id=None, roles=[])

return auth_data

Expand Down
21 changes: 14 additions & 7 deletions examples/python/relational_database_migration_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import cognee
import os
import logging

from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.api.v1.visualize.visualize import visualize_graph
Expand All @@ -10,13 +9,12 @@
)

from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user

from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
create_db_and_tables as create_vector_db_and_tables,
)

# Prerequisites:
Expand All @@ -25,17 +23,23 @@
# LLM_API_KEY = "your_key_here"
# 3. Fill all relevant MIGRATION_DB information for the database you want to migrate to graph / Cognee

# NOTE: If you don't have a DB you want to migrate you can try it out with our
# test database at the following location:
# MIGRATION_DB_PATH="/{path_to_your_local_cognee}/cognee/tests/test_data"
# MIGRATION_DB_NAME="migration_database.sqlite"
# MIGRATION_DB_PROVIDER="sqlite"


async def main():
engine = get_migration_relational_engine()

# Clean all data stored in Cognee
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

# Needed to create principals table
# Create tables for databases
# Needed to create appropriate tables only on the Cognee side
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
await create_vector_db_and_tables()

print("\nExtracting schema of database to migrate.")
schema = await engine.extract_schema()
Expand All @@ -57,8 +61,11 @@ async def main():
await visualize_graph(destination_file_path)
print(f"Visualization can be found at: {destination_file_path}")

# Make sure to set top_k at a high value for a broader search, the default value is only 10!
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="What kind of data do you contain?"
query_type=SearchType.GRAPH_COMPLETION,
query_text="What kind of data do you contain?",
top_k=1000,
)
print(f"Search results: {search_results}")

Expand Down
Loading