Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ async def embed_text(self, text: List[str]) -> List[List[float]]:
litellm.aembedding(
model=self.model,
input=text,
api_key=self.api_key if self.api_key and self.api_key.strip() != "" else "EMPTY",
api_key=self.api_key
if self.api_key and self.api_key.strip() != ""
else "EMPTY",
api_base=self.endpoint,
api_version=self.api_version,
),
Expand Down
6 changes: 6 additions & 0 deletions cognee/modules/graph/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .get_formatted_graph_data import get_formatted_graph_data

from .delete_data_related_nodes import delete_data_related_nodes
from .delete_data_related_edges import delete_data_related_edges

from .delete_dataset_related_nodes import delete_dataset_related_nodes
from .delete_dataset_related_edges import delete_dataset_related_edges
13 changes: 13 additions & 0 deletions cognee/modules/graph/methods/delete_data_related_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Edge


@with_async_session
async def delete_data_related_edges(data_id: UUID, session: AsyncSession):
edges = (await session.scalars(select(Edge).where(Edge.data_id == data_id))).all()

await session.execute(delete(Edge).where(Edge.id.in_([edge.id for edge in edges])))
13 changes: 13 additions & 0 deletions cognee/modules/graph/methods/delete_data_related_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node


@with_async_session
async def delete_data_related_nodes(data_id: UUID, session: AsyncSession):
nodes = (await session.scalars(select(Node).where(Node.data_id == data_id))).all()

await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))
13 changes: 13 additions & 0 deletions cognee/modules/graph/methods/delete_dataset_related_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Edge


@with_async_session
async def delete_dataset_related_edges(dataset_id: UUID, session: AsyncSession):
edges = (await session.scalars(select(Edge).where(Edge.dataset_id == dataset_id))).all()

await session.execute(delete(Edge).where(Edge.id.in_([edge.id for edge in edges])))
13 changes: 13 additions & 0 deletions cognee/modules/graph/methods/delete_dataset_related_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from uuid import UUID
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node


@with_async_session
async def delete_dataset_related_nodes(dataset_id: UUID, session: AsyncSession):
nodes = (await session.scalars(select(Node).where(Node.dataset_id == dataset_id))).all()

await session.execute(delete(Node).where(Node.id.in_([node.id for node in nodes])))
58 changes: 58 additions & 0 deletions cognee/modules/graph/models/Edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from datetime import datetime, timezone
from sqlalchemy import (
# event,
DateTime,
JSON,
UUID,
Text,
)

# from sqlalchemy.schema import DDL
from sqlalchemy.orm import Mapped, mapped_column

from cognee.infrastructure.databases.relational import Base


class Edge(Base):
__tablename__ = "edges"

id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)

slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)

user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)

data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)

dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)

source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)

relationship_name: Mapped[str] = mapped_column(Text, nullable=False)

label: Mapped[str | None] = mapped_column(Text)
attributes: Mapped[dict | None] = mapped_column(JSON)

created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)

# __table_args__ = (
# {"postgresql_partition_by": "HASH (user_id)"}, # partitioning by user
# )


# Enable row-level security (RLS) for edges
# enable_edge_rls = DDL("""
# ALTER TABLE edges ENABLE ROW LEVEL SECURITY;
# """)
# create_user_isolation_policy = DDL("""
# CREATE POLICY user_isolation_policy
# ON edges
# USING (user_id = current_setting('app.current_user_id')::uuid)
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
# """)

# event.listen(Edge.__table__, "after_create", enable_edge_rls)
# event.listen(Edge.__table__, "after_create", create_user_isolation_policy)
59 changes: 59 additions & 0 deletions cognee/modules/graph/models/Node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from datetime import datetime, timezone
from sqlalchemy import (
DateTime,
Index,
# event,
String,
JSON,
UUID,
)

# from sqlalchemy.schema import DDL
from sqlalchemy.orm import Mapped, mapped_column

from cognee.infrastructure.databases.relational import Base


class Node(Base):
__tablename__ = "nodes"

id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)

slug: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)

user_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)

data_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)

dataset_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False)

label: Mapped[str | None] = mapped_column(String(255))
type: Mapped[str] = mapped_column(String(255), nullable=False)
indexed_fields: Mapped[list] = mapped_column(JSON, nullable=False)

attributes: Mapped[dict | None] = mapped_column(JSON)

created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)

__table_args__ = (
Index("index_node_dataset_slug", "dataset_id", "slug"),
Index("index_node_dataset_data", "dataset_id", "data_id"),
# {"postgresql_partition_by": "HASH (user_id)"}, # HASH partitioning on user_id
)


# Enable row-level security (RLS) for nodes
# enable_node_rls = DDL("""
# ALTER TABLE nodes ENABLE ROW LEVEL SECURITY;
# """)
# create_user_isolation_policy = DDL("""
# CREATE POLICY user_isolation_policy
# ON nodes
# USING (user_id = current_setting('app.current_user_id')::uuid)
# WITH CHECK (user_id = current_setting('app.current_user_id')::uuid);
# """)

# event.listen(Node.__table__, "after_create", enable_node_rls)
# event.listen(Node.__table__, "after_create", create_user_isolation_policy)
2 changes: 2 additions & 0 deletions cognee/modules/graph/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .Edge import Edge
from .Node import Node
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
lifetime_seconds = int(os.getenv("JWT_LIFETIME_SECONDS", "3600"))
except ValueError:
lifetime_seconds = 3600

return APIJWTStrategy(secret, lifetime_seconds=lifetime_seconds)

auth_backend = AuthenticationBackend(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock

from cognee.modules.graph.methods import delete_data_related_edges


class DummyScalarResult:
def __init__(self, items):
self._items = items

def all(self):
return self._items


class FakeEdge:
def __init__(self, edge_id):
self.id = edge_id


@pytest.mark.asyncio
async def test_delete_data_related_edges_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeEdge(1), FakeEdge(2)]))
session.execute = AsyncMock()

await delete_data_related_edges(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()


@pytest.mark.asyncio
async def test_delete_data_related_edges_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()

await delete_data_related_edges(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock

from cognee.modules.graph.methods import delete_data_related_nodes


class DummyScalarResult:
def __init__(self, items):
self._items = items

def all(self):
return self._items


class FakeNode:
def __init__(self, node_id):
self.id = node_id


@pytest.mark.asyncio
async def test_delete_data_related_nodes_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeNode(1), FakeNode(2)]))
session.execute = AsyncMock()

await delete_data_related_nodes(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()


@pytest.mark.asyncio
async def test_delete_data_related_nodes_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()

await delete_data_related_nodes(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock

from cognee.modules.graph.methods import delete_dataset_related_edges


class DummyScalarResult:
def __init__(self, items):
self._items = items

def all(self):
return self._items


class FakeEdge:
def __init__(self, edge_id):
self.id = edge_id


@pytest.mark.asyncio
async def test_delete_dataset_related_edges_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeEdge(1), FakeEdge(2)]))
session.execute = AsyncMock()

await delete_dataset_related_edges(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()


@pytest.mark.asyncio
async def test_delete_dataset_related_edges_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()

await delete_dataset_related_edges(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from uuid import uuid4
from types import SimpleNamespace
from unittest.mock import AsyncMock

from cognee.modules.graph.methods import delete_dataset_related_nodes


class DummyScalarResult:
def __init__(self, items):
self._items = items

def all(self):
return self._items


class FakeNode:
def __init__(self, node_id):
self.id = node_id


@pytest.mark.asyncio
async def test_delete_dataset_related_nodes_deletes_found_rows():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([FakeNode(1), FakeNode(2)]))
session.execute = AsyncMock()

await delete_dataset_related_nodes(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()


@pytest.mark.asyncio
async def test_delete_dataset_related_nodes_handles_empty_list():
session = SimpleNamespace()
session.scalars = AsyncMock(return_value=DummyScalarResult([]))
session.execute = AsyncMock()

await delete_dataset_related_nodes(uuid4(), session=session)

session.scalars.assert_awaited_once()
session.execute.assert_awaited_once()
Loading