diff --git a/example/models/top_waiters.sql b/example/models/top_waiters.sql index 125e09c9f0..2cd74fc99a 100644 --- a/example/models/top_waiters.sql +++ b/example/models/top_waiters.sql @@ -1,7 +1,7 @@ /* View of top waiters. */ MODEL ( name sushi.top_waiters, - kind full, + kind view, owner jen ); diff --git a/sqlmesh/core/config.py b/sqlmesh/core/config.py index 110575a70a..4bf1f551c5 100644 --- a/sqlmesh/core/config.py +++ b/sqlmesh/core/config.py @@ -217,6 +217,7 @@ def create_plan_evaluator(self, context: Context) -> PlanEvaluator: dag_creation_max_retry_attempts=self.dag_creation_max_retry_attempts, console=context.console, notification_targets=context.notification_targets, + ddl_concurrent_tasks=context.ddl_concurrent_tasks, ) @@ -286,6 +287,8 @@ class Config(PydanticModel): physical_schema: The default schema used to store materialized tables. snapshot_ttl: Duration before unpromoted snapshots are removed. time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d. + ddl_concurrent_task: The number of concurrent tasks used for DDL + operations (table / view creation, deletion, etc). Default: 1. """ engine_adapter: EngineAdapter = Field( @@ -298,6 +301,7 @@ class Config(PydanticModel): snapshot_ttl: str = "" ignore_patterns: t.List[str] = [] time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT + ddl_concurrent_tasks: int = 1 class Config: arbitrary_types_allowed = True diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 5291b035d8..c70cf92ceb 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -82,6 +82,8 @@ class Context: physical_schema: The schema used to store physical materialized tables. snapshot_ttl: Duration before unpromoted snapshots are removed. path: The directory containing SQLMesh files. + ddl_concurrent_task: The number of concurrent tasks used for DDL + operations (table / view creation, deletion, etc). Default: 1. config: A Config object or the name of a Config object in config.py. test_config: A Config object or name of a Config object in config.py to use for testing only load: Whether or not to automatically load all models and macros (default True). @@ -97,6 +99,7 @@ def __init__( physical_schema: str = "", snapshot_ttl: str = "", path: str = "", + ddl_concurrent_tasks: t.Optional[int] = None, config: t.Optional[t.Union[Config, str]] = None, test_config: t.Optional[t.Union[Config, str]] = None, load: bool = True, @@ -120,7 +123,12 @@ def __init__( self.macros = UniqueKeyDict("macros") self.dag: DAG[str] = DAG() self.engine_adapter = engine_adapter or self.config.engine_adapter - self.snapshot_evaluator = SnapshotEvaluator(self.engine_adapter) + self.ddl_concurrent_tasks = ( + ddl_concurrent_tasks or self.config.ddl_concurrent_tasks + ) + self.snapshot_evaluator = SnapshotEvaluator( + self.engine_adapter, ddl_concurrent_tasks=self.ddl_concurrent_tasks + ) self._ignore_patterns = c.IGNORE_PATTERNS + self.config.ignore_patterns self.console = console or get_console() @@ -705,5 +713,4 @@ def _glob_path( def _add_model_to_dag(self, model: Model) -> None: self.dag.graph[model.name] = set() - for table in model.depends_on: - self.dag.add(model.name, table) + self.dag.add(model.name, model.depends_on) diff --git a/sqlmesh/core/dag.py b/sqlmesh/core/dag.py index 8d5bd80e9b..bfc88d26d0 100644 --- a/sqlmesh/core/dag.py +++ b/sqlmesh/core/dag.py @@ -16,17 +16,17 @@ class DAG(t.Generic[T]): def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None): self.graph = graph or {} - def add(self, node: T, dependency: t.Optional[T] = None) -> None: + def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: """Add a node to the graph with an optional upstream dependency. Args: node: The node to add. - dependency: An optional dependency to add to the node. + dependencies: Optional dependencies to add to the node. """ if node not in self.graph: self.graph[node] = set() - if dependency: - self.graph[node].add(dependency) + if dependencies: + self.graph[node].update(dependencies) def subdag(self, *nodes: T) -> DAG[T]: """Create a new subdag given node(s). @@ -50,7 +50,7 @@ def subdag(self, *nodes: T) -> DAG[T]: def upstream(self, node: T) -> t.List[T]: """Returns all upstream dependencies in topologically sorted order.""" - return self.subdag(node).sort()[:-1] + return self.subdag(node).sorted()[:-1] @property def leaves(self) -> t.Set[T]: @@ -59,7 +59,7 @@ def leaves(self) -> t.Set[T]: dep for deps in self.graph.values() for dep in deps if dep not in self.graph } - def sort(self) -> t.List[T]: + def sorted(self) -> t.List[T]: """Topologically sort the graph. Returns: @@ -98,7 +98,7 @@ def downstream(self, node: T) -> t.List[T]: Returns: A list of descendant nodes sorted in topological order. """ - sorted_nodes = self.sort() + sorted_nodes = self.sorted() try: node_index = sorted_nodes.index(node) except ValueError: diff --git a/sqlmesh/core/model.py b/sqlmesh/core/model.py index 0ab2100926..fa4bf33d84 100644 --- a/sqlmesh/core/model.py +++ b/sqlmesh/core/model.py @@ -814,9 +814,9 @@ def _render_query( end: The end datetime to render. Defaults to epoch start. latest: The latest datetime to use for non-incremental queries. Defaults to epoch start. add_incremental_filter: Add an incremental filter to the query if the model is incremental. - snapshots: All snapshots to use for expansion and mapping of physical locations. + snapshots: All upstream snapshots to use for expansion and mapping of physical locations. If passing snapshots is undesirable, mapping can be used instead to manually map tables. - mapping: Mapping to replace table names, if not set, the mapping wil be created from snapshots. + mapping: Mapping to replace table names, if not set, the mapping will be created from snapshots. expand: Expand referenced models as subqueries. This is used to bypass backfills when running queries that depend on materialized tables. Model definitions are inlined and can thus be run end to end on the fly. @@ -913,7 +913,7 @@ def ctas_query(self, snapshots: t.Dict[str, Snapshot]) -> exp.Subqueryable: SELECTS and hopefully the optimizer is smart enough to not do anything. Args: - All upstream snapshots of this model so queries can be expanded. + snapshots: All upstream snapshots of this model so queries can be expanded. Return: The mocked out ctas query. """ diff --git a/sqlmesh/core/plan.py b/sqlmesh/core/plan.py index 70178f5670..92ba965a19 100644 --- a/sqlmesh/core/plan.py +++ b/sqlmesh/core/plan.py @@ -288,7 +288,7 @@ def snapshot_change_category(self, snapshot: Snapshot) -> SnapshotChangeCategory def info_cache(self) -> InfoCache: """Returns the info cache of categorized, uncategorized snapshots.""" if self._info_cache is None: - queue = deque(self._dag.sort()) + queue = deque(self._dag.sorted()) snapshots = [] all_indirectly_modified = set() diff --git a/sqlmesh/core/plan_evaluator.py b/sqlmesh/core/plan_evaluator.py index dfe5916007..54eaf7e1df 100644 --- a/sqlmesh/core/plan_evaluator.py +++ b/sqlmesh/core/plan_evaluator.py @@ -93,10 +93,18 @@ def _push(self, plan: Plan) -> None: Args: plan: The plan to source snapshots from. """ - snapshots = {snapshot.name: snapshot for snapshot in plan.new_snapshots} - self.state_sync.push_snapshots(snapshots.values()) - for snapshot in snapshots.values(): - self.snapshot_evaluator.create(snapshot, snapshots) + parent_snapshot_ids = { + p_sid for snapshot in plan.new_snapshots for p_sid in snapshot.parents + } + + stored_snapshots_by_id = self.state_sync.get_snapshots(parent_snapshot_ids) + new_snapshots_by_id = { + snapshot.snapshot_id: snapshot for snapshot in plan.new_snapshots + } + all_snapshots_by_id = {**stored_snapshots_by_id, **new_snapshots_by_id} + + self.snapshot_evaluator.create(plan.new_snapshots, all_snapshots_by_id) + self.state_sync.push_snapshots(plan.new_snapshots) def _promote(self, plan: Plan) -> None: """Promote a plan. @@ -110,16 +118,14 @@ def _promote(self, plan: Plan) -> None: added, removed = self.state_sync.promote(environment, no_gaps=plan.no_gaps) - for snapshot_table_info in added: - self.snapshot_evaluator.promote( - snapshot_table_info, - environment=environment.name, - ) - for snapshot_table_info in removed: - self.snapshot_evaluator.demote( - snapshot_table_info, - environment=environment.name, - ) + self.snapshot_evaluator.promote( + added, + environment=environment.name, + ) + self.snapshot_evaluator.demote( + removed, + environment=environment.name, + ) class AirflowPlanEvaluator(PlanEvaluator): @@ -132,6 +138,7 @@ def __init__( dag_creation_poll_interval_secs: int = 30, dag_creation_max_retry_attempts: int = 10, notification_targets: t.Optional[t.List[NotificationTarget]] = None, + ddl_concurrent_tasks: int = 1, ): self.airflow_client = airflow_client self.blocking = blocking @@ -140,6 +147,7 @@ def __init__( self.dag_creation_max_retry_attempts = dag_creation_max_retry_attempts self.console = console or get_console() self.notification_targets = notification_targets or [] + self.ddl_concurrent_tasks = ddl_concurrent_tasks def evaluate(self, plan: Plan) -> None: environment = plan.environment @@ -153,6 +161,7 @@ def evaluate(self, plan: Plan) -> None: no_gaps=plan.no_gaps, restatements=plan.restatements, notification_targets=self.notification_targets, + ddl_concurrent_tasks=self.ddl_concurrent_tasks, ) if self.blocking: diff --git a/sqlmesh/core/snapshot.py b/sqlmesh/core/snapshot.py index 09347f851d..f235ffaa0c 100644 --- a/sqlmesh/core/snapshot.py +++ b/sqlmesh/core/snapshot.py @@ -168,6 +168,7 @@ class SnapshotTableInfo(PydanticModel, SnapshotInfoMixin, frozen=True): fingerprint: str version: str physical_schema: str + parents: t.Tuple[SnapshotId, ...] previous_versions: t.Tuple[SnapshotDataVersion, ...] = () change_category: t.Optional[SnapshotChangeCategory] @@ -233,7 +234,7 @@ class Snapshot(PydanticModel, SnapshotInfoMixin): fingerprint: str physical_schema: str model: Model - parents: t.List[SnapshotId] + parents: t.Tuple[SnapshotId, ...] intervals: Intervals created_ts: int updated_ts: int @@ -327,7 +328,7 @@ def from_model( ), physical_schema=physical_schema, model=model, - parents=[ + parents=tuple( SnapshotId( name=name, fingerprint=fingerprint_from_model( @@ -338,7 +339,7 @@ def from_model( ), ) for name in _parents_from_model(model, models) - ], + ), intervals=[], created_ts=created_ts, updated_ts=created_ts, @@ -509,6 +510,7 @@ def table_info(self) -> SnapshotTableInfo: name=self.name, fingerprint=self.fingerprint, version=self.version, + parents=self.parents, previous_versions=self.previous_versions, change_category=self.change_category, ) diff --git a/sqlmesh/core/snapshot_evaluator.py b/sqlmesh/core/snapshot_evaluator.py index 9b4c7e3f97..2d4aa05721 100644 --- a/sqlmesh/core/snapshot_evaluator.py +++ b/sqlmesh/core/snapshot_evaluator.py @@ -28,7 +28,8 @@ from sqlmesh.core.audit import AuditResult from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo +from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotInfoLike +from sqlmesh.utils.concurrency import concurrent_apply_to_snapshots from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import AuditError @@ -44,10 +45,13 @@ class SnapshotEvaluator: Args: adapter: The adapter that interfaces with the execution engine. + ddl_concurrent_task: The number of concurrent tasks used for DDL + operations (table / view creation, deletion, etc). Default: 1. """ - def __init__(self, adapter: EngineAdapter): + def __init__(self, adapter: EngineAdapter, ddl_concurrent_tasks: int = 1): self.adapter = adapter + self.ddl_concurrent_tasks = ddl_concurrent_tasks def evaluate( self, @@ -123,90 +127,65 @@ def evaluate( else: self.adapter.insert_append(table_name, query_or_df, columns=columns) - def promote(self, snapshot: SnapshotInfoLike, environment: str) -> None: - """Promotes the given snapshot in the target environment by replacing a corresponding view with - a physical table associated with the given snapshot. + def promote( + self, target_snapshots: t.Iterable[SnapshotInfoLike], environment: str + ) -> None: + """Promotes the given collection of snapshots in the target environment by replacing a corresponding + view with a physical table associated with the given snapshot. Args: - snapshot: Snapshot to promote. + target_snapshots: Snapshots to promote. environment: The target environment. """ - qualified_view_name = snapshot.qualified_view_name - schema = qualified_view_name.schema_for_environment(environment=environment) - if schema is not None: - self.adapter.create_schema(schema) - - view_name = qualified_view_name.for_environment(environment=environment) - table_name = snapshot.table_name - if self.adapter.table_exists(table_name): - logger.info( - "Updating view '%s' to point at table '%s'", view_name, table_name - ) - self.adapter.create_view(view_name, exp.select("*").from_(table_name)) - else: - logger.info("Dropping view '%s' for non-materialized table", view_name) - self.adapter.drop_view(view_name) + concurrent_apply_to_snapshots( + target_snapshots, + lambda s: self._promote_snapshot(s, environment), + self.ddl_concurrent_tasks, + ) - def demote(self, snapshot: SnapshotInfoLike, environment: str) -> None: - """Demotes the given snapshot in the target environment by removing its view. + def demote( + self, target_snapshots: t.Iterable[SnapshotInfoLike], environment: str + ) -> None: + """Demotes the given collection of snapshots in the target environment by removing its view. Args: - snapshot: Snapshot to remove. + target_snapshots: Snapshots to demote. environment: The target environment. """ - view_name = snapshot.qualified_view_name.for_environment( - environment=environment + concurrent_apply_to_snapshots( + target_snapshots, + lambda s: self._demote_snapshot(s, environment), + self.ddl_concurrent_tasks, ) - if self.adapter.table_exists(view_name): - logger.info("Dropping view '%s'", view_name) - self.adapter.drop_view(view_name) - def create(self, snapshot: Snapshot, snapshots: t.Dict[str, Snapshot]) -> None: - """Creates a physical snapshot schema and table. + def create( + self, + target_snapshots: t.Iterable[Snapshot], + snapshots: t.Dict[SnapshotId, Snapshot], + ) -> None: + """Creates a physical snapshot schema and table for the given collection of snapshots. Args: - snapshot: Snapshot to create. + target_snapshots: Target snapshost. """ - if snapshot.is_embedded_kind: - return - - self.adapter.create_schema(snapshot.physical_schema) - table_name = snapshot.table_name - - if snapshot.is_view_kind: - logger.info("Creating view '%s'", table_name) - self.adapter.create_view( - table_name, snapshot.model.render_query(snapshots=snapshots) - ) - else: - logger.info("Creating table '%s'", table_name) - self.adapter.create_table( - table_name, - query_or_columns=snapshot.model.columns - if snapshot.model.annotated - else snapshot.model.ctas_query(snapshots), - storage_format=snapshot.model.storage_format, - partitioned_by=snapshot.model.partitioned_by, - ) + concurrent_apply_to_snapshots( + target_snapshots, + lambda s: self._create_snapshot(s, snapshots), + self.ddl_concurrent_tasks, + ) - def cleanup(self, snapshots: t.Iterable[SnapshotTableInfo | Snapshot]) -> None: + def cleanup(self, target_snapshots: t.Iterable[SnapshotInfoLike]) -> None: """Cleans up the given snapshots by removing its table Args: - snapshots: Snapshots to cleanup. + target_snapshots: Snapshots to cleanup. """ - for snapshot in snapshots: - snapshot = snapshot.table_info - table_name = snapshot.table_name - if not self.adapter.table_exists(table_name): - continue - - try: - self.adapter.drop_table(table_name) - logger.info("Dropped table '%s'", table_name) - except Exception: - self.adapter.drop_view(table_name) - logger.info("Dropped view '%s'", table_name) + concurrent_apply_to_snapshots( + target_snapshots, + self._cleanup_snapshot, + self.ddl_concurrent_tasks, + reverse_order=True, + ) def audit( self, @@ -251,3 +230,69 @@ def audit( ) results.append(AuditResult(audit=audit, count=count, query=query)) return results + + def _create_snapshot( + self, snapshot: Snapshot, snapshots: t.Dict[SnapshotId, Snapshot] + ) -> None: + if snapshot.is_embedded_kind: + return + + self.adapter.create_schema(snapshot.physical_schema) + table_name = snapshot.table_name + + parent_snapshots_by_name = { + snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents + } + + if snapshot.is_view_kind: + logger.info("Creating view '%s'", table_name) + self.adapter.create_view( + table_name, + snapshot.model.render_query(snapshots=parent_snapshots_by_name), + ) + else: + logger.info("Creating table '%s'", table_name) + self.adapter.create_table( + table_name, + query_or_columns=snapshot.model.columns + if snapshot.model.annotated + else snapshot.model.ctas_query(parent_snapshots_by_name), + storage_format=snapshot.model.storage_format, + partitioned_by=snapshot.model.partitioned_by, + ) + + def _promote_snapshot(self, snapshot: SnapshotInfoLike, environment: str) -> None: + qualified_view_name = snapshot.qualified_view_name + schema = qualified_view_name.schema_for_environment(environment=environment) + if schema is not None: + self.adapter.create_schema(schema) + + view_name = qualified_view_name.for_environment(environment=environment) + table_name = snapshot.table_name + if self.adapter.table_exists(table_name): + logger.info( + "Updating view '%s' to point at table '%s'", view_name, table_name + ) + self.adapter.create_view(view_name, exp.select("*").from_(table_name)) + else: + logger.info("Dropping view '%s' for non-materialized table", view_name) + self.adapter.drop_view(view_name) + + def _demote_snapshot(self, snapshot: SnapshotInfoLike, environment: str) -> None: + view_name = snapshot.qualified_view_name.for_environment( + environment=environment + ) + if self.adapter.table_exists(view_name): + logger.info("Dropping view '%s'", view_name) + self.adapter.drop_view(view_name) + + def _cleanup_snapshot(self, snapshot: SnapshotInfoLike) -> None: + snapshot = snapshot.table_info + table_name = snapshot.table_name + if self.adapter.table_exists(table_name): + try: + self.adapter.drop_table(table_name) + logger.info("Dropped table '%s'", table_name) + except Exception: + self.adapter.drop_view(table_name) + logger.info("Dropped view '%s'", table_name) diff --git a/sqlmesh/engines/commands.py b/sqlmesh/engines/commands.py index 98d4fd6d2a..897326d3e9 100644 --- a/sqlmesh/engines/commands.py +++ b/sqlmesh/engines/commands.py @@ -64,8 +64,7 @@ def promote( ) -> None: if isinstance(command_payload, str): command_payload = PromoteCommandPayload.parse_raw(command_payload) - for s in command_payload.snapshots: - evaluator.promote(s, command_payload.environment) + evaluator.promote(command_payload.snapshots, command_payload.environment) def demote( @@ -73,8 +72,7 @@ def demote( ) -> None: if isinstance(command_payload, str): command_payload = DemoteCommandPayload.parse_raw(command_payload) - for s in command_payload.snapshots: - evaluator.demote(s, command_payload.environment) + evaluator.demote(command_payload.snapshots, command_payload.environment) def cleanup( @@ -92,16 +90,11 @@ def create_tables( if isinstance(command_payload, str): command_payload = CreateTablesCommandPayload.parse_raw(command_payload) - snapshots = {s.snapshot_id: s for s in command_payload.snapshots} - for target_sid in command_payload.target_snapshot_ids: - target_snapshot = snapshots[target_sid] - parent_snapshots = { - p.name: p - for p in [ - snapshots[sid] for sid in target_snapshot.parents if sid in snapshots - ] - } - evaluator.create(target_snapshot, parent_snapshots) + snapshots_by_id = {s.snapshot_id: s for s in command_payload.snapshots} + target_snapshots = [ + snapshots_by_id[sid] for sid in command_payload.target_snapshot_ids + ] + evaluator.create(target_snapshots, snapshots_by_id) COMMAND_HANDLERS: t.Dict[CommandType, t.Callable[[SnapshotEvaluator, str], None]] = { diff --git a/sqlmesh/engines/spark/app.py b/sqlmesh/engines/spark/app.py index 98bb8491e7..e49f053555 100644 --- a/sqlmesh/engines/spark/app.py +++ b/sqlmesh/engines/spark/app.py @@ -14,19 +14,32 @@ def create_spark_session() -> SparkSession: - return SparkSession.builder.enableHiveSupport().getOrCreate() + return ( + SparkSession.builder.config("spark.scheduler.mode", "FAIR") + .enableHiveSupport() + .getOrCreate() + ) def main() -> None: - spark = create_spark_session() - connection = spark_session_db.connection(spark) - evaluator = SnapshotEvaluator(EngineAdapter(connection, "spark")) + logging.basicConfig( + format="%(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)", + level=logging.INFO, + ) command_type = commands.CommandType(sys.argv[1]) command_handler = commands.COMMAND_HANDLERS.get(command_type) if not command_handler: raise NotSupportedError(f"Command '{command_type.value}' not supported") + ddl_concurrent_tasks = int(sys.argv[2]) if len(sys.argv) > 2 else 1 + + spark = create_spark_session() + connection = spark_session_db.connection(spark) + evaluator = SnapshotEvaluator( + EngineAdapter(connection, "spark"), ddl_concurrent_tasks=ddl_concurrent_tasks + ) + with open(SparkFiles.get(commands.COMMAND_PAYLOAD_FILE_NAME), "r") as payload_fd: command_payload = payload_fd.read() logger.info("Command payload:\n %s", command_payload) diff --git a/sqlmesh/engines/spark/db_api/spark_session.py b/sqlmesh/engines/spark/db_api/spark_session.py index 6c58a737c5..1670fed597 100644 --- a/sqlmesh/engines/spark/db_api/spark_session.py +++ b/sqlmesh/engines/spark/db_api/spark_session.py @@ -1,4 +1,5 @@ import typing as t +from threading import get_ident from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import Row @@ -16,6 +17,11 @@ def __init__(self, spark: SparkSession): def execute(self, query: str, parameters: t.Optional[t.Any] = None) -> None: if parameters: raise NotSupportedError("Parameterized queries are not supported") + + self._spark.sparkContext.setLocalProperty( + "spark.scheduler.pool", f"pool_{get_ident()}" + ) + self._last_df = self._spark.sql(query) self._last_output = None self._last_output_cursor = 0 diff --git a/sqlmesh/schedulers/airflow/client.py b/sqlmesh/schedulers/airflow/client.py index 7842ec1778..72e0277b1c 100644 --- a/sqlmesh/schedulers/airflow/client.py +++ b/sqlmesh/schedulers/airflow/client.py @@ -50,6 +50,7 @@ def apply_plan( no_gaps: bool = False, restatements: t.Optional[t.Iterable[str]] = None, notification_targets: t.Optional[t.List[NotificationTarget]] = None, + ddl_concurrent_tasks: int = 1, timestamp: t.Optional[datetime] = None, ) -> str: is_first_run = self._get_first_dag_run_id(common.PLAN_RECEIVER_DAG_ID) is None @@ -62,6 +63,7 @@ def apply_plan( request_id=request_id, restatements=set(restatements or []), notification_targets=notification_targets or [], + ddl_concurrent_tasks=ddl_concurrent_tasks, ), dag_run_id=common.INIT_RUN_ID if is_first_run else None, timestamp=timestamp, diff --git a/sqlmesh/schedulers/airflow/common.py b/sqlmesh/schedulers/airflow/common.py index 00e099ea44..add9ce4c99 100644 --- a/sqlmesh/schedulers/airflow/common.py +++ b/sqlmesh/schedulers/airflow/common.py @@ -47,6 +47,7 @@ class PlanReceiverDagConf(PydanticModel): no_gaps: bool restatements: t.Set[str] notification_targets: t.List[NotificationTarget] + ddl_concurrent_tasks: int class BackfillIntervalsPerSnapshot(PydanticModel): @@ -57,16 +58,17 @@ class BackfillIntervalsPerSnapshot(PydanticModel): class PlanApplicationRequest(PydanticModel): request_id: str environment_name: str - new_snapshot_batches: t.List[t.List[Snapshot]] + new_snapshots: t.List[Snapshot] backfill_intervals_per_snapshot: t.List[BackfillIntervalsPerSnapshot] - promotion_batches: t.List[t.List[SnapshotTableInfo]] - demotion_batches: t.List[t.List[SnapshotTableInfo]] + promoted_snapshots: t.List[SnapshotTableInfo] + demoted_snapshots: t.List[SnapshotTableInfo] start: TimeLike end: t.Optional[TimeLike] no_gaps: bool plan_id: str previous_plan_id: t.Optional[str] notification_targets: t.List[NotificationTarget] + ddl_concurrent_tasks: int def snapshot_xcom_key(snapshot: SnapshotIdLike) -> str: diff --git a/sqlmesh/schedulers/airflow/dag_generator.py b/sqlmesh/schedulers/airflow/dag_generator.py index 6a47839a48..e07966af93 100644 --- a/sqlmesh/schedulers/airflow/dag_generator.py +++ b/sqlmesh/schedulers/airflow/dag_generator.py @@ -120,11 +120,8 @@ def _create_plan_application_dag( request.environment_name, ) - new_snapshots = { - s for snapshots in request.new_snapshot_batches for s in snapshots - } all_snapshots = { - **{s.snapshot_id: s for s in new_snapshots}, + **{s.snapshot_id: s for s in request.new_snapshots}, **self._snapshots, } @@ -144,7 +141,7 @@ def _create_plan_application_dag( end_task = EmptyOperator(task_id="plan_application_end") (create_start_task, create_end_task) = self._create_creation_tasks( - request.new_snapshot_batches + request.new_snapshots, request.ddl_concurrent_tasks ) (backfill_start_task, backfill_end_task) = self._create_backfill_tasks( @@ -155,14 +152,15 @@ def _create_plan_application_dag( promote_start_task, promote_end_task, ) = self._create_promotion_demotion_tasks( - request.promotion_batches, - request.demotion_batches, + request.promoted_snapshots, + request.demoted_snapshots, request.environment_name, request.start, request.end, request.no_gaps, request.plan_id, request.previous_plan_id, + request.ddl_concurrent_tasks, ) start_task >> create_start_task @@ -211,55 +209,51 @@ def _add_notification_target_tasks( promote_end_task >> end_task def _create_creation_tasks( - self, new_snapshot_batches: t.List[t.List[Snapshot]] + self, new_snapshots: t.List[Snapshot], ddl_concurrent_tasks: int ) -> t.Tuple[BaseOperator, BaseOperator]: start_task = EmptyOperator(task_id="snapshot_creation_start") end_task = EmptyOperator(task_id="snapshot_creation_end") - new_snapshot_batches = [b for b in new_snapshot_batches if b] - - if not new_snapshot_batches: + if not new_snapshots: start_task >> end_task return (start_task, end_task) - new_snapshots = [s for snapshots in new_snapshot_batches for s in snapshots] + creation_task = self._create_snapshot_create_table_operator( + new_snapshots, ddl_concurrent_tasks, "snapshot_creation__create_tables" + ) + update_state_task = PythonOperator( task_id="snapshot_creation__update_state", python_callable=creation_update_state_task, op_kwargs={"new_snapshots": new_snapshots}, ) + start_task >> creation_task + creation_task >> update_state_task update_state_task >> end_task - for batch_id, batch in enumerate(new_snapshot_batches): - task = self._create_snapshot_create_table_operator( - batch, f"snapshot_creation__create_tables_batch_{batch_id}" - ) - start_task >> task - task >> update_state_task - return (start_task, end_task) def _create_promotion_demotion_tasks( self, - promotion_batches: t.List[t.List[SnapshotTableInfo]], - demotion_batches: t.List[t.List[SnapshotTableInfo]], + promoted_snapshots: t.List[SnapshotTableInfo], + demoted_snapshots: t.List[SnapshotTableInfo], environment: str, start: TimeLike, end: t.Optional[TimeLike], no_gaps: bool, plan_id: str, previous_plan_id: t.Optional[str], + ddl_concurrent_tasks: int, ) -> t.Tuple[BaseOperator, BaseOperator]: start_task = EmptyOperator(task_id="snapshot_promotion_start") end_task = EmptyOperator(task_id="snapshot_promotion_end") - snapshots = [s for snapshots in promotion_batches for s in snapshots] update_state_task = PythonOperator( task_id="snapshot_promotion__update_state", python_callable=promotion_update_state_task, op_kwargs={ - "snapshots": snapshots, + "snapshots": promoted_snapshots, "environment_name": environment, "start": start, "end": end, @@ -271,28 +265,27 @@ def _create_promotion_demotion_tasks( start_task >> update_state_task - promotion_batches = [b for b in promotion_batches if b] - demotion_batches = [b for b in demotion_batches if b] - - for batch_id, batch in enumerate(promotion_batches): - task = self._create_snapshot_promotion_operator( - batch, + if promoted_snapshots: + create_views_task = self._create_snapshot_promotion_operator( + promoted_snapshots, environment, - f"snapshot_promotion__create_views_batch_{batch_id}", + ddl_concurrent_tasks, + "snapshot_promotion__create_views", ) - update_state_task >> task - task >> end_task + update_state_task >> create_views_task + create_views_task >> end_task - for batch_id, batch in enumerate(demotion_batches): - task = self._create_snapshot_demotion_operator( - batch, + if demoted_snapshots: + delete_views_task = self._create_snapshot_demotion_operator( + demoted_snapshots, environment, - f"snapshot_promotion__delete_views_batch_{batch_id}", + ddl_concurrent_tasks, + "snapshot_promotion__delete_views", ) - update_state_task >> task - task >> end_task + update_state_task >> delete_views_task + delete_views_task >> end_task - if not promotion_batches and not demotion_batches: + if not promoted_snapshots and not demoted_snapshots: update_state_task >> end_task return (start_task, end_task) @@ -373,6 +366,7 @@ def _create_snapshot_promotion_operator( self, snapshots: t.List[SnapshotTableInfo], environment: str, + ddl_concurrent_tasks: int, task_id: str, ) -> BaseOperator: return self._engine_operator( @@ -380,6 +374,7 @@ def _create_snapshot_promotion_operator( target=targets.SnapshotPromotionTarget( snapshots=snapshots, environment=environment, + ddl_concurrent_tasks=ddl_concurrent_tasks, ), task_id=task_id, ) @@ -388,6 +383,7 @@ def _create_snapshot_demotion_operator( self, snapshots: t.List[SnapshotTableInfo], environment: str, + ddl_concurrent_tasks: int, task_id: str, ) -> BaseOperator: return self._engine_operator( @@ -395,6 +391,7 @@ def _create_snapshot_demotion_operator( target=targets.SnapshotDemotionTarget( snapshots=snapshots, environment=environment, + ddl_concurrent_tasks=ddl_concurrent_tasks, ), task_id=task_id, ) @@ -402,11 +399,14 @@ def _create_snapshot_demotion_operator( def _create_snapshot_create_table_operator( self, new_snapshots: t.List[Snapshot], + ddl_concurrent_tasks: int, task_id: str, ) -> BaseOperator: return self._engine_operator( **self._ddl_engine_operator_args, - target=targets.SnapshotCreateTableTarget(new_snapshots=new_snapshots), + target=targets.SnapshotCreateTableTarget( + new_snapshots=new_snapshots, ddl_concurrent_tasks=ddl_concurrent_tasks + ), task_id=task_id, ) @@ -454,7 +454,7 @@ def _create_hwm_sensors(self, snapshot: Snapshot) -> t.List[HighWaterMarkSensor] @provide_session def creation_update_state_task( - new_snapshots: t.List[Snapshot], + new_snapshots: t.Iterable[Snapshot], session: Session = util.PROVIDED_SESSION, ) -> None: XComStateSync(session).push_snapshots(new_snapshots) diff --git a/sqlmesh/schedulers/airflow/integration.py b/sqlmesh/schedulers/airflow/integration.py index 5b181f279f..4e671e5974 100644 --- a/sqlmesh/schedulers/airflow/integration.py +++ b/sqlmesh/schedulers/airflow/integration.py @@ -53,7 +53,6 @@ class SQLMeshAirflow: connection ID. ddl_engine_operator_args: Same as `engine_operator_args` but only used for the snapshot promotion process. If not specified falls back to using `engine_operator_args`. - ddl_concurrent_task: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). janitor_interval: Defines how often the janitor DAG runs. The janitor DAG removes platform-managed DAG instances that are pending deletion from Airflow. Default: 1 hour. @@ -66,7 +65,6 @@ def __init__( engine_operator: t.Union[str, t.Type[BaseOperator]], engine_operator_args: t.Optional[t.Dict[str, t.Any]] = None, ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]] = None, - ddl_concurrent_tasks: int = 1, janitor_interval: timedelta = timedelta(hours=1), plan_application_dag_ttl: timedelta = timedelta(days=2), ): @@ -78,7 +76,6 @@ def __init__( self._ddl_engine_operator_args = ( ddl_engine_operator_args or engine_operator_args or {} ) - self._ddl_concurrent_tasks = ddl_concurrent_tasks self._janitor_interval = janitor_interval self._plan_application_dag_ttl = plan_application_dag_ttl @@ -118,7 +115,6 @@ def _create_plan_receiver_dag(self) -> DAG: receiver_task = PythonOperator( task_id=common.PLAN_RECEIVER_TASK_ID, python_callable=_plan_receiver_task, - op_kwargs={"ddl_concurrent_tasks": self._ddl_concurrent_tasks}, dag=dag, ) @@ -166,7 +162,6 @@ def _create_system_dag( def _plan_receiver_task( dag_run: DagRun, ti: TaskInstance, - ddl_concurrent_tasks: int, session: Session = util.PROVIDED_SESSION, ) -> None: state_sync = XComStateSync(session) @@ -226,31 +221,20 @@ def _plan_receiver_task( for (snapshot, intervals) in backfill_batches ] - new_snapshot_batches = util.create_batches( - plan_conf.new_snapshots, ddl_concurrent_tasks - ) - - promotion_batches = util.create_batches( - plan_conf.environment.snapshots, ddl_concurrent_tasks - ) - - demotion_batches = util.create_batches( - _get_demoted_snapshots(plan_conf.environment, state_sync), ddl_concurrent_tasks - ) - request = common.PlanApplicationRequest( request_id=plan_conf.request_id, environment_name=plan_conf.environment.name, - new_snapshot_batches=new_snapshot_batches, + new_snapshots=plan_conf.new_snapshots, backfill_intervals_per_snapshot=backfill_intervals_per_snapshot, - promotion_batches=promotion_batches, - demotion_batches=demotion_batches, + promoted_snapshots=plan_conf.environment.snapshots, + demoted_snapshots=_get_demoted_snapshots(plan_conf.environment, state_sync), start=plan_conf.environment.start, end=plan_conf.environment.end, no_gaps=plan_conf.no_gaps, plan_id=plan_conf.environment.plan_id, previous_plan_id=plan_conf.environment.previous_plan_id, notification_targets=plan_conf.notification_targets, + ddl_concurrent_tasks=plan_conf.ddl_concurrent_tasks, ) ti.xcom_push( diff --git a/sqlmesh/schedulers/airflow/operators/spark_submit.py b/sqlmesh/schedulers/airflow/operators/spark_submit.py index 55e3ec5097..3b5f58e3db 100644 --- a/sqlmesh/schedulers/airflow/operators/spark_submit.py +++ b/sqlmesh/schedulers/airflow/operators/spark_submit.py @@ -78,21 +78,28 @@ def execute(self, context: Context) -> None: if self._hook is None: self._hook = self._get_hook( - self._target.command_type, payload_file_path + self._target.command_type, + payload_file_path, + self._target.ddl_concurrent_tasks, ) self._hook.submit(self._application) self._target.post_hook(context) def on_kill(self) -> None: if self._hook is None: - self._hook = self._get_hook(None, None) + self._hook = self._get_hook(None, None, None) self._hook.on_kill() def _get_hook( self, command_type: t.Optional[commands.CommandType], command_payload_file_path: t.Optional[str], + ddl_concurrent_tasks: t.Optional[int], ) -> SparkSubmitHook: + application_args = [ + *([command_type.value] if command_type else []), + *([str(ddl_concurrent_tasks)] if ddl_concurrent_tasks else []), + ] return SparkSubmitHook( conf=self._spark_conf, conn_id=self._connection_id, @@ -105,6 +112,6 @@ def _get_hook( proxy_user=self._proxy_user, name=self._application_name, num_executors=self._num_executors, - application_args=[command_type.value] if command_type else None, + application_args=application_args, files=command_payload_file_path, ) diff --git a/sqlmesh/schedulers/airflow/operators/targets.py b/sqlmesh/schedulers/airflow/operators/targets.py index c7ec7188a3..1e84841de1 100644 --- a/sqlmesh/schedulers/airflow/operators/targets.py +++ b/sqlmesh/schedulers/airflow/operators/targets.py @@ -22,6 +22,7 @@ class BaseTarget(abc.ABC, t.Generic[CP]): command_type: commands.CommandType command_handler: t.Callable[[SnapshotEvaluator, CP], None] + ddl_concurrent_tasks: int def serialized_command_payload(self, context: Context) -> str: """Returns the serialized command payload for the Spark application. @@ -44,7 +45,10 @@ def execute(self, context: Context, connection: t.Any, dialect: str) -> None: dialect: The dialect with which this adapter is associated. """ payload = self._get_command_payload_or_skip(context) - snapshot_evaluator = SnapshotEvaluator(EngineAdapter(connection, dialect)) + snapshot_evaluator = SnapshotEvaluator( + EngineAdapter(connection, dialect), + ddl_concurrent_tasks=self.ddl_concurrent_tasks, + ) self.command_handler(snapshot_evaluator, payload) self.post_hook(context) @@ -93,6 +97,7 @@ class SnapshotEvaluationTarget( command_handler: t.Callable[ [SnapshotEvaluator, commands.EvaluateCommandPayload], None ] = commands.evaluate + ddl_concurrent_tasks: int = 1 snapshot: Snapshot table_mapping: t.Dict[str, str] @@ -151,6 +156,7 @@ class SnapshotPromotionTarget( snapshots: t.List[SnapshotTableInfo] environment: str + ddl_concurrent_tasks: int def _get_command_payload( self, context: Context @@ -178,6 +184,7 @@ class SnapshotDemotionTarget(BaseTarget[commands.DemoteCommandPayload], Pydantic snapshots: t.List[SnapshotTableInfo] environment: str + ddl_concurrent_tasks: int def _get_command_payload( self, context: Context @@ -197,6 +204,7 @@ class SnapshotTableCleanupTarget( command_handler: t.Callable[ [SnapshotEvaluator, commands.CleanupCommandPayload], None ] = commands.cleanup + ddl_concurrent_tasks: int = 1 @provide_session def post_hook( @@ -237,7 +245,7 @@ class SnapshotCreateTableTarget( ] = commands.create_tables new_snapshots: t.List[Snapshot] - snapshots: t.Optional[t.List[Snapshot]] + ddl_concurrent_tasks: int def _get_command_payload( self, context: Context diff --git a/sqlmesh/schedulers/airflow/util.py b/sqlmesh/schedulers/airflow/util.py index 2415ceb14f..47182e8257 100644 --- a/sqlmesh/schedulers/airflow/util.py +++ b/sqlmesh/schedulers/airflow/util.py @@ -1,5 +1,4 @@ import logging -import math import typing as t from datetime import datetime, timedelta, timezone @@ -107,18 +106,3 @@ def safe_utcfromtimestamp(timestamp: t.Optional[float]) -> t.Optional[datetime]: if timestamp is not None else None ) - - -T = t.TypeVar("T") - - -def create_batches(snapshots: t.List[T], tasks_num: int) -> t.List[t.List[T]]: - batch_size = math.ceil(len(snapshots) / tasks_num) - - result = [] - for i in range(0, tasks_num): - batch_offset = i * batch_size - batch = snapshots[batch_offset : batch_offset + batch_size] - result.append(batch) - - return result diff --git a/sqlmesh/utils/concurrency.py b/sqlmesh/utils/concurrency.py new file mode 100644 index 0000000000..44f3f7c3bd --- /dev/null +++ b/sqlmesh/utils/concurrency.py @@ -0,0 +1,89 @@ +import typing as t +from concurrent.futures import Future, ThreadPoolExecutor, wait +from threading import Event + +from sqlmesh.core.dag import DAG +from sqlmesh.core.snapshot import SnapshotId, SnapshotInfoLike +from sqlmesh.utils.errors import ConfigError + +T = t.TypeVar("T", bound=SnapshotInfoLike) + + +def concurrent_apply_to_snapshots( + snapshots: t.Iterable[T], + fn: t.Callable[[T], None], + tasks_num: int, + reverse_order: bool = False, +) -> None: + """Applies a function to the given collection of snapshots concurrently while + preserving the topological order between snapshots. + + Args: + snapshots: Target snapshots. + fn: The function that will be applied concurrently to each snapshot. + tasks_num: The number of concurrent tasks. + reverse_order: Whether the order should be reversed. Default: False.. + """ + snapshots_by_id = {s.snapshot_id: s for s in snapshots} + + dag: DAG[SnapshotId] = DAG[SnapshotId]() + for snapshot in snapshots: + dag.add( + snapshot.snapshot_id, + [p_sid for p_sid in snapshot.parents if p_sid in snapshots_by_id], + ) + + concurrent_apply_to_dag( + dag, + lambda s_id: fn(snapshots_by_id[s_id]), + tasks_num, + reverse_order=reverse_order, + ) + + +H = t.TypeVar("H", bound=t.Hashable) + + +def concurrent_apply_to_dag( + dag: DAG[H], fn: t.Callable[[H], None], tasks_num: int, reverse_order: bool = False +) -> None: + """Applies a function to the given DAG concurrently while preserving the topological + order between snapshots. + + Args: + dag: The target DAG. + fn: The function that will be applied concurrently to each snapshot. + tasks_num: The number of concurrent tasks. + reverse_order: Whether the order should be reversed. Default: False.. + """ + if tasks_num <= 0: + raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") + + ordered_nodes = dag.sorted() + if reverse_order: + ordered_nodes.reverse() + + if tasks_num == 1: + for node in ordered_nodes: + fn(node) + return + + future_map: t.Dict[H, Future] = {} + future_map_ready = Event() + + def process_node(node: H) -> None: + future_map_ready.wait() + + wait_for = [ + future_map[n] + for n in (dag.upstream(node) if not reverse_order else dag.downstream(node)) + ] + wait(wait_for) + fn(node) + + with ThreadPoolExecutor(max_workers=tasks_num) as pool: + for node in ordered_nodes: + future_map[node] = pool.submit(process_node, node) + future_map_ready.set() + + wait(future_map.values()) diff --git a/tests/core/test_dag.py b/tests/core/test_dag.py index e3a71db6cf..85eee579ac 100644 --- a/tests/core/test_dag.py +++ b/tests/core/test_dag.py @@ -11,7 +11,7 @@ def test_no_downstream(sushi_context): def test_lineage(sushi_context): - lineage = sushi_context.dag.lineage("sushi.order_items").sort() + lineage = sushi_context.dag.lineage("sushi.order_items").sorted() assert lineage.index("sushi.order_items") > lineage.index("sushi.items") assert lineage.index("sushi.customer_revenue_by_day") > lineage.index( "sushi.order_items" diff --git a/tests/core/test_plan_evaluator.py b/tests/core/test_plan_evaluator.py index 51ea2805c3..f370b8f69b 100644 --- a/tests/core/test_plan_evaluator.py +++ b/tests/core/test_plan_evaluator.py @@ -1,9 +1,11 @@ import pytest from pytest_mock.plugin import MockerFixture +from sqlglot import parse_one from sqlmesh.core.context import Context +from sqlmesh.core.model import Model, ModelKind from sqlmesh.core.plan import Plan -from sqlmesh.core.plan_evaluator import AirflowPlanEvaluator +from sqlmesh.core.plan_evaluator import AirflowPlanEvaluator, BuiltInPlanEvaluator from sqlmesh.schedulers.airflow import common as airflow_common from sqlmesh.utils.errors import SQLMeshError @@ -21,6 +23,65 @@ def sushi_plan(sushi_context: Context, mocker: MockerFixture) -> Plan: ) +def test_builtin_evaluator_push(sushi_context: Context, make_snapshot): + new_model = Model( + name="sushi.new_test_model", + owner="jen", + cron="@daily", + start="2020-01-01", + query=parse_one("SELECT 1::INT AS one"), + ) + new_view_model = Model( + name="sushi.new_test_view_model", + kind=ModelKind.VIEW, + owner="jen", + start="2020-01-01", + query=parse_one( + "SELECT 1::INT AS one FROM sushi.new_test_model, sushi.waiters" + ), + ) + + sushi_context.upsert_model(new_view_model) + sushi_context.upsert_model(new_model) + + snapshots = sushi_context.snapshots + new_model_snapshot = snapshots[new_model.name] + new_view_model_snapshot = snapshots[new_view_model.name] + + new_model_snapshot.version = new_model_snapshot.fingerprint + sushi_context.table_info_cache[ + new_model_snapshot.snapshot_id + ] = new_model_snapshot.table_info + new_view_model_snapshot.version = new_view_model_snapshot.fingerprint + sushi_context.table_info_cache[ + new_view_model_snapshot.snapshot_id + ] = new_view_model_snapshot.table_info + + plan = Plan( + sushi_context._context_diff("prod"), + dag=sushi_context.dag, + state_reader=sushi_context.state_reader, + ) + + evaluator = BuiltInPlanEvaluator( + sushi_context.state_sync, + sushi_context.snapshot_evaluator, + sushi_context.console, + ) + evaluator._push(plan) + + assert ( + len( + sushi_context.state_sync.get_snapshots( + [new_model_snapshot, new_view_model_snapshot] + ) + ) + == 2 + ) + assert sushi_context.engine_adapter.table_exists(new_model_snapshot.table_name) + assert sushi_context.engine_adapter.table_exists(new_view_model_snapshot.table_name) + + def test_airflow_evaluator(sushi_plan: Plan, mocker: MockerFixture): airflow_client_mock = mocker.Mock() airflow_client_mock.apply_plan.return_value = "test_plan_receiver_dag_run_id" @@ -39,6 +100,7 @@ def test_airflow_evaluator(sushi_plan: Plan, mocker: MockerFixture): no_gaps=False, notification_targets=[], restatements=set(), + ddl_concurrent_tasks=1, ) assert airflow_client_mock.wait_for_dag_run_completion.call_count == 2 diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index eeae5a3a8e..9ab24b4ce3 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -51,7 +51,7 @@ def test_evaluate(mocker: MockerFixture, make_snapshot): ) snapshot = make_snapshot(model, physical_schema="physical_schema", version="1") - evaluator.create(snapshot, {}) + evaluator.create([snapshot], {}) evaluator.evaluate( snapshot, "2020-01-01", @@ -85,7 +85,7 @@ def test_promote(mocker: MockerFixture, make_snapshot): ) evaluator.promote( - make_snapshot(model, physical_schema="physical_schema", version="1"), + [make_snapshot(model, physical_schema="physical_schema", version="1")], "test_env", ) @@ -102,12 +102,15 @@ def test_promote_model_info(mocker: MockerFixture): evaluator = SnapshotEvaluator(adapter_mock) evaluator.promote( - SnapshotTableInfo( - physical_schema="physical_schema", - name="test_schema.test_model", - fingerprint="1", - version="1", - ), + [ + SnapshotTableInfo( + physical_schema="physical_schema", + name="test_schema.test_model", + fingerprint="1", + version="1", + parents=[], + ) + ], "test_env", ) @@ -124,7 +127,7 @@ def test_evaluate_creation_duckdb( date_kwargs: t.Dict[str, str], ): evaluator = SnapshotEvaluator(EngineAdapter(duck_conn, "duckdb")) - evaluator.create(snapshot, {}) + evaluator.create([snapshot], {}) version = snapshot.version def assert_tables_exist() -> None: diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index 18b73cd847..c8067f8ff5 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -121,6 +121,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot, dag_run_entries: "physical_schema": "physical_schema", "previous_versions": [], "version": snapshot.version, + "parents": [], } ], "start": "2022-01-01", @@ -132,6 +133,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot, dag_run_entries: "notification_targets": [], "request_id": request_id, "restatements": [], + "ddl_concurrent_tasks": 1, }, "dag_run_id": expected_dag_run_id, "logical_date": "2022-08-16T02:40:19.000000Z", diff --git a/tests/schedulers/airflow/test_integration.py b/tests/schedulers/airflow/test_integration.py index 3d64718e71..2f92303a34 100644 --- a/tests/schedulers/airflow/test_integration.py +++ b/tests/schedulers/airflow/test_integration.py @@ -122,6 +122,7 @@ def test_plan_receiver_task(mocker: MockerFixture, make_snapshot, random_name): no_gaps=True, restatements={"raw.items"}, notification_targets=[], + ddl_concurrent_tasks=1, ) deleted_snapshot = SnapshotTableInfo( @@ -129,6 +130,7 @@ def test_plan_receiver_task(mocker: MockerFixture, make_snapshot, random_name): fingerprint="test_fingerprint", version="test_version", physical_schema="test_physical_schema", + parents=[], ) old_environment = Environment( name=environment_name, @@ -166,21 +168,22 @@ def test_plan_receiver_task(mocker: MockerFixture, make_snapshot, random_name): ) == common.PlanApplicationRequest( request_id="test_request_id", environment_name=environment_name, - new_snapshot_batches=[[snapshot]], + new_snapshots=[snapshot], backfill_intervals_per_snapshot=[ common.BackfillIntervalsPerSnapshot( snapshot_id=snapshot.snapshot_id, intervals=[(to_datetime("2022-01-01"), to_datetime("2022-01-02"))], ) ], - promotion_batches=[[snapshot.table_info]], - demotion_batches=[[deleted_snapshot]], + promoted_snapshots=[snapshot.table_info], + demoted_snapshots=[deleted_snapshot], start="2022-01-01", end="2022-01-01", no_gaps=True, plan_id="test_plan_id", previous_plan_id=None, notification_targets=[], + ddl_concurrent_tasks=1, ) @@ -201,6 +204,7 @@ def test_plan_receiver_task_duplicated_snapshot( no_gaps=False, restatements=set(), notification_targets=[], + ddl_concurrent_tasks=1, ) task_instance_mock = mocker.Mock() @@ -242,6 +246,7 @@ def test_plan_receiver_task_unbounded_end( no_gaps=True, restatements={"raw.items"}, notification_targets=[], + ddl_concurrent_tasks=1, ) task_instance_mock = mocker.Mock() diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py new file mode 100644 index 0000000000..1f1d8375e4 --- /dev/null +++ b/tests/utils/test_concurrency.py @@ -0,0 +1,36 @@ +from pytest_mock.plugin import MockerFixture + +from sqlmesh.core.snapshot import SnapshotId +from sqlmesh.utils.concurrency import concurrent_apply_to_snapshots + + +def test_concurrent_apply_to_snapshots(mocker: MockerFixture): + snapshot_a = mocker.Mock() + snapshot_a.snapshot_id = SnapshotId(name="model_a", fingerprint="snapshot_a") + snapshot_a.parents = [] + + snapshot_b = mocker.Mock() + snapshot_b.snapshot_id = SnapshotId(name="model_b", fingerprint="snapshot_b") + snapshot_b.parents = [] + + snapshot_c = mocker.Mock() + snapshot_c.snapshot_id = SnapshotId(name="model_c", fingerprint="snapshot_c") + snapshot_c.parents = [snapshot_a.snapshot_id, snapshot_b.snapshot_id] + + snapshot_d = mocker.Mock() + snapshot_d.snapshot_id = SnapshotId(name="model_d", fingerprint="snapshot_d") + snapshot_d.parents = [snapshot_b.snapshot_id, snapshot_c.snapshot_id] + + processed_snapshots = [] + + concurrent_apply_to_snapshots( + [snapshot_a, snapshot_b, snapshot_c, snapshot_d], + lambda s: processed_snapshots.append(s), + 2, + ) + + assert len(processed_snapshots) == 4 + assert processed_snapshots[0] in (snapshot_a, snapshot_b) + assert processed_snapshots[1] in (snapshot_a, snapshot_b) + assert processed_snapshots[2] == snapshot_c + assert processed_snapshots[3] == snapshot_d