diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 8041aeaea9..c197efc72c 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -1,15 +1,23 @@ +import os from os import path +import logging from uuid import UUID from typing import Optional from typing import AsyncGenerator, List from contextlib import asynccontextmanager -from sqlalchemy import text, select, MetaData, Table +from sqlalchemy import text, select, MetaData, Table, delete from sqlalchemy.orm import joinedload +from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from cognee.infrastructure.databases.exceptions import EntityNotFoundError +from cognee.modules.data.models.Data import Data + from ..ModelBase import Base + +logger = logging.getLogger(__name__) + class SQLAlchemyAdapter(): def __init__(self, connection_string: str): self.db_path: str = None @@ -86,9 +94,9 @@ async def get_schema_list(self) -> List[str]: return [schema[0] for schema in result.fetchall()] return [] - async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"): + async def delete_entity_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"): """ - Delete data in given table based on id. Table must have an id Column. + Delete entity in given table based on id. Table must have an id Column. """ if self.engine.dialect.name == "sqlite": async with self.get_async_session() as session: @@ -107,6 +115,42 @@ async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: O await session.commit() + async def delete_data_entity(self, data_id: UUID): + """ + Delete data and local files related to data if there are no references to it anymore. + """ + async with self.get_async_session() as session: + if self.engine.dialect.name == "sqlite": + # Foreign key constraints are disabled by default in SQLite (for backwards compatibility), + # so must be enabled for each database connection/session separately. + await session.execute(text("PRAGMA foreign_keys = ON;")) + + try: + data_entity = (await session.scalars(select(Data).where(Data.id == data_id))).one() + except (ValueError, NoResultFound) as e: + raise EntityNotFoundError(message=f"Entity not found: {str(e)}") + + # Check if other data objects point to the same raw data location + raw_data_location_entities = (await session.execute( + select(Data.raw_data_location).where(Data.raw_data_location == data_entity.raw_data_location))).all() + + # Don't delete local file unless this is the only reference to the file in the database + if len(raw_data_location_entities) == 1: + + # delete local file only if it's created by cognee + from cognee.base_config import get_base_config + config = get_base_config() + + if config.data_root_directory in raw_data_location_entities[0].raw_data_location: + if os.path.exists(raw_data_location_entities[0].raw_data_location): + os.remove(raw_data_location_entities[0].raw_data_location) + else: + # Report bug as file should exist + logger.error("Local file which should exist can't be found.") + + await session.execute(delete(Data).where(Data.id == data_id)) + await session.commit() + async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table: """ Dynamically loads a table using the given table name and schema name. diff --git a/cognee/modules/data/methods/delete_data.py b/cognee/modules/data/methods/delete_data.py index c0493a6068..65abe714a5 100644 --- a/cognee/modules/data/methods/delete_data.py +++ b/cognee/modules/data/methods/delete_data.py @@ -17,4 +17,4 @@ async def delete_data(data: Data): db_engine = get_relational_engine() - return await db_engine.delete_data_by_id(data.__tablename__, data.id) + return await db_engine.delete_data_entity(data.id) diff --git a/cognee/modules/data/methods/delete_dataset.py b/cognee/modules/data/methods/delete_dataset.py index c2205144d0..96a2e7d716 100644 --- a/cognee/modules/data/methods/delete_dataset.py +++ b/cognee/modules/data/methods/delete_dataset.py @@ -4,4 +4,4 @@ async def delete_dataset(dataset: Dataset): db_engine = get_relational_engine() - return await db_engine.delete_data_by_id(dataset.__tablename__, dataset.id) + return await db_engine.delete_entity_by_id(dataset.__tablename__, dataset.id) diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 9554a3f9d3..4179040890 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -2,12 +2,53 @@ import logging import pathlib import cognee + +from cognee.modules.data.models import Data from cognee.api.v1.search import SearchType from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.users.methods import get_default_user logging.basicConfig(level=logging.DEBUG) +async def test_local_file_deletion(data_text, file_location): + from sqlalchemy import select + import hashlib + from cognee.infrastructure.databases.relational import get_relational_engine + + engine = get_relational_engine() + + async with engine.get_async_session() as session: + # Get hash of data contents + encoded_text = data_text.encode("utf-8") + data_hash = hashlib.md5(encoded_text).hexdigest() + # Get data entry from database based on hash contents + data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one() + assert os.path.isfile(data.raw_data_location), f"Data location doesn't exist: {data.raw_data_location}" + # Test deletion of data along with local files created by cognee + await engine.delete_data_entity(data.id) + assert not os.path.exists( + data.raw_data_location), f"Data location still exists after deletion: {data.raw_data_location}" + + async with engine.get_async_session() as session: + # Get data entry from database based on file path + data = (await session.scalars(select(Data).where(Data.raw_data_location == file_location))).one() + assert os.path.isfile(data.raw_data_location), f"Data location doesn't exist: {data.raw_data_location}" + # Test local files not created by cognee won't get deleted + await engine.delete_data_entity(data.id) + assert os.path.exists(data.raw_data_location), f"Data location doesn't exists: {data.raw_data_location}" + +async def test_getting_of_documents(dataset_name_1): + # Test getting of documents for search per dataset + from cognee.modules.users.permissions.methods import get_document_ids_for_user + user = await get_default_user() + document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) + assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1" + + # Test getting of documents for search when no dataset is provided + user = await get_default_user() + document_ids = await get_document_ids_for_user(user.id) + assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2" + async def main(): cognee.config.set_vector_db_config( @@ -67,16 +108,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine - # Test getting of documents for search per dataset - from cognee.modules.users.permissions.methods import get_document_ids_for_user - user = await get_default_user() - document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) - assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1" - - # Test getting of documents for search when no dataset is provided - user = await get_default_user() - document_ids = await get_document_ids_for_user(user.id) - assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2" + await test_getting_of_documents(dataset_name_1) vector_engine = get_vector_engine() random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0] @@ -106,6 +138,8 @@ async def main(): results = await brute_force_triplet_search('What is a quantum computer?') assert len(results) > 0 + await test_local_file_deletion(text, explanation_file_path) + await cognee.prune.prune_data() assert not os.path.isdir(data_directory_path), "Local data files are not deleted"