Skip to content

Commit 213699d

Browse files
authored
Respect dependencies when creating tables / views for new snapshots (#23)
1 parent 7827cbf commit 213699d

26 files changed

Lines changed: 479 additions & 216 deletions

example/models/top_waiters.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/* View of top waiters. */
22
MODEL (
33
name sushi.top_waiters,
4-
kind full,
4+
kind view,
55
owner jen
66
);
77

sqlmesh/core/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def create_plan_evaluator(self, context: Context) -> PlanEvaluator:
217217
dag_creation_max_retry_attempts=self.dag_creation_max_retry_attempts,
218218
console=context.console,
219219
notification_targets=context.notification_targets,
220+
ddl_concurrent_tasks=context.ddl_concurrent_tasks,
220221
)
221222

222223

@@ -286,6 +287,8 @@ class Config(PydanticModel):
286287
physical_schema: The default schema used to store materialized tables.
287288
snapshot_ttl: Duration before unpromoted snapshots are removed.
288289
time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d.
290+
ddl_concurrent_task: The number of concurrent tasks used for DDL
291+
operations (table / view creation, deletion, etc). Default: 1.
289292
"""
290293

291294
engine_adapter: EngineAdapter = Field(
@@ -298,6 +301,7 @@ class Config(PydanticModel):
298301
snapshot_ttl: str = ""
299302
ignore_patterns: t.List[str] = []
300303
time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT
304+
ddl_concurrent_tasks: int = 1
301305

302306
class Config:
303307
arbitrary_types_allowed = True

sqlmesh/core/context.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class Context:
8282
physical_schema: The schema used to store physical materialized tables.
8383
snapshot_ttl: Duration before unpromoted snapshots are removed.
8484
path: The directory containing SQLMesh files.
85+
ddl_concurrent_task: The number of concurrent tasks used for DDL
86+
operations (table / view creation, deletion, etc). Default: 1.
8587
config: A Config object or the name of a Config object in config.py.
8688
test_config: A Config object or name of a Config object in config.py to use for testing only
8789
load: Whether or not to automatically load all models and macros (default True).
@@ -97,6 +99,7 @@ def __init__(
9799
physical_schema: str = "",
98100
snapshot_ttl: str = "",
99101
path: str = "",
102+
ddl_concurrent_tasks: t.Optional[int] = None,
100103
config: t.Optional[t.Union[Config, str]] = None,
101104
test_config: t.Optional[t.Union[Config, str]] = None,
102105
load: bool = True,
@@ -120,7 +123,12 @@ def __init__(
120123
self.macros = UniqueKeyDict("macros")
121124
self.dag: DAG[str] = DAG()
122125
self.engine_adapter = engine_adapter or self.config.engine_adapter
123-
self.snapshot_evaluator = SnapshotEvaluator(self.engine_adapter)
126+
self.ddl_concurrent_tasks = (
127+
ddl_concurrent_tasks or self.config.ddl_concurrent_tasks
128+
)
129+
self.snapshot_evaluator = SnapshotEvaluator(
130+
self.engine_adapter, ddl_concurrent_tasks=self.ddl_concurrent_tasks
131+
)
124132
self._ignore_patterns = c.IGNORE_PATTERNS + self.config.ignore_patterns
125133
self.console = console or get_console()
126134

@@ -705,5 +713,4 @@ def _glob_path(
705713
def _add_model_to_dag(self, model: Model) -> None:
706714
self.dag.graph[model.name] = set()
707715

708-
for table in model.depends_on:
709-
self.dag.add(model.name, table)
716+
self.dag.add(model.name, model.depends_on)

sqlmesh/core/dag.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@ class DAG(t.Generic[T]):
1616
def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
1717
self.graph = graph or {}
1818

19-
def add(self, node: T, dependency: t.Optional[T] = None) -> None:
19+
def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None:
2020
"""Add a node to the graph with an optional upstream dependency.
2121
2222
Args:
2323
node: The node to add.
24-
dependency: An optional dependency to add to the node.
24+
dependencies: Optional dependencies to add to the node.
2525
"""
2626
if node not in self.graph:
2727
self.graph[node] = set()
28-
if dependency:
29-
self.graph[node].add(dependency)
28+
if dependencies:
29+
self.graph[node].update(dependencies)
3030

3131
def subdag(self, *nodes: T) -> DAG[T]:
3232
"""Create a new subdag given node(s).
@@ -50,7 +50,7 @@ def subdag(self, *nodes: T) -> DAG[T]:
5050

5151
def upstream(self, node: T) -> t.List[T]:
5252
"""Returns all upstream dependencies in topologically sorted order."""
53-
return self.subdag(node).sort()[:-1]
53+
return self.subdag(node).sorted()[:-1]
5454

5555
@property
5656
def leaves(self) -> t.Set[T]:
@@ -59,7 +59,7 @@ def leaves(self) -> t.Set[T]:
5959
dep for deps in self.graph.values() for dep in deps if dep not in self.graph
6060
}
6161

62-
def sort(self) -> t.List[T]:
62+
def sorted(self) -> t.List[T]:
6363
"""Topologically sort the graph.
6464
6565
Returns:
@@ -98,7 +98,7 @@ def downstream(self, node: T) -> t.List[T]:
9898
Returns:
9999
A list of descendant nodes sorted in topological order.
100100
"""
101-
sorted_nodes = self.sort()
101+
sorted_nodes = self.sorted()
102102
try:
103103
node_index = sorted_nodes.index(node)
104104
except ValueError:

sqlmesh/core/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -814,9 +814,9 @@ def _render_query(
814814
end: The end datetime to render. Defaults to epoch start.
815815
latest: The latest datetime to use for non-incremental queries. Defaults to epoch start.
816816
add_incremental_filter: Add an incremental filter to the query if the model is incremental.
817-
snapshots: All snapshots to use for expansion and mapping of physical locations.
817+
snapshots: All upstream snapshots to use for expansion and mapping of physical locations.
818818
If passing snapshots is undesirable, mapping can be used instead to manually map tables.
819-
mapping: Mapping to replace table names, if not set, the mapping wil be created from snapshots.
819+
mapping: Mapping to replace table names, if not set, the mapping will be created from snapshots.
820820
expand: Expand referenced models as subqueries. This is used to bypass backfills when running queries
821821
that depend on materialized tables. Model definitions are inlined and can thus be run end to
822822
end on the fly.
@@ -913,7 +913,7 @@ def ctas_query(self, snapshots: t.Dict[str, Snapshot]) -> exp.Subqueryable:
913913
SELECTS and hopefully the optimizer is smart enough to not do anything.
914914
915915
Args:
916-
All upstream snapshots of this model so queries can be expanded.
916+
snapshots: All upstream snapshots of this model so queries can be expanded.
917917
Return:
918918
The mocked out ctas query.
919919
"""

sqlmesh/core/plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def snapshot_change_category(self, snapshot: Snapshot) -> SnapshotChangeCategory
288288
def info_cache(self) -> InfoCache:
289289
"""Returns the info cache of categorized, uncategorized snapshots."""
290290
if self._info_cache is None:
291-
queue = deque(self._dag.sort())
291+
queue = deque(self._dag.sorted())
292292
snapshots = []
293293
all_indirectly_modified = set()
294294

sqlmesh/core/plan_evaluator.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,18 @@ def _push(self, plan: Plan) -> None:
9393
Args:
9494
plan: The plan to source snapshots from.
9595
"""
96-
snapshots = {snapshot.name: snapshot for snapshot in plan.new_snapshots}
97-
self.state_sync.push_snapshots(snapshots.values())
98-
for snapshot in snapshots.values():
99-
self.snapshot_evaluator.create(snapshot, snapshots)
96+
parent_snapshot_ids = {
97+
p_sid for snapshot in plan.new_snapshots for p_sid in snapshot.parents
98+
}
99+
100+
stored_snapshots_by_id = self.state_sync.get_snapshots(parent_snapshot_ids)
101+
new_snapshots_by_id = {
102+
snapshot.snapshot_id: snapshot for snapshot in plan.new_snapshots
103+
}
104+
all_snapshots_by_id = {**stored_snapshots_by_id, **new_snapshots_by_id}
105+
106+
self.snapshot_evaluator.create(plan.new_snapshots, all_snapshots_by_id)
107+
self.state_sync.push_snapshots(plan.new_snapshots)
100108

101109
def _promote(self, plan: Plan) -> None:
102110
"""Promote a plan.
@@ -110,16 +118,14 @@ def _promote(self, plan: Plan) -> None:
110118

111119
added, removed = self.state_sync.promote(environment, no_gaps=plan.no_gaps)
112120

113-
for snapshot_table_info in added:
114-
self.snapshot_evaluator.promote(
115-
snapshot_table_info,
116-
environment=environment.name,
117-
)
118-
for snapshot_table_info in removed:
119-
self.snapshot_evaluator.demote(
120-
snapshot_table_info,
121-
environment=environment.name,
122-
)
121+
self.snapshot_evaluator.promote(
122+
added,
123+
environment=environment.name,
124+
)
125+
self.snapshot_evaluator.demote(
126+
removed,
127+
environment=environment.name,
128+
)
123129

124130

125131
class AirflowPlanEvaluator(PlanEvaluator):
@@ -132,6 +138,7 @@ def __init__(
132138
dag_creation_poll_interval_secs: int = 30,
133139
dag_creation_max_retry_attempts: int = 10,
134140
notification_targets: t.Optional[t.List[NotificationTarget]] = None,
141+
ddl_concurrent_tasks: int = 1,
135142
):
136143
self.airflow_client = airflow_client
137144
self.blocking = blocking
@@ -140,6 +147,7 @@ def __init__(
140147
self.dag_creation_max_retry_attempts = dag_creation_max_retry_attempts
141148
self.console = console or get_console()
142149
self.notification_targets = notification_targets or []
150+
self.ddl_concurrent_tasks = ddl_concurrent_tasks
143151

144152
def evaluate(self, plan: Plan) -> None:
145153
environment = plan.environment
@@ -153,6 +161,7 @@ def evaluate(self, plan: Plan) -> None:
153161
no_gaps=plan.no_gaps,
154162
restatements=plan.restatements,
155163
notification_targets=self.notification_targets,
164+
ddl_concurrent_tasks=self.ddl_concurrent_tasks,
156165
)
157166

158167
if self.blocking:

sqlmesh/core/snapshot.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class SnapshotTableInfo(PydanticModel, SnapshotInfoMixin, frozen=True):
168168
fingerprint: str
169169
version: str
170170
physical_schema: str
171+
parents: t.Tuple[SnapshotId, ...]
171172
previous_versions: t.Tuple[SnapshotDataVersion, ...] = ()
172173
change_category: t.Optional[SnapshotChangeCategory]
173174

@@ -233,7 +234,7 @@ class Snapshot(PydanticModel, SnapshotInfoMixin):
233234
fingerprint: str
234235
physical_schema: str
235236
model: Model
236-
parents: t.List[SnapshotId]
237+
parents: t.Tuple[SnapshotId, ...]
237238
intervals: Intervals
238239
created_ts: int
239240
updated_ts: int
@@ -327,7 +328,7 @@ def from_model(
327328
),
328329
physical_schema=physical_schema,
329330
model=model,
330-
parents=[
331+
parents=tuple(
331332
SnapshotId(
332333
name=name,
333334
fingerprint=fingerprint_from_model(
@@ -338,7 +339,7 @@ def from_model(
338339
),
339340
)
340341
for name in _parents_from_model(model, models)
341-
],
342+
),
342343
intervals=[],
343344
created_ts=created_ts,
344345
updated_ts=created_ts,
@@ -509,6 +510,7 @@ def table_info(self) -> SnapshotTableInfo:
509510
name=self.name,
510511
fingerprint=self.fingerprint,
511512
version=self.version,
513+
parents=self.parents,
512514
previous_versions=self.previous_versions,
513515
change_category=self.change_category,
514516
)

0 commit comments

Comments
 (0)