Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions example/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,24 @@


# A config that uses Airflow + Spark.
DEFAULT_AIRFLOW_KWARGS = {
**DEFAULT_KWARGS,
"backfill_concurrent_tasks": 4,
"ddl_concurrent_tasks": 4,
}


airflow_config = Config(
**{
**DEFAULT_KWARGS,
**DEFAULT_AIRFLOW_KWARGS,
"scheduler_backend": AirflowSchedulerBackend(),
}
)


airflow_config_docker = Config(
**{
**DEFAULT_KWARGS,
**DEFAULT_AIRFLOW_KWARGS,
"scheduler_backend": AirflowSchedulerBackend(
airflow_url="http://airflow-webserver:8080/"
),
Expand Down
8 changes: 8 additions & 0 deletions sqlmesh/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def create_plan_evaluator(self, context: Context) -> PlanEvaluator:
return BuiltInPlanEvaluator(
state_sync=context.state_sync,
snapshot_evaluator=context.snapshot_evaluator,
backfill_concurrent_tasks=context.backfill_concurrent_tasks,
console=context.console,
)

Expand Down Expand Up @@ -218,6 +219,7 @@ def create_plan_evaluator(self, context: Context) -> PlanEvaluator:
dag_creation_max_retry_attempts=self.dag_creation_max_retry_attempts,
console=context.console,
notification_targets=context.notification_targets,
backfill_concurrent_tasks=context.backfill_concurrent_tasks,
ddl_concurrent_tasks=context.ddl_concurrent_tasks,
)

Expand Down Expand Up @@ -267,8 +269,12 @@ class Config(PydanticModel):
snapshot_ttl: Duration before unpromoted snapshots are removed.
time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d.
This time format uses python format codes. https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes.
backfill_concurrent_tasks: The number of concurrent tasks used for model backfilling during
plan application. Default: 1.
ddl_concurrent_task: The number of concurrent tasks used for DDL
operations (table / view creation, deletion, etc). Default: 1.
evaluation_concurrent_tasks: The number of concurrent tasks used for model evaluation when
running with the built-in scheduler. Default: 1.
"""

engine_connection_factory: t.Callable[[], t.Any] = duckdb.connect
Expand All @@ -280,7 +286,9 @@ class Config(PydanticModel):
snapshot_ttl: str = ""
ignore_patterns: t.List[str] = []
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT
backfill_concurrent_tasks: int = 1
ddl_concurrent_tasks: int = 1
evaluation_concurrent_tasks: int = 1

class Config:
arbitrary_types_allowed = True
24 changes: 22 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,12 @@ class Context(BaseContext):
physical_schema: The schema used to store physical materialized tables.
snapshot_ttl: Duration before unpromoted snapshots are removed.
path: The directory containing SQLMesh files.
backfill_concurrent_tasks: The number of concurrent tasks used for model backfilling during
plan application. Default: 1.
ddl_concurrent_task: The number of concurrent tasks used for DDL
operations (table / view creation, deletion, etc). Default: 1.
evaluation_concurrent_tasks: The number of concurrent tasks used for model evaluation when
running with the built-in scheduler. Default: 1.
config: A Config object or the name of a Config object in config.py.
test_config: A Config object or name of a Config object in config.py to use for testing only
load: Whether or not to automatically load all models and macros (default True).
Expand All @@ -166,7 +170,9 @@ def __init__(
physical_schema: str = "",
snapshot_ttl: str = "",
path: str = "",
backfill_concurrent_tasks: t.Optional[int] = None,
ddl_concurrent_tasks: t.Optional[int] = None,
evaluation_concurrent_tasks: t.Optional[int] = None,
config: t.Optional[t.Union[Config, str]] = None,
test_config: t.Optional[t.Union[Config, str]] = None,
load: bool = True,
Expand Down Expand Up @@ -200,20 +206,33 @@ def __init__(
self.macros = UniqueKeyDict("macros")
self.dag: DAG[str] = DAG()

self.backfill_concurrent_tasks = (
backfill_concurrent_tasks or self.config.backfill_concurrent_tasks
)
self.ddl_concurrent_tasks = (
ddl_concurrent_tasks or self.config.ddl_concurrent_tasks
)
self.evaluation_concurrent_tasks = (
evaluation_concurrent_tasks or self.config.evaluation_concurrent_tasks
)
self.is_multithreaded = (
max(
self.backfill_concurrent_tasks,
self.ddl_concurrent_tasks,
self.evaluation_concurrent_tasks,
)
> 1
)

self._engine_adapter = engine_adapter or EngineAdapter(
self.config.engine_connection_factory,
self.config.engine_dialect,
multithreaded=self.ddl_concurrent_tasks > 1,
multithreaded=self.is_multithreaded,
)
self.test_engine_adapter = (
EngineAdapter(
self.test_config.engine_connection_factory,
self.test_config.engine_dialect,
multithreaded=self.test_config.ddl_concurrent_tasks > 1,
)
if self.test_config
else None
Expand Down Expand Up @@ -270,6 +289,7 @@ def scheduler(self) -> Scheduler:
self.snapshots,
self.snapshot_evaluator,
self.state_sync,
max_workers=self.evaluation_concurrent_tasks,
console=self.console,
)

Expand Down
6 changes: 6 additions & 0 deletions sqlmesh/core/plan_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def __init__(
self,
state_sync: StateSync,
snapshot_evaluator: SnapshotEvaluator,
backfill_concurrent_tasks: int = 1,
console: t.Optional[Console] = None,
):
self.state_sync = state_sync
self.snapshot_evaluator = snapshot_evaluator
self.backfill_concurrent_tasks = backfill_concurrent_tasks
self.console = console or get_console()

def evaluate(self, plan: Plan) -> None:
Expand All @@ -75,6 +77,7 @@ def evaluate(self, plan: Plan) -> None:
{snapshot.name: snapshot for snapshot in snapshots},
self.snapshot_evaluator,
self.state_sync,
max_workers=self.backfill_concurrent_tasks,
console=self.console,
)
scheduler.run(snapshots, plan.start, plan.end)
Expand Down Expand Up @@ -138,6 +141,7 @@ def __init__(
dag_creation_poll_interval_secs: int = 30,
dag_creation_max_retry_attempts: int = 10,
notification_targets: t.Optional[t.List[NotificationTarget]] = None,
backfill_concurrent_tasks: int = 1,
ddl_concurrent_tasks: int = 1,
):
self.airflow_client = airflow_client
Expand All @@ -147,6 +151,7 @@ def __init__(
self.dag_creation_max_retry_attempts = dag_creation_max_retry_attempts
self.console = console or get_console()
self.notification_targets = notification_targets or []
self.backfill_concurrent_tasks = backfill_concurrent_tasks
self.ddl_concurrent_tasks = ddl_concurrent_tasks

def evaluate(self, plan: Plan) -> None:
Expand All @@ -161,6 +166,7 @@ def evaluate(self, plan: Plan) -> None:
no_gaps=plan.no_gaps,
restatements=plan.restatements,
notification_targets=self.notification_targets,
backfill_concurrent_tasks=self.backfill_concurrent_tasks,
ddl_concurrent_tasks=self.ddl_concurrent_tasks,
)

Expand Down
9 changes: 5 additions & 4 deletions sqlmesh/engines/spark/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def main() -> None:
if not command_handler:
raise NotSupportedError(f"Command '{command_type.value}' not supported")

ddl_concurrent_tasks = int(sys.argv[2]) if len(sys.argv) > 2 else 1

spark = create_spark_session()
connection = spark_session_db.connection(spark)

ddl_concurrent_tasks = int(sys.argv[2]) if len(sys.argv) > 2 else 1
evaluator = SnapshotEvaluator(
EngineAdapter(
lambda: connection, "spark", multithreaded=ddl_concurrent_tasks > 1
lambda: spark_session_db.connection(spark),
"spark",
multithreaded=ddl_concurrent_tasks > 1,
),
ddl_concurrent_tasks=ddl_concurrent_tasks,
)
Expand Down
7 changes: 3 additions & 4 deletions sqlmesh/engines/spark/db_api/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None:
if parameters:
raise NotSupportedError("Parameterized queries are not supported")

self._spark.sparkContext.setLocalProperty(
"spark.scheduler.pool", f"pool_{get_ident()}"
)

self._last_df = self._spark.sql(query)
self._last_output = None
self._last_output_cursor = 0
Expand Down Expand Up @@ -71,6 +67,9 @@ def __init__(self, spark: SparkSession):
self.spark = spark

def cursor(self) -> SparkSessionCursor:
self.spark.sparkContext.setLocalProperty(
"spark.scheduler.pool", f"pool_{get_ident()}"
)
return SparkSessionCursor(self.spark)

def commit(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/schedulers/airflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def apply_plan(
no_gaps: bool = False,
restatements: t.Optional[t.Iterable[str]] = None,
notification_targets: t.Optional[t.List[NotificationTarget]] = None,
backfill_concurrent_tasks: int = 1,
ddl_concurrent_tasks: int = 1,
timestamp: t.Optional[datetime] = None,
) -> str:
Expand All @@ -63,6 +64,7 @@ def apply_plan(
request_id=request_id,
restatements=set(restatements or []),
notification_targets=notification_targets or [],
backfill_concurrent_tasks=backfill_concurrent_tasks,
ddl_concurrent_tasks=ddl_concurrent_tasks,
),
dag_run_id=common.INIT_RUN_ID if is_first_run else None,
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/schedulers/airflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class PlanReceiverDagConf(PydanticModel):
no_gaps: bool
restatements: t.Set[str]
notification_targets: t.List[NotificationTarget]
backfill_concurrent_tasks: int
ddl_concurrent_tasks: int


Expand All @@ -68,6 +69,7 @@ class PlanApplicationRequest(PydanticModel):
plan_id: str
previous_plan_id: t.Optional[str]
notification_targets: t.List[NotificationTarget]
backfill_concurrent_tasks: int
ddl_concurrent_tasks: int


Expand Down
1 change: 1 addition & 0 deletions sqlmesh/schedulers/airflow/dag_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _create_plan_application_dag(
dag_id=dag_id,
schedule_interval="@once",
start_date=now(),
max_active_tasks=request.backfill_concurrent_tasks,
catchup=False,
is_paused_upon_creation=False,
tags=[
Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/schedulers/airflow/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def dags(self) -> t.List[DAG]:
def _create_plan_receiver_dag(self) -> DAG:
dag = self._create_system_dag(common.PLAN_RECEIVER_DAG_ID, None)

receiver_task = PythonOperator(
PythonOperator(
task_id=common.PLAN_RECEIVER_TASK_ID,
python_callable=_plan_receiver_task,
dag=dag,
Expand Down Expand Up @@ -234,6 +234,7 @@ def _plan_receiver_task(
plan_id=plan_conf.environment.plan_id,
previous_plan_id=plan_conf.environment.previous_plan_id,
notification_targets=plan_conf.notification_targets,
backfill_concurrent_tasks=plan_conf.backfill_concurrent_tasks,
ddl_concurrent_tasks=plan_conf.ddl_concurrent_tasks,
)

Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_plan_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot):
evaluator = BuiltInPlanEvaluator(
sushi_context.state_sync,
sushi_context.snapshot_evaluator,
sushi_context.console,
console=sushi_context.console,
)
evaluator._push(plan)

Expand Down Expand Up @@ -100,6 +100,7 @@ def test_airflow_evaluator(sushi_plan: Plan, mocker: MockerFixture):
no_gaps=False,
notification_targets=[],
restatements=set(),
backfill_concurrent_tasks=1,
ddl_concurrent_tasks=1,
)

Expand Down
1 change: 1 addition & 0 deletions tests/schedulers/airflow/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot, dag_run_entries:
"notification_targets": [],
"request_id": request_id,
"restatements": [],
"backfill_concurrent_tasks": 1,
"ddl_concurrent_tasks": 1,
},
"dag_run_id": expected_dag_run_id,
Expand Down
4 changes: 4 additions & 0 deletions tests/schedulers/airflow/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_plan_receiver_task(mocker: MockerFixture, make_snapshot, random_name):
no_gaps=True,
restatements={"raw.items"},
notification_targets=[],
backfill_concurrent_tasks=1,
ddl_concurrent_tasks=1,
)

Expand Down Expand Up @@ -183,6 +184,7 @@ def test_plan_receiver_task(mocker: MockerFixture, make_snapshot, random_name):
plan_id="test_plan_id",
previous_plan_id=None,
notification_targets=[],
backfill_concurrent_tasks=1,
ddl_concurrent_tasks=1,
)

Expand All @@ -204,6 +206,7 @@ def test_plan_receiver_task_duplicated_snapshot(
no_gaps=False,
restatements=set(),
notification_targets=[],
backfill_concurrent_tasks=1,
ddl_concurrent_tasks=1,
)

Expand Down Expand Up @@ -246,6 +249,7 @@ def test_plan_receiver_task_unbounded_end(
no_gaps=True,
restatements={"raw.items"},
notification_targets=[],
backfill_concurrent_tasks=1,
ddl_concurrent_tasks=1,
)

Expand Down