Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from os import path
from typing import AsyncGenerator
from uuid import UUID
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
from sqlalchemy import text, select
from sqlalchemy import text, select, MetaData, Table
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker

Expand Down Expand Up @@ -50,11 +52,14 @@ async def create_table(self, schema_name: str, table_name: str, table_config: li
await connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"))
await connection.close()

async def delete_table(self, table_name: str):
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
async with self.engine.begin() as connection:
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE;"))

await connection.close()
if self.engine.dialect.name == "sqlite":
# SQLite doesn’t support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
else:
await connection.execute(text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;"))
Comment on lines +55 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider using SQL parameters for table and schema names.

The implementation correctly handles different database types, but direct string interpolation in SQL queries could be vulnerable to SQL injection. Consider using parameters:

-                await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
+                await connection.execute(text("DROP TABLE IF EXISTS :table_name;"), {"table_name": table_name})
-                await connection.execute(text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;"))
+                await connection.execute(
+                    text("DROP TABLE IF EXISTS :schema_name.:table_name CASCADE;"),
+                    {"schema_name": schema_name, "table_name": table_name}
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
async with self.engine.begin() as connection:
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE;"))
await connection.close()
if self.engine.dialect.name == "sqlite":
# SQLite doesn’t support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
else:
await connection.execute(text(f"DROP TABLE IF EXISTS {schema_name}.{table_name} CASCADE;"))
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# SQLite doesn't support schema namespaces and the CASCADE keyword.
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
await connection.execute(text("DROP TABLE IF EXISTS :table_name;"), {"table_name": table_name})
else:
await connection.execute(
text("DROP TABLE IF EXISTS :schema_name.:table_name CASCADE;"),
{"schema_name": schema_name, "table_name": table_name}
)


async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
columns = ", ".join(data[0].keys())
Expand All @@ -65,6 +70,55 @@ async def insert_data(self, schema_name: str, table_name: str, data: list[dict])
await connection.execute(insert_query, data)
await connection.close()

async def get_schema_list(self) -> List[str]:
"""
Return a list of all schema names in database
"""
if self.engine.dialect.name == "postgresql":
async with self.engine.begin() as connection:
result = await connection.execute(
text("""
SELECT schema_name FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
""")
)
return [schema[0] for schema in result.fetchall()]
return []

async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
"""
Delete data in given table based on id. Table must have an id Column.
"""
async with self.get_async_session() as session:
TableModel = await self.get_table(table_name, schema_name)
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()

Comment on lines +88 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Enhance delete_data_by_id method robustness.

The method needs additional error handling and validation to ensure reliable operation.

Consider these improvements:

     async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
-        """
-        Delete data in given table based on id. Table must have an id Column.
-        """
+        """
+        Delete data in given table based on id.
+        
+        Args:
+            table_name: Name of the table
+            data_id: UUID of the record to delete
+            schema_name: Optional schema name, defaults to "public"
+            
+        Returns:
+            bool: True if deletion was successful, False otherwise
+            
+        Raises:
+            ValueError: If table doesn't have an id column
+        """
         async with self.get_async_session() as session:
-            TableModel = await self.get_table(table_name, schema_name)
-            await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
-            await session.commit()
+            try:
+                TableModel = await self.get_table(table_name, schema_name)
+                if 'id' not in TableModel.c:
+                    raise ValueError(f"Table '{table_name}' does not have an 'id' column")
+                result = await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
+                await session.commit()
+                return result.rowcount > 0
+            except Exception as e:
+                await session.rollback()
+                print(f"Error deleting record: {e}")
+                return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
"""
Delete data in given table based on id. Table must have an id Column.
"""
async with self.get_async_session() as session:
TableModel = await self.get_table(table_name, schema_name)
await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()
async def delete_data_by_id(self, table_name: str, data_id: UUID, schema_name: Optional[str] = "public"):
"""
Delete data in given table based on id.
Args:
table_name: Name of the table
data_id: UUID of the record to delete
schema_name: Optional schema name, defaults to "public"
Returns:
bool: True if deletion was successful, False otherwise
Raises:
ValueError: If table doesn't have an id column
"""
async with self.get_async_session() as session:
try:
TableModel = await self.get_table(table_name, schema_name)
if 'id' not in TableModel.c:
raise ValueError(f"Table '{table_name}' does not have an 'id' column")
result = await session.execute(TableModel.delete().where(TableModel.c.id == data_id))
await session.commit()
return result.rowcount > 0
except Exception as e:
await session.rollback()
print(f"Error deleting record: {e}")
return False

async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table:
"""
Dynamically loads a table using the given table name and schema name.
"""
async with self.engine.begin() as connection:
if self.engine.dialect.name == "sqlite":
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if table_name in Base.metadata.tables:
return Base.metadata.tables[table_name]
else:
raise ValueError(f"Table '{table_name}' not found.")
else:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect, schema=schema_name)
# Define the full table name
full_table_name = f"{schema_name}.{table_name}"
# Check if table is in list of tables for the given schema
if full_table_name in metadata.tables:
return metadata.tables[full_table_name]
raise ValueError(f"Table '{full_table_name}' not found.")


async def get_data(self, table_name: str, filters: dict = None):
async with self.engine.begin() as connection:
query = f"SELECT * FROM {table_name}"
Expand Down Expand Up @@ -119,12 +173,17 @@ async def delete_database(self):
self.db_path = None
else:
async with self.engine.begin() as connection:
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
for table in Base.metadata.sorted_tables:
drop_table_query = text(f"DROP TABLE IF EXISTS {table.name} CASCADE")
await connection.execute(drop_table_query)

schema_list = await self.get_schema_list()
# Create a MetaData instance to load table information
metadata = MetaData()
# Drop all tables from all schemas
for schema_name in schema_list:
# Load the schema information into the MetaData object
await connection.run_sync(metadata.reflect, schema=schema_name)
for table in metadata.sorted_tables:
drop_table_query = text(f"DROP TABLE IF EXISTS {schema_name}.{table.name} CASCADE")
await connection.execute(drop_table_query)
metadata.clear()
except Exception as e:
print(f"Error deleting database: {e}")

Expand Down