From df611d1c47810512f4e2d1c5340ced5e8cc01413 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 9 Dec 2022 10:45:40 -0800 Subject: [PATCH] Add configuration parameters for the number of concurrent tasks used for model evaluation and backfilling --- example/config.py | 11 +++++++-- sqlmesh/core/config.py | 8 +++++++ sqlmesh/core/context.py | 24 +++++++++++++++++-- sqlmesh/core/plan_evaluator.py | 6 +++++ sqlmesh/engines/spark/app.py | 9 +++---- sqlmesh/engines/spark/db_api/spark_session.py | 7 +++--- sqlmesh/schedulers/airflow/client.py | 2 ++ sqlmesh/schedulers/airflow/common.py | 2 ++ sqlmesh/schedulers/airflow/dag_generator.py | 1 + sqlmesh/schedulers/airflow/integration.py | 3 ++- tests/core/test_plan_evaluator.py | 3 ++- tests/schedulers/airflow/test_client.py | 1 + tests/schedulers/airflow/test_integration.py | 4 ++++ 13 files changed, 67 insertions(+), 14 deletions(-) diff --git a/example/config.py b/example/config.py index 6485d8749e..b68f0ecab7 100644 --- a/example/config.py +++ b/example/config.py @@ -30,9 +30,16 @@ # A config that uses Airflow + Spark. +DEFAULT_AIRFLOW_KWARGS = { + **DEFAULT_KWARGS, + "backfill_concurrent_tasks": 4, + "ddl_concurrent_tasks": 4, +} + + airflow_config = Config( **{ - **DEFAULT_KWARGS, + **DEFAULT_AIRFLOW_KWARGS, "scheduler_backend": AirflowSchedulerBackend(), } ) @@ -40,7 +47,7 @@ airflow_config_docker = Config( **{ - **DEFAULT_KWARGS, + **DEFAULT_AIRFLOW_KWARGS, "scheduler_backend": AirflowSchedulerBackend( airflow_url="http://airflow-webserver:8080/" ), diff --git a/sqlmesh/core/config.py b/sqlmesh/core/config.py index 0f68714f0b..46674f1d15 100644 --- a/sqlmesh/core/config.py +++ b/sqlmesh/core/config.py @@ -173,6 +173,7 @@ def create_plan_evaluator(self, context: Context) -> PlanEvaluator: return BuiltInPlanEvaluator( state_sync=context.state_sync, snapshot_evaluator=context.snapshot_evaluator, + backfill_concurrent_tasks=context.backfill_concurrent_tasks, console=context.console, ) @@ -218,6 +219,7 @@ def create_plan_evaluator(self, context: Context) -> PlanEvaluator: dag_creation_max_retry_attempts=self.dag_creation_max_retry_attempts, console=context.console, notification_targets=context.notification_targets, + backfill_concurrent_tasks=context.backfill_concurrent_tasks, ddl_concurrent_tasks=context.ddl_concurrent_tasks, ) @@ -267,8 +269,12 @@ class Config(PydanticModel): snapshot_ttl: Duration before unpromoted snapshots are removed. time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d. This time format uses python format codes. https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes. + backfill_concurrent_tasks: The number of concurrent tasks used for model backfilling during + plan application. Default: 1. ddl_concurrent_task: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). Default: 1. + evaluation_concurrent_tasks: The number of concurrent tasks used for model evaluation when + running with the built-in scheduler. Default: 1. """ engine_connection_factory: t.Callable[[], t.Any] = duckdb.connect @@ -280,7 +286,9 @@ class Config(PydanticModel): snapshot_ttl: str = "" ignore_patterns: t.List[str] = [] time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT + backfill_concurrent_tasks: int = 1 ddl_concurrent_tasks: int = 1 + evaluation_concurrent_tasks: int = 1 class Config: arbitrary_types_allowed = True diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 76b4dfd84d..bd252fdd5e 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -149,8 +149,12 @@ class Context(BaseContext): physical_schema: The schema used to store physical materialized tables. snapshot_ttl: Duration before unpromoted snapshots are removed. path: The directory containing SQLMesh files. + backfill_concurrent_tasks: The number of concurrent tasks used for model backfilling during + plan application. Default: 1. ddl_concurrent_task: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). Default: 1. + evaluation_concurrent_tasks: The number of concurrent tasks used for model evaluation when + running with the built-in scheduler. Default: 1. config: A Config object or the name of a Config object in config.py. test_config: A Config object or name of a Config object in config.py to use for testing only load: Whether or not to automatically load all models and macros (default True). @@ -166,7 +170,9 @@ def __init__( physical_schema: str = "", snapshot_ttl: str = "", path: str = "", + backfill_concurrent_tasks: t.Optional[int] = None, ddl_concurrent_tasks: t.Optional[int] = None, + evaluation_concurrent_tasks: t.Optional[int] = None, config: t.Optional[t.Union[Config, str]] = None, test_config: t.Optional[t.Union[Config, str]] = None, load: bool = True, @@ -200,20 +206,33 @@ def __init__( self.macros = UniqueKeyDict("macros") self.dag: DAG[str] = DAG() + self.backfill_concurrent_tasks = ( + backfill_concurrent_tasks or self.config.backfill_concurrent_tasks + ) self.ddl_concurrent_tasks = ( ddl_concurrent_tasks or self.config.ddl_concurrent_tasks ) + self.evaluation_concurrent_tasks = ( + evaluation_concurrent_tasks or self.config.evaluation_concurrent_tasks + ) + self.is_multithreaded = ( + max( + self.backfill_concurrent_tasks, + self.ddl_concurrent_tasks, + self.evaluation_concurrent_tasks, + ) + > 1 + ) self._engine_adapter = engine_adapter or EngineAdapter( self.config.engine_connection_factory, self.config.engine_dialect, - multithreaded=self.ddl_concurrent_tasks > 1, + multithreaded=self.is_multithreaded, ) self.test_engine_adapter = ( EngineAdapter( self.test_config.engine_connection_factory, self.test_config.engine_dialect, - multithreaded=self.test_config.ddl_concurrent_tasks > 1, ) if self.test_config else None @@ -270,6 +289,7 @@ def scheduler(self) -> Scheduler: self.snapshots, self.snapshot_evaluator, self.state_sync, + max_workers=self.evaluation_concurrent_tasks, console=self.console, ) diff --git a/sqlmesh/core/plan_evaluator.py b/sqlmesh/core/plan_evaluator.py index 54eaf7e1df..e4b2fe2872 100644 --- a/sqlmesh/core/plan_evaluator.py +++ b/sqlmesh/core/plan_evaluator.py @@ -50,10 +50,12 @@ def __init__( self, state_sync: StateSync, snapshot_evaluator: SnapshotEvaluator, + backfill_concurrent_tasks: int = 1, console: t.Optional[Console] = None, ): self.state_sync = state_sync self.snapshot_evaluator = snapshot_evaluator + self.backfill_concurrent_tasks = backfill_concurrent_tasks self.console = console or get_console() def evaluate(self, plan: Plan) -> None: @@ -75,6 +77,7 @@ def evaluate(self, plan: Plan) -> None: {snapshot.name: snapshot for snapshot in snapshots}, self.snapshot_evaluator, self.state_sync, + max_workers=self.backfill_concurrent_tasks, console=self.console, ) scheduler.run(snapshots, plan.start, plan.end) @@ -138,6 +141,7 @@ def __init__( dag_creation_poll_interval_secs: int = 30, dag_creation_max_retry_attempts: int = 10, notification_targets: t.Optional[t.List[NotificationTarget]] = None, + backfill_concurrent_tasks: int = 1, ddl_concurrent_tasks: int = 1, ): self.airflow_client = airflow_client @@ -147,6 +151,7 @@ def __init__( self.dag_creation_max_retry_attempts = dag_creation_max_retry_attempts self.console = console or get_console() self.notification_targets = notification_targets or [] + self.backfill_concurrent_tasks = backfill_concurrent_tasks self.ddl_concurrent_tasks = ddl_concurrent_tasks def evaluate(self, plan: Plan) -> None: @@ -161,6 +166,7 @@ def evaluate(self, plan: Plan) -> None: no_gaps=plan.no_gaps, restatements=plan.restatements, notification_targets=self.notification_targets, + backfill_concurrent_tasks=self.backfill_concurrent_tasks, ddl_concurrent_tasks=self.ddl_concurrent_tasks, ) diff --git a/sqlmesh/engines/spark/app.py b/sqlmesh/engines/spark/app.py index d9d59784c4..1fa6343f76 100644 --- a/sqlmesh/engines/spark/app.py +++ b/sqlmesh/engines/spark/app.py @@ -32,13 +32,14 @@ def main() -> None: if not command_handler: raise NotSupportedError(f"Command '{command_type.value}' not supported") - ddl_concurrent_tasks = int(sys.argv[2]) if len(sys.argv) > 2 else 1 - spark = create_spark_session() - connection = spark_session_db.connection(spark) + + ddl_concurrent_tasks = int(sys.argv[2]) if len(sys.argv) > 2 else 1 evaluator = SnapshotEvaluator( EngineAdapter( - lambda: connection, "spark", multithreaded=ddl_concurrent_tasks > 1 + lambda: spark_session_db.connection(spark), + "spark", + multithreaded=ddl_concurrent_tasks > 1, ), ddl_concurrent_tasks=ddl_concurrent_tasks, ) diff --git a/sqlmesh/engines/spark/db_api/spark_session.py b/sqlmesh/engines/spark/db_api/spark_session.py index 1670fed597..dc59c110c7 100644 --- a/sqlmesh/engines/spark/db_api/spark_session.py +++ b/sqlmesh/engines/spark/db_api/spark_session.py @@ -18,10 +18,6 @@ def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None: if parameters: raise NotSupportedError("Parameterized queries are not supported") - self._spark.sparkContext.setLocalProperty( - "spark.scheduler.pool", f"pool_{get_ident()}" - ) - self._last_df = self._spark.sql(query) self._last_output = None self._last_output_cursor = 0 @@ -71,6 +67,9 @@ def __init__(self, spark: SparkSession): self.spark = spark def cursor(self) -> SparkSessionCursor: + self.spark.sparkContext.setLocalProperty( + "spark.scheduler.pool", f"pool_{get_ident()}" + ) return SparkSessionCursor(self.spark) def commit(self) -> None: diff --git a/sqlmesh/schedulers/airflow/client.py b/sqlmesh/schedulers/airflow/client.py index 72e0277b1c..1169afd1d2 100644 --- a/sqlmesh/schedulers/airflow/client.py +++ b/sqlmesh/schedulers/airflow/client.py @@ -50,6 +50,7 @@ def apply_plan( no_gaps: bool = False, restatements: t.Optional[t.Iterable[str]] = None, notification_targets: t.Optional[t.List[NotificationTarget]] = None, + backfill_concurrent_tasks: int = 1, ddl_concurrent_tasks: int = 1, timestamp: t.Optional[datetime] = None, ) -> str: @@ -63,6 +64,7 @@ def apply_plan( request_id=request_id, restatements=set(restatements or []), notification_targets=notification_targets or [], + backfill_concurrent_tasks=backfill_concurrent_tasks, ddl_concurrent_tasks=ddl_concurrent_tasks, ), dag_run_id=common.INIT_RUN_ID if is_first_run else None, diff --git a/sqlmesh/schedulers/airflow/common.py b/sqlmesh/schedulers/airflow/common.py index add9ce4c99..e1d9784817 100644 --- a/sqlmesh/schedulers/airflow/common.py +++ b/sqlmesh/schedulers/airflow/common.py @@ -47,6 +47,7 @@ class PlanReceiverDagConf(PydanticModel): no_gaps: bool restatements: t.Set[str] notification_targets: t.List[NotificationTarget] + backfill_concurrent_tasks: int ddl_concurrent_tasks: int @@ -68,6 +69,7 @@ class PlanApplicationRequest(PydanticModel): plan_id: str previous_plan_id: t.Optional[str] notification_targets: t.List[NotificationTarget] + backfill_concurrent_tasks: int ddl_concurrent_tasks: int diff --git a/sqlmesh/schedulers/airflow/dag_generator.py b/sqlmesh/schedulers/airflow/dag_generator.py index e07966af93..ba925bef37 100644 --- a/sqlmesh/schedulers/airflow/dag_generator.py +++ b/sqlmesh/schedulers/airflow/dag_generator.py @@ -129,6 +129,7 @@ def _create_plan_application_dag( dag_id=dag_id, schedule_interval="@once", start_date=now(), + max_active_tasks=request.backfill_concurrent_tasks, catchup=False, is_paused_upon_creation=False, tags=[ diff --git a/sqlmesh/schedulers/airflow/integration.py b/sqlmesh/schedulers/airflow/integration.py index 4e671e5974..a93fa3f4a8 100644 --- a/sqlmesh/schedulers/airflow/integration.py +++ b/sqlmesh/schedulers/airflow/integration.py @@ -112,7 +112,7 @@ def dags(self) -> t.List[DAG]: def _create_plan_receiver_dag(self) -> DAG: dag = self._create_system_dag(common.PLAN_RECEIVER_DAG_ID, None) - receiver_task = PythonOperator( + PythonOperator( task_id=common.PLAN_RECEIVER_TASK_ID, python_callable=_plan_receiver_task, dag=dag, @@ -234,6 +234,7 @@ def _plan_receiver_task( plan_id=plan_conf.environment.plan_id, previous_plan_id=plan_conf.environment.previous_plan_id, notification_targets=plan_conf.notification_targets, + backfill_concurrent_tasks=plan_conf.backfill_concurrent_tasks, ddl_concurrent_tasks=plan_conf.ddl_concurrent_tasks, ) diff --git a/tests/core/test_plan_evaluator.py b/tests/core/test_plan_evaluator.py index f370b8f69b..45ea35b990 100644 --- a/tests/core/test_plan_evaluator.py +++ b/tests/core/test_plan_evaluator.py @@ -66,7 +66,7 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot): evaluator = BuiltInPlanEvaluator( sushi_context.state_sync, sushi_context.snapshot_evaluator, - sushi_context.console, + console=sushi_context.console, ) evaluator._push(plan) @@ -100,6 +100,7 @@ def test_airflow_evaluator(sushi_plan: Plan, mocker: MockerFixture): no_gaps=False, notification_targets=[], restatements=set(), + backfill_concurrent_tasks=1, ddl_concurrent_tasks=1, ) diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index c8067f8ff5..1955147c8b 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -133,6 +133,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot, dag_run_entries: "notification_targets": [], "request_id": request_id, "restatements": [], + "backfill_concurrent_tasks": 1, "ddl_concurrent_tasks": 1, }, "dag_run_id": expected_dag_run_id, diff --git a/tests/schedulers/airflow/test_integration.py b/tests/schedulers/airflow/test_integration.py index 2f92303a34..194ddfef55 100644 --- a/tests/schedulers/airflow/test_integration.py +++ b/tests/schedulers/airflow/test_integration.py @@ -122,6 +122,7 @@ def test_plan_receiver_task(mocker: MockerFixture, make_snapshot, random_name): no_gaps=True, restatements={"raw.items"}, notification_targets=[], + backfill_concurrent_tasks=1, ddl_concurrent_tasks=1, ) @@ -183,6 +184,7 @@ def test_plan_receiver_task(mocker: MockerFixture, make_snapshot, random_name): plan_id="test_plan_id", previous_plan_id=None, notification_targets=[], + backfill_concurrent_tasks=1, ddl_concurrent_tasks=1, ) @@ -204,6 +206,7 @@ def test_plan_receiver_task_duplicated_snapshot( no_gaps=False, restatements=set(), notification_targets=[], + backfill_concurrent_tasks=1, ddl_concurrent_tasks=1, ) @@ -246,6 +249,7 @@ def test_plan_receiver_task_unbounded_end( no_gaps=True, restatements={"raw.items"}, notification_targets=[], + backfill_concurrent_tasks=1, ddl_concurrent_tasks=1, )