diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 81a828bd83..edde075658 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -1,7 +1,9 @@ from os import path -from typing import AsyncGenerator +from uuid import UUID +from typing import Optional +from typing import AsyncGenerator, List from contextlib import asynccontextmanager -from sqlalchemy import text, select +from sqlalchemy import text, select, MetaData, Table from sqlalchemy.orm import joinedload from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker @@ -50,11 +52,14 @@ async def create_table(self, schema_name: str, table_name: str, table_config: li await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});")) await connection.close() - async def delete_table(self, table_name: str): + async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"): async with self.engine.begin() as connection: - await connection.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE;")) - - await connection.close() + if self.engine.dialect.name == "sqlite": + # SQLite doesn’t support schema namespaces and the CASCADE keyword. + # However, foreign key constraint can be defined with ON DELETE CASCADE during table creation. + await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};")) + else: + await connection.execute(text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;")) async def insert_data(self, schema_name: str, table_name: str, data: list[dict]): columns = ", ".join(data[0].keys()) @@ -65,6 +70,55 @@ async def insert_data(self, schema_name: str, table_name: str, data: list[dict]) await connection.execute(insert_query, data) await connection.close() + async def get_schema_list(self) -> List[str]: + """ + Return a list of all schema names in database + """ + if self.engine.dialect.name == "postgresql": + async with self.engine.begin() as connection: + result = await connection.execute( + text(""" + SELECT schema_name FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema'); + """) + ) + return [schema[0] for schema in result.fetchall()] + return [] + + async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"): + """ + Delete data in given table based on id. Table must have an id Column. + """ + async with self.get_async_session() as session: + TableModel = await self.get_table(table_name, schema_name) + await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) + await session.commit() + + async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table: + """ + Dynamically loads a table using the given table name and schema name. + """ + async with self.engine.begin() as connection: + if self.engine.dialect.name == "sqlite": + # Load the schema information into the MetaData object + await connection.run_sync(Base.metadata.reflect) + if table_name in Base.metadata.tables: + return Base.metadata.tables[table_name] + else: + raise ValueError(f"Table '{table_name}' not found.") + else: + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect, schema=schema_name) + # Define the full table name + full_table_name = f"{schema_name}.{table_name}" + # Check if table is in list of tables for the given schema + if full_table_name in metadata.tables: + return metadata.tables[full_table_name] + raise ValueError(f"Table '{full_table_name}' not found.") + + async def get_data(self, table_name: str, filters: dict = None): async with self.engine.begin() as connection: query = f"SELECT * FROM {table_name}" @@ -119,12 +173,17 @@ async def delete_database(self): self.db_path = None else: async with self.engine.begin() as connection: - # Load the schema information into the MetaData object - await connection.run_sync(Base.metadata.reflect) - for table in Base.metadata.sorted_tables: - drop_table_query = text(f"DROP TABLE IF EXISTS {table.name} CASCADE") - await connection.execute(drop_table_query) - + schema_list = await self.get_schema_list() + # Create a MetaData instance to load table information + metadata = MetaData() + # Drop all tables from all schemas + for schema_name in schema_list: + # Load the schema information into the MetaData object + await connection.run_sync(metadata.reflect, schema=schema_name) + for table in metadata.sorted_tables: + drop_table_query = text(f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE") + await connection.execute(drop_table_query) + metadata.clear() except Exception as e: print(f"Error deleting database: {e}")