From 8c8c36f59c69bdec81fbd008d7d2d4267bc4edac Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Thu, 8 Dec 2022 13:28:29 -0800 Subject: [PATCH 1/2] Use Connection Pool to manage connections --- example/config.py | 11 +- sqlmesh/core/config.py | 54 ++----- sqlmesh/core/context.py | 61 +++++--- sqlmesh/core/dag.py | 29 ++-- sqlmesh/core/engine_adapter.py | 46 ++++-- sqlmesh/core/scheduler.py | 2 + sqlmesh/core/snapshot_evaluator.py | 19 +++ sqlmesh/engines/spark/app.py | 7 +- .../airflow/operators/databricks.py | 3 +- .../schedulers/airflow/operators/targets.py | 11 +- sqlmesh/utils/connection_pool.py | 138 ++++++++++++++++++ tests/core/test_context.py | 4 +- tests/core/test_engine_adapter.py | 37 ++--- tests/core/test_snapshot_evaluator.py | 2 +- tests/core/test_state_sync.py | 2 +- .../airflow/operators/test_targets.py | 6 +- tests/utils/test_connection_pool.py | 119 +++++++++++++++ 17 files changed, 419 insertions(+), 132 deletions(-) create mode 100644 sqlmesh/utils/connection_pool.py create mode 100644 tests/utils/test_connection_pool.py diff --git a/example/config.py b/example/config.py index 647db5f290..6485d8749e 100644 --- a/example/config.py +++ b/example/config.py @@ -3,16 +3,13 @@ import duckdb from sqlmesh.core.config import AirflowSchedulerBackend, Config -from sqlmesh.core.engine_adapter import EngineAdapter DATA_DIR = os.path.join(os.path.dirname(__file__), "data") DEFAULT_KWARGS = { - "dialect": "duckdb", # The default dialect of models is DuckDB SQL. - "engine_adapter": EngineAdapter( - duckdb.connect(), "duckdb" - ), # The default engine runs in DuckDB. + "engine_dialect": "duckdb", + "engine_connection_factory": duckdb.connect, } # An in memory DuckDB config. @@ -22,8 +19,8 @@ local_config = Config( **{ **DEFAULT_KWARGS, - "engine_adapter": EngineAdapter( - lambda: duckdb.connect(database=f"{DATA_DIR}/local.duckdb"), "duckdb" + "engine_connection_factory": lambda: duckdb.connect( + database=f"{DATA_DIR}/local.duckdb" ), } ) diff --git a/sqlmesh/core/config.py b/sqlmesh/core/config.py index 4bf1f551c5..d3544aa5a9 100644 --- a/sqlmesh/core/config.py +++ b/sqlmesh/core/config.py @@ -11,8 +11,8 @@ import duckdb from sqlmesh.core.engine_adapter import EngineAdapter local_config = Config( - engine_adapter=EngineAdapter(duckdb.connect(), "duckdb"), - dialect="duckdb" + engine_config_factory=duckdb.connect, + engine_dialect="duckdb" ) # End config.py @@ -25,17 +25,20 @@ >>> from sqlmesh import Context >>> from sqlmesh.core.config import Config >>> my_config = Config( - ... engine_adapter=EngineAdapter(duckdb.connect(), "duckdb"), - ... dialect="duckdb" + ... engine_connection_factory=duckdb.connect, + ... engine_dialect="duckdb" ... ) >>> context = Context(path="example", config=my_config) ``` - Individual config parameters used when initializing a Context. ```python - >>> adapter = EngineAdapter(duckdb.connect(), "duckdb") + >>> from sqlmesh import Context + >>> from sqlmesh.core.engine_adapter import EngineAdapter + >>> adapter = EngineAdapter(duckdb.connect, "duckdb") >>> context = Context( - ... path="example", engine_adapter=adapter, + ... path="example", + ... engine_adapter=adapter, ... dialect="duckdb", ... ) @@ -60,7 +63,7 @@ DEFAULT_KWARGS = { "dialect": "duckdb", # The default dialect of models is DuckDB SQL. - "engine_adapter": EngineAdapter(duckdb.connect(), "duckdb"), # The default engine runs in DuckDB. + "engine_adapter": EngineAdapter(duckdb.connect, "duckdb"), # The default engine runs in DuckDB. } # An in memory DuckDB config. @@ -102,13 +105,11 @@ import typing as t import duckdb -from pydantic import Field from requests import Session from sqlmesh.core import constants as c from sqlmesh.core._typing import NotificationTarget from sqlmesh.core.console import Console -from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.plan_evaluator import ( AirflowPlanEvaluator, BuiltInPlanEvaluator, @@ -256,34 +257,12 @@ class Config(PydanticModel): """ An object used by a Context to configure your SQLMesh project. - An engine adapter can lazily establish a database connection if it is passed a callable that returns a - database API compliant connection. - ```python - >>> from sqlmesh import Context - >>> context = Context( - ... path="example", - ... engine_adapter=EngineAdapter(duckdb.connect, "duckdb"), - ... dialect="duckdb" - ... ) - - ``` - ```python - >>> def create_connection(): - ... return duckdb.connect() - ... - >>> context = Context( - ... path="example", - ... engine_adapter=EngineAdapter(create_connection, "duckdb"), - ... dialect="duckdb" - ... ) - - ``` - Args: - engine_adapter: The default engine adapter to use + engine_connection_factory: The calllable which creates a new engine connection on each call. + engine_dialect: The engine dialect. scheduler_backend: Identifies which scheduler backend to use. notification_targets: The notification targets to use. - dialect: The default sql dialect of model queries. + dialect: The default sql dialect of model queries. Default: same as engine dialect. physical_schema: The default schema used to store materialized tables. snapshot_ttl: Duration before unpromoted snapshots are removed. time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d. @@ -291,12 +270,11 @@ class Config(PydanticModel): operations (table / view creation, deletion, etc). Default: 1. """ - engine_adapter: EngineAdapter = Field( - default_factory=lambda: EngineAdapter(duckdb.connect, "duckdb") - ) + engine_connection_factory: t.Callable[[], t.Any] = duckdb.connect + engine_dialect: str = "duckdb" scheduler_backend: SchedulerBackend = BuiltInSchedulerBackend() notification_targets: t.List[NotificationTarget] = [] - dialect: str = "" + dialect: t.Optional[str] = None physical_schema: str = "" snapshot_ttl: str = "" ignore_patterns: t.List[str] = [] diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index c70cf92ceb..7b690102ce 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -105,14 +105,24 @@ def __init__( load: bool = True, console: t.Optional[Console] = None, ): + self.console = console or get_console() self.path = Path(path).absolute() self.config = self._load_config(config) + self.test_config = None + try: + self.test_config = self._load_config(test_config or "test_config") + except ConfigError: + self.console.log_error( + "Running without test support since `test_config` was not provided and ` " + "test_config` variable was not found in the namespace" + ) + # Initialize cache cache_path = self.path.joinpath(c.CACHE_PATH) cache_path.mkdir(exist_ok=True) self.table_info_cache = FileCache(cache_path.joinpath(c.TABLE_INFO_CACHE)) - self.dialect = dialect or self.config.dialect + self.dialect = dialect or self.config.dialect or self.config.engine_dialect self.physical_schema = ( physical_schema or self.config.physical_schema or "sqlmesh" ) @@ -122,34 +132,41 @@ def __init__( self.models = UniqueKeyDict("models") self.macros = UniqueKeyDict("macros") self.dag: DAG[str] = DAG() - self.engine_adapter = engine_adapter or self.config.engine_adapter + self.ddl_concurrent_tasks = ( ddl_concurrent_tasks or self.config.ddl_concurrent_tasks ) + + self.engine_adapter = engine_adapter or EngineAdapter( + self.config.engine_connection_factory, + self.config.engine_dialect, + multithreaded=self.ddl_concurrent_tasks > 1, + ) + self.test_engine_adapter = ( + EngineAdapter( + self.test_config.engine_connection_factory, + self.test_config.engine_dialect, + multithreaded=self.test_config.ddl_concurrent_tasks > 1, + ) + if self.test_config + else None + ) + self.snapshot_evaluator = SnapshotEvaluator( self.engine_adapter, ddl_concurrent_tasks=self.ddl_concurrent_tasks ) - self._ignore_patterns = c.IGNORE_PATTERNS + self.config.ignore_patterns - self.console = console or get_console() + + self.notification_targets = self.config.notification_targets + ( + notification_targets or [] + ) self._provided_state_sync: t.Optional[StateSync] = state_sync self._state_sync: t.Optional[StateSync] = None self._state_reader: t.Optional[StateReader] = None - self.notification_targets = self.config.notification_targets + ( - notification_targets or [] - ) - self.test_config = None + self._ignore_patterns = c.IGNORE_PATTERNS + self.config.ignore_patterns self._path_mtimes: t.Dict[Path, float] = {} - try: - self.test_config = self._load_config(test_config or "test_config") - except ConfigError: - self.console.log_error( - "Running without test support since `test_config` was not provided and ` " - "test_config` variable was not found in the namespace" - ) - if load: self.load() @@ -367,10 +384,10 @@ def format(self) -> None: def _run_plan_tests( self, skip_tests: bool = False ) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]: - if self.test_config and not skip_tests: + if self.test_engine_adapter and not skip_tests: result, test_output = self.run_tests() self.console.log_test_results( - result, test_output, self.test_config.engine_adapter.dialect + result, test_output, self.test_engine_adapter.dialect ) if not result.wasSuccessful(): raise PlanError( @@ -522,14 +539,14 @@ def run_tests( self, path: t.Optional[str] = None ) -> t.Tuple[unittest.result.TestResult, str]: """Discover and run model tests""" - if not self.test_config: + if not self.test_engine_adapter: raise ConfigError("Tried to run tests but test_config is not defined") test_output = StringIO() with contextlib.redirect_stderr(test_output): result = run_all_model_tests( path=Path(path) if path else self.test_directory_path, snapshots=self.snapshots, - engine_adapter=self.test_config.engine_adapter, + engine_adapter=self.test_engine_adapter, ignore_patterns=self._ignore_patterns, ) return result, test_output.getvalue() @@ -584,6 +601,10 @@ def audit( self.console.show_sql(f"{error.query}") self.console.log_status_update("Done.") + def close(self): + """Releases all resources allocated by this context.""" + self.snapshot_evaluator.close() + def _context_diff( self, environment: str | Environment, diff --git a/sqlmesh/core/dag.py b/sqlmesh/core/dag.py index 294a9ec6c5..4f60138b4f 100644 --- a/sqlmesh/core/dag.py +++ b/sqlmesh/core/dag.py @@ -14,7 +14,7 @@ class DAG(t.Generic[T]): def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None): - self.graph: t.Dict[T, t.Set[T]] = {} + self._graph: t.Dict[T, t.Set[T]] = {} for node, dependencies in (graph or {}).items(): self.add(node, dependencies) @@ -25,10 +25,10 @@ def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: node: The node to add. dependencies: Optional dependencies to add to the node. """ - if node not in self.graph: - self.graph[node] = set() + if node not in self._graph: + self._graph[node] = set() if dependencies: - self.graph[node].update(dependencies) + self._graph[node].update(dependencies) for d in dependencies: self.add(d) @@ -46,7 +46,7 @@ def subdag(self, *nodes: T) -> DAG[T]: while queue: node = queue.pop() - deps = self.graph.get(node, set()) + deps = self._graph.get(node, set()) graph[node] = deps queue.update(deps) @@ -60,17 +60,24 @@ def upstream(self, node: T) -> t.List[T]: def leaves(self) -> t.Set[T]: """Returns all nodes in the graph without any upstream dependencies.""" return { - dep for deps in self.graph.values() for dep in deps if dep not in self.graph + dep + for deps in self._graph.values() + for dep in deps + if dep not in self._graph } + @property + def graph(self) -> t.Dict[T, t.Set[T]]: + graph = {} + for node, deps in self._graph.items(): + graph[node] = deps.copy() + return graph + def sorted(self) -> t.List[T]: """Returns a list of nodes sorted in topological order.""" result: t.List[T] = [] - unprocessed_nodes = {} - for node, deps in self.graph.items(): - unprocessed_nodes[node] = deps.copy() - + unprocessed_nodes = self.graph while unprocessed_nodes: next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} @@ -103,7 +110,7 @@ def visit() -> t.Iterator[T]: """Visit topologically sorted nodes after input node and yield downstream dependants.""" downstream = {node} for current_node in sorted_nodes[node_index + 1 :]: - upstream = self.graph.get(current_node, set()) + upstream = self._graph.get(current_node, set()) if not upstream.isdisjoint(downstream): downstream.add(current_node) yield current_node diff --git a/sqlmesh/core/engine_adapter.py b/sqlmesh/core/engine_adapter.py index 3bc5d7e5b9..6234b970fe 100644 --- a/sqlmesh/core/engine_adapter.py +++ b/sqlmesh/core/engine_adapter.py @@ -17,6 +17,7 @@ from sqlglot import exp from sqlmesh.utils import optional_import +from sqlmesh.utils.connection_pool import connection_pool from sqlmesh.utils.df import pandas_to_sql from sqlmesh.utils.errors import SQLMeshError @@ -40,27 +41,38 @@ class EngineAdapter: with the underlying engine and data store. Args: - connection: Database API compliant connection. The connection will be - lazily established if a callable that returns a connection is passed in. + connection_factory: a callable which produces a new Database API compliant + connection on every call. dialect: The dialect with which this adapter is associated. + multithreaded: Indicates whether this adapter will be used by more than one thread. """ - def __init__(self, connection: t.Any, dialect: str): - self.connection = connection + def __init__( + self, + connection_factory: t.Callable[[], t.Any], + dialect: str, + multithreaded: bool = False, + ): self.dialect = dialect.lower() - self.spark: t.Optional["pyspark.sql.SparkSession"] = getattr( # type: ignore - connection, "spark", None - ) + self._connection_pool = connection_pool(connection_factory, multithreaded) self._transaction = False @property def cursor(self) -> t.Any: - if not hasattr(self, "_cursor"): - if callable(self.connection): - self._cursor = self.connection().cursor() - else: - self._cursor = self.connection.cursor() - return self._cursor + return self._connection_pool.get_cursor() + + @property + def spark(self) -> t.Optional["pyspark.sql.SparkSession"]: # type: ignore + return getattr(self._connection_pool.get(), "spark", None) + + def recycle(self) -> t.Any: + """Closes all open connections and releases all allocated resources associated with any thread + except the calling one.""" + self._connection_pool.close_all(exclude_calling_thread=True) + + def close(self) -> t.Any: + """Closes all open connections and releases all allocated resources.""" + self._connection_pool.close_all() def create_and_insert( self, @@ -324,6 +336,8 @@ def _insert( expressions=[exp.column(c, quoted=True) for c in columns], ) + connection = self._connection_pool.get() + if ( self.spark and pyspark @@ -340,17 +354,17 @@ def _insert( if ( not overwrite and sqlalchemy - and isinstance(self.connection, sqlalchemy.engine.Connectable) + and isinstance(connection, sqlalchemy.engine.Connectable) ): query_or_df.to_sql( table_name, - self.connection, + connection, if_exists="append", index=False, chunksize=batch_size, method="multi", ) - elif isinstance(self.connection, duckdb.DuckDBPyConnection): + elif isinstance(connection, duckdb.DuckDBPyConnection): self.execute( exp.Insert( this=into, diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 6ec7c9526e..7822234574 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -168,6 +168,8 @@ def run( else: self.console.complete_snapshot_progress() + self.snapshot_evaluator.recycle() + return self.failed def interval_params( diff --git a/sqlmesh/core/snapshot_evaluator.py b/sqlmesh/core/snapshot_evaluator.py index 2d4aa05721..3680bdd371 100644 --- a/sqlmesh/core/snapshot_evaluator.py +++ b/sqlmesh/core/snapshot_evaluator.py @@ -142,6 +142,7 @@ def promote( lambda s: self._promote_snapshot(s, environment), self.ddl_concurrent_tasks, ) + self.recycle() def demote( self, target_snapshots: t.Iterable[SnapshotInfoLike], environment: str @@ -157,6 +158,7 @@ def demote( lambda s: self._demote_snapshot(s, environment), self.ddl_concurrent_tasks, ) + self.recycle() def create( self, @@ -173,6 +175,7 @@ def create( lambda s: self._create_snapshot(s, snapshots), self.ddl_concurrent_tasks, ) + self.recycle() def cleanup(self, target_snapshots: t.Iterable[SnapshotInfoLike]) -> None: """Cleans up the given snapshots by removing its table @@ -186,6 +189,7 @@ def cleanup(self, target_snapshots: t.Iterable[SnapshotInfoLike]) -> None: self.ddl_concurrent_tasks, reverse_order=True, ) + self.recycle() def audit( self, @@ -231,6 +235,21 @@ def audit( results.append(AuditResult(audit=audit, count=count, query=query)) return results + def recycle(self) -> None: + """Closes all open connections and releases all allocated resources associated with any thread + except the calling one.""" + try: + self.adapter.recycle() + except Exception: + logger.exception("Failed to recycle Snapshot Evaluator") + + def close(self) -> None: + """Closes all open connections and releases all allocated resources.""" + try: + self.adapter.close() + except Exception: + logger.exception("Failed to close Snapshot Evaluator") + def _create_snapshot( self, snapshot: Snapshot, snapshots: t.Dict[SnapshotId, Snapshot] ) -> None: diff --git a/sqlmesh/engines/spark/app.py b/sqlmesh/engines/spark/app.py index e49f053555..d9d59784c4 100644 --- a/sqlmesh/engines/spark/app.py +++ b/sqlmesh/engines/spark/app.py @@ -37,7 +37,10 @@ def main() -> None: spark = create_spark_session() connection = spark_session_db.connection(spark) evaluator = SnapshotEvaluator( - EngineAdapter(connection, "spark"), ddl_concurrent_tasks=ddl_concurrent_tasks + EngineAdapter( + lambda: connection, "spark", multithreaded=ddl_concurrent_tasks > 1 + ), + ddl_concurrent_tasks=ddl_concurrent_tasks, ) with open(SparkFiles.get(commands.COMMAND_PAYLOAD_FILE_NAME), "r") as payload_fd: @@ -46,6 +49,8 @@ def main() -> None: command_handler(evaluator, command_payload) + evaluator.close() + if __name__ == "__main__": main() diff --git a/sqlmesh/schedulers/airflow/operators/databricks.py b/sqlmesh/schedulers/airflow/operators/databricks.py index ae59480ae1..acad09a7e4 100644 --- a/sqlmesh/schedulers/airflow/operators/databricks.py +++ b/sqlmesh/schedulers/airflow/operators/databricks.py @@ -34,5 +34,4 @@ def get_db_hook(self) -> DatabricksSqlHook: def execute(self, context: Context) -> None: """Executes the desired target against the configured Databricks connection""" - connection = self.get_db_hook().get_conn() - self._target.execute(context, connection, "spark") + self._target.execute(context, lambda: self.get_db_hook().get_conn(), "spark") diff --git a/sqlmesh/schedulers/airflow/operators/targets.py b/sqlmesh/schedulers/airflow/operators/targets.py index 1e84841de1..6e855c85c0 100644 --- a/sqlmesh/schedulers/airflow/operators/targets.py +++ b/sqlmesh/schedulers/airflow/operators/targets.py @@ -35,21 +35,24 @@ def serialized_command_payload(self, context: Context) -> str: """ return self._get_command_payload_or_skip(context).json() - def execute(self, context: Context, connection: t.Any, dialect: str) -> None: + def execute( + self, context: Context, connection_factory: t.Callable[[], t.Any], dialect: str + ) -> None: """Executes this target. Args: context: Airflow task context. - connection: Database API compliant connection. The connection will be - lazily established if a callable that returns a connection is passed in. + connection_factory: a callable which produces a new Database API compliant + connection on every call. dialect: The dialect with which this adapter is associated. """ payload = self._get_command_payload_or_skip(context) snapshot_evaluator = SnapshotEvaluator( - EngineAdapter(connection, dialect), + EngineAdapter(connection_factory, dialect), ddl_concurrent_tasks=self.ddl_concurrent_tasks, ) self.command_handler(snapshot_evaluator, payload) + snapshot_evaluator.close() self.post_hook(context) def post_hook(self, context: Context, **kwargs) -> None: diff --git a/sqlmesh/utils/connection_pool.py b/sqlmesh/utils/connection_pool.py new file mode 100644 index 0000000000..994f5caa77 --- /dev/null +++ b/sqlmesh/utils/connection_pool.py @@ -0,0 +1,138 @@ +import abc +import logging +import typing as t +from threading import Lock, get_ident + +logger = logging.getLogger(__name__) + + +class ConnectionPool(abc.ABC): + def get_cursor(self) -> t.Any: + """Returns cached cursor instance. + + Automatically creates a new instance if one is not available. + + Returns: + A cursor instance. + """ + + def get(self) -> t.Any: + """Returns cached connection instance. + + Automatically opens a new connection if one is not available. + + Returns: + A connection instance. + """ + + def close_cursor(self) -> None: + """Closes the current cursor instance if exists.""" + + def close(self) -> None: + """Closes the current connection instance if exists. + + Note: if there is a cursor instance available it will be closed as well. + """ + + def close_all(self, exclude_calling_thread: bool = False) -> None: + """Closes all cached cursors and connections. + + Args: + exclude_calling_thread: If set to True excludes cursors and connections associated + with the calling thread. + """ + + +class ThreadLocalConnectionPool(ConnectionPool): + def __init__(self, connection_factory: t.Callable[[], t.Any]): + self._connection_factory = connection_factory + self._thread_connections: t.Dict[t.Hashable, t.Any] = {} + self._thread_cursors: t.Dict[t.Hashable, t.Any] = {} + self._thread_connections_lock = Lock() + self._thread_cursors_lock = Lock() + + def get_cursor(self) -> t.Any: + thread_id = get_ident() + with self._thread_cursors_lock: + if thread_id not in self._thread_cursors: + self._thread_cursors[thread_id] = self.get().cursor() + return self._thread_cursors[thread_id] + + def get(self) -> t.Any: + thread_id = get_ident() + with self._thread_connections_lock: + if thread_id not in self._thread_connections: + self._thread_connections[thread_id] = self._connection_factory() + return self._thread_connections[thread_id] + + def close_cursor(self) -> None: + thread_id = get_ident() + with self._thread_cursors_lock: + if thread_id in self._thread_cursors: + _try_close(self._thread_cursors[thread_id], "cursor") + self._thread_cursors.pop(thread_id) + + def close(self) -> None: + thread_id = get_ident() + with self._thread_cursors_lock, self._thread_connections_lock: + if thread_id in self._thread_connections: + _try_close(self._thread_connections[thread_id], "connection") + self._thread_connections.pop(thread_id) + self._thread_cursors.pop(thread_id, None) + + def close_all(self, exclude_calling_thread: bool = False) -> None: + calling_thread_id = get_ident() + with self._thread_cursors_lock, self._thread_connections_lock: + for thread_id, connection in self._thread_connections.copy().items(): + if not exclude_calling_thread or thread_id != calling_thread_id: + # NOTE: the access to the connection instance itself is not thread-safe here. + _try_close(connection, "connection") + self._thread_connections.pop(thread_id) + self._thread_cursors.pop(thread_id, None) + + +class SingletonConnectionPool(ConnectionPool): + def __init__(self, connection_factory: t.Callable[[], t.Any]): + self._connection_factory = connection_factory + self._connection: t.Optional[t.Any] = None + self._cursor: t.Optional[t.Any] = None + + def get_cursor(self) -> t.Any: + if not self._cursor: + self._cursor = self.get().cursor() + return self._cursor + + def get(self) -> t.Any: + if not self._connection: + self._connection = self._connection_factory() + return self._connection + + def close_cursor(self) -> None: + _try_close(self._cursor, "cursor") + self._cursor = None + + def close(self) -> None: + _try_close(self._connection, "connection") + self._connection = None + self._cursor = None + + def close_all(self, exclude_calling_thread: bool = False) -> None: + if not exclude_calling_thread: + self.close() + + +def connection_pool( + connection_factory: t.Callable[[], t.Any], multithreaded: bool +) -> ConnectionPool: + return ( + ThreadLocalConnectionPool(connection_factory) + if multithreaded + else SingletonConnectionPool(connection_factory) + ) + + +def _try_close(closeable: t.Any, kind: str) -> None: + try: + closeable.close() + except Exception: + logger.exception("Failed to close %s", kind) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 50032e43e2..aa977c2f55 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1,7 +1,6 @@ import pathlib from datetime import date -import duckdb import pytest from pytest_mock.plugin import MockerFixture from sqlglot import parse_one @@ -9,7 +8,6 @@ import sqlmesh.core.constants from sqlmesh.core.config import Config from sqlmesh.core.context import Context -from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.plan import Plan from sqlmesh.core.plan_evaluator import BuiltInPlanEvaluator from sqlmesh.utils.errors import ConfigError @@ -306,5 +304,5 @@ def test_incremental_model_without_partition_support(tmpdir) -> None: ): Context( path=str(tmpdir), - config=Config(engine_adapter=EngineAdapter(duckdb.connect, "duckdb")), + config=Config(), ) diff --git a/tests/core/test_engine_adapter.py b/tests/core/test_engine_adapter.py index 7ab106699c..8bd951a3ea 100644 --- a/tests/core/test_engine_adapter.py +++ b/tests/core/test_engine_adapter.py @@ -14,7 +14,7 @@ def test_create_view(mocker: MockerFixture): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.create_view("test_view", parse_one("SELECT a FROM tbl")) adapter.create_view("test_view", parse_one("SELECT a FROM tbl"), replace=False) @@ -29,7 +29,7 @@ def test_create_schema(mocker: MockerFixture): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.create_schema("test_schema") adapter.create_schema("test_schema", ignore_if_exists=False) @@ -44,7 +44,7 @@ def test_table_exists(mocker: MockerFixture): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore assert adapter.table_exists("test_table") cursor_mock.execute.assert_called_once_with( "DESCRIBE TABLE test_table", @@ -54,7 +54,7 @@ def test_table_exists(mocker: MockerFixture): cursor_mock.execute.side_effect = RuntimeError("error") connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore assert not adapter.table_exists("test_table") cursor_mock.execute.assert_called_once_with( "DESCRIBE TABLE test_table", @@ -66,7 +66,7 @@ def test_insert_overwrite(mocker: MockerFixture): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.insert_overwrite( "test_table", parse_one("SELECT a FROM tbl"), columns=["a"] ) @@ -81,7 +81,7 @@ def test_insert_append(mocker: MockerFixture): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.insert_append("test_table", parse_one("SELECT a FROM tbl"), columns=["a"]) cursor_mock.execute.assert_called_once_with( @@ -94,7 +94,7 @@ def test_delete_insert_query(mocker: MockerFixture): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.delete_insert_query( "test_table", parse_one("SELECT a FROM tbl"), @@ -117,7 +117,7 @@ def test_create_and_insert(mocker: MockerFixture): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.create_and_insert( "test_table", {"a": exp.DataType.build("bigint")}, @@ -140,7 +140,7 @@ def test_create_table(mocker: MockerFixture): "colb": exp.DataType.build("TEXT"), } - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.create_table("test_table", column_mapping) cursor_mock.execute.assert_called_once_with( @@ -158,7 +158,7 @@ def test_create_table_properties(mocker: MockerFixture): "colb": exp.DataType.build("TEXT"), } - adapter = EngineAdapter(connection_mock, "spark") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "spark") # type: ignore adapter.create_table( "test_table", column_mapping, @@ -181,7 +181,7 @@ def test_create_table_properties_ignored(mocker: MockerFixture): "colb": exp.DataType.build("TEXT"), } - adapter = EngineAdapter(connection_mock, "duckdb") # type: ignore + adapter = EngineAdapter(lambda: connection_mock, "duckdb") # type: ignore adapter.create_table( "test_table", column_mapping, @@ -194,23 +194,10 @@ def test_create_table_properties_ignored(mocker: MockerFixture): ) -def test_lazy_connection(mocker: MockerFixture): - create_connection_mock = mocker.Mock() - adapter = EngineAdapter(create_connection_mock, "duckdb") - create_connection_mock.assert_not_called() - - cursor1 = adapter.cursor - create_connection_mock.assert_called_once() - - cursor2 = adapter.cursor - create_connection_mock.assert_called_once() - assert cursor1 == cursor2 - - @pytest.fixture def adapter(duck_conn): duck_conn.execute("CREATE VIEW tbl AS SELECT 1 AS a") - return EngineAdapter(duck_conn, "duckdb") + return EngineAdapter(lambda: duck_conn, "duckdb") def test_create_view_duckdb(adapter: EngineAdapter, duck_conn): diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 9ab24b4ce3..99b54d6451 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -126,7 +126,7 @@ def test_evaluate_creation_duckdb( duck_conn, date_kwargs: t.Dict[str, str], ): - evaluator = SnapshotEvaluator(EngineAdapter(duck_conn, "duckdb")) + evaluator = SnapshotEvaluator(EngineAdapter(lambda: duck_conn, "duckdb")) evaluator.create([snapshot], {}) version = snapshot.version diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index c67ba7a14c..e6a20d124f 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -17,7 +17,7 @@ @pytest.fixture def state_sync(duck_conn, mock_file_cache): state_sync = EngineAdapterStateSync( - EngineAdapter(duck_conn, "duckdb"), + EngineAdapter(lambda: duck_conn, "duckdb"), "sqlmesh", mock_file_cache, ) diff --git a/tests/schedulers/airflow/operators/test_targets.py b/tests/schedulers/airflow/operators/test_targets.py index 53c3d45268..1a4eea36b0 100644 --- a/tests/schedulers/airflow/operators/test_targets.py +++ b/tests/schedulers/airflow/operators/test_targets.py @@ -48,7 +48,7 @@ def test_evaluation_target_execute( target = targets.SnapshotEvaluationTarget( snapshot=snapshot, table_mapping=table_mapping ) - target.execute(context, None, "spark") + target.execute(context, lambda: None, "spark") evaluator_evaluate_mock.assert_called_once_with( snapshot, interval_ds, interval_ds, logical_ds, {}, mapping=table_mapping @@ -81,7 +81,7 @@ def test_table_cleanup_target_execute( target = targets.SnapshotTableCleanupTarget() - target.execute(context, None, "spark") + target.execute(context, lambda: None, "spark") evaluator_cleanup_mock.assert_called_once_with([snapshot.table_info]) @@ -114,7 +114,7 @@ def test_table_cleanup_target_skip_execution( target = targets.SnapshotTableCleanupTarget() with pytest.raises(AirflowSkipException): - target.execute(context, None, "spark") + target.execute(context, lambda: None, "spark") evaluator_cleanup_mock.assert_not_called() delete_xcom_mock.assert_called_once() diff --git a/tests/utils/test_connection_pool.py b/tests/utils/test_connection_pool.py new file mode 100644 index 0000000000..71268f0cf3 --- /dev/null +++ b/tests/utils/test_connection_pool.py @@ -0,0 +1,119 @@ +from concurrent.futures import ThreadPoolExecutor +from threading import get_ident + +from pytest_mock.plugin import MockerFixture + +from sqlmesh.utils.connection_pool import ( + SingletonConnectionPool, + ThreadLocalConnectionPool, +) + + +def test_singleton_connection_pool_get(mocker: MockerFixture): + cursor_mock = mocker.Mock() + connection_mock = mocker.Mock() + connection_mock.cursor.return_value = cursor_mock + connection_factory_mock = mocker.Mock(return_value=connection_mock) + + pool = SingletonConnectionPool(connection_factory_mock) + + assert pool.get_cursor() == cursor_mock + assert pool.get_cursor() == cursor_mock + assert pool.get() == connection_mock + assert pool.get() == connection_mock + + connection_factory_mock.assert_called_once() + connection_mock.cursor.assert_called_once() + + +def test_singleton_connection_pool_close(mocker: MockerFixture): + cursor_mock = mocker.Mock() + connection_mock = mocker.Mock() + connection_mock.cursor.return_value = cursor_mock + connection_factory_mock = mocker.Mock(return_value=connection_mock) + + pool = SingletonConnectionPool(connection_factory_mock) + + pool.close() + pool.close_cursor() + pool.close_all() + pool.close_all(exclude_calling_thread=True) + + assert pool.get_cursor() == cursor_mock + pool.close_cursor() + + assert pool.get_cursor() == cursor_mock + pool.close() + + assert pool.get_cursor() == cursor_mock + pool.close_all() + + assert pool.get_cursor() == cursor_mock + pool.close_all(exclude_calling_thread=True) + + assert connection_mock.close.call_count == 2 + assert connection_mock.cursor.call_count == 4 + assert cursor_mock.close.call_count == 1 + assert connection_factory_mock.call_count == 3 + + +def test_thread_local_connection_pool(mocker: MockerFixture): + cursor_mock_thread_one = mocker.Mock() + connection_mock_thread_one = mocker.Mock() + connection_mock_thread_one.cursor.return_value = cursor_mock_thread_one + + cursor_mock_thread_two = mocker.Mock() + connection_mock_thread_two = mocker.Mock() + connection_mock_thread_two.cursor.return_value = cursor_mock_thread_two + + test_thread_id = get_ident() + + def connection_factory(): + return ( + connection_mock_thread_one + if get_ident() == test_thread_id + else connection_mock_thread_two + ) + + connection_factory_mock = mocker.Mock(side_effect=connection_factory) + pool = ThreadLocalConnectionPool(connection_factory_mock) + + def thread(): + assert pool.get_cursor() == cursor_mock_thread_two + assert pool.get_cursor() == cursor_mock_thread_two + assert pool.get() == connection_mock_thread_two + assert pool.get() == connection_mock_thread_two + + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit(thread).result() + + assert pool.get_cursor() == cursor_mock_thread_one + assert pool.get_cursor() == cursor_mock_thread_one + assert pool.get() == connection_mock_thread_one + assert pool.get() == connection_mock_thread_one + + assert len(pool._thread_connections) == 2 + assert len(pool._thread_cursors) == 2 + + pool.close_all(exclude_calling_thread=True) + + assert len(pool._thread_connections) == 1 + assert len(pool._thread_cursors) == 1 + assert test_thread_id in pool._thread_connections + assert test_thread_id in pool._thread_cursors + + pool.close_cursor() + pool.close() + + assert pool.get_cursor() == cursor_mock_thread_one + + pool.close_all() + + assert connection_factory_mock.call_count == 3 + + assert cursor_mock_thread_one.close.call_count == 1 + assert connection_mock_thread_one.cursor.call_count == 2 + assert connection_mock_thread_one.close.call_count == 2 + + assert connection_mock_thread_two.cursor.call_count == 1 + assert connection_mock_thread_two.close.call_count == 1 From d28d85f3158c7127f822d3e2f740f2c8c598fc48 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Thu, 8 Dec 2022 15:34:38 -0800 Subject: [PATCH 2/2] Add multithread context manager to the snapshot evaluator --- sqlmesh/core/config.py | 4 +-- sqlmesh/core/scheduler.py | 4 +-- sqlmesh/core/snapshot_evaluator.py | 58 +++++++++++++++++------------- 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/sqlmesh/core/config.py b/sqlmesh/core/config.py index d3544aa5a9..6cb98ef5b6 100644 --- a/sqlmesh/core/config.py +++ b/sqlmesh/core/config.py @@ -11,7 +11,7 @@ import duckdb from sqlmesh.core.engine_adapter import EngineAdapter local_config = Config( - engine_config_factory=duckdb.connect, + engine_connection_factory=duckdb.connect, engine_dialect="duckdb" ) # End config.py @@ -274,7 +274,7 @@ class Config(PydanticModel): engine_dialect: str = "duckdb" scheduler_backend: SchedulerBackend = BuiltInSchedulerBackend() notification_targets: t.List[NotificationTarget] = [] - dialect: t.Optional[str] = None + dialect: str = "" physical_schema: str = "" snapshot_ttl: str = "" ignore_patterns: t.List[str] = [] diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 7822234574..0f3634260b 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -136,7 +136,7 @@ def run( # We have to run all batches per snapshot to mark it as completed self.console.start_snapshot_progress(snapshot.name, len(intervals)) - with ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor( + with self.snapshot_evaluator.multithreaded_context(), ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor( max_workers=self.max_workers ) as batch_pool: while True: @@ -168,8 +168,6 @@ def run( else: self.console.complete_snapshot_progress() - self.snapshot_evaluator.recycle() - return self.failed def interval_params( diff --git a/sqlmesh/core/snapshot_evaluator.py b/sqlmesh/core/snapshot_evaluator.py index 3680bdd371..3ba6a71bae 100644 --- a/sqlmesh/core/snapshot_evaluator.py +++ b/sqlmesh/core/snapshot_evaluator.py @@ -23,6 +23,7 @@ import logging import typing as t +from contextlib import contextmanager from sqlglot import exp, select @@ -137,12 +138,12 @@ def promote( target_snapshots: Snapshots to promote. environment: The target environment. """ - concurrent_apply_to_snapshots( - target_snapshots, - lambda s: self._promote_snapshot(s, environment), - self.ddl_concurrent_tasks, - ) - self.recycle() + with self.multithreaded_context(): + concurrent_apply_to_snapshots( + target_snapshots, + lambda s: self._promote_snapshot(s, environment), + self.ddl_concurrent_tasks, + ) def demote( self, target_snapshots: t.Iterable[SnapshotInfoLike], environment: str @@ -153,12 +154,12 @@ def demote( target_snapshots: Snapshots to demote. environment: The target environment. """ - concurrent_apply_to_snapshots( - target_snapshots, - lambda s: self._demote_snapshot(s, environment), - self.ddl_concurrent_tasks, - ) - self.recycle() + with self.multithreaded_context(): + concurrent_apply_to_snapshots( + target_snapshots, + lambda s: self._demote_snapshot(s, environment), + self.ddl_concurrent_tasks, + ) def create( self, @@ -170,12 +171,12 @@ def create( Args: target_snapshots: Target snapshost. """ - concurrent_apply_to_snapshots( - target_snapshots, - lambda s: self._create_snapshot(s, snapshots), - self.ddl_concurrent_tasks, - ) - self.recycle() + with self.multithreaded_context(): + concurrent_apply_to_snapshots( + target_snapshots, + lambda s: self._create_snapshot(s, snapshots), + self.ddl_concurrent_tasks, + ) def cleanup(self, target_snapshots: t.Iterable[SnapshotInfoLike]) -> None: """Cleans up the given snapshots by removing its table @@ -183,13 +184,13 @@ def cleanup(self, target_snapshots: t.Iterable[SnapshotInfoLike]) -> None: Args: target_snapshots: Snapshots to cleanup. """ - concurrent_apply_to_snapshots( - target_snapshots, - self._cleanup_snapshot, - self.ddl_concurrent_tasks, - reverse_order=True, - ) - self.recycle() + with self.multithreaded_context(): + concurrent_apply_to_snapshots( + target_snapshots, + self._cleanup_snapshot, + self.ddl_concurrent_tasks, + reverse_order=True, + ) def audit( self, @@ -235,6 +236,13 @@ def audit( results.append(AuditResult(audit=audit, count=count, query=query)) return results + @contextmanager + def multithreaded_context(self) -> t.Generator[None, None, None]: + try: + yield + finally: + self.recycle() + def recycle(self) -> None: """Closes all open connections and releases all allocated resources associated with any thread except the calling one."""