Skip to content

Commit 6fcfb85

Browse files
committed
Extend the built-in scheduler with the ability to run against the persisted state
1 parent 3fee692 commit 6fcfb85

10 files changed

Lines changed: 121 additions & 150 deletions

File tree

sqlmesh/core/context.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,22 @@ def upsert_model(self, model: t.Union[str, Model] = "", **kwargs) -> Model:
264264
self._add_model_to_dag(model)
265265
return model
266266

267-
def scheduler(self) -> Scheduler:
268-
"""The built in scheduler."""
267+
def scheduler(self, use_local_state: bool = True) -> Scheduler:
268+
"""Returns the built-in scheduler.
269+
270+
Args:
271+
use_local_state: Whether to initialize the scheduler from the currently loaded
272+
local state or use the persisted state instead.
273+
274+
Returns:
275+
The built-in scheduler instance.
276+
"""
277+
if use_local_state:
278+
snapshots = {s.snapshot_id: s for s in self.snapshots.values()}
279+
else:
280+
snapshots = self.state_sync.get_snapshots(None)
269281
return Scheduler(
270-
self.snapshots,
282+
snapshots,
271283
self.snapshot_evaluator,
272284
self.state_sync,
273285
console=self.console,
@@ -341,18 +353,18 @@ def run(
341353
start: t.Optional[TimeLike] = None,
342354
end: t.Optional[TimeLike] = None,
343355
latest: t.Optional[TimeLike] = None,
344-
) -> t.Dict[str, str]:
356+
use_local_state: bool = True,
357+
) -> None:
345358
"""Run the entire dag through the scheduler.
346359
347360
Args:
348361
start: The start of the interval to render.
349362
end: The end of the interval to render.
350363
latest: The latest time used for non incremental datasets.
351-
352-
Returns:
353-
Dictionary of stacktraces if errors occur.
364+
use_local_state: If set to True runs against the currently loaded local state,
365+
otherwise uses the persisted state.
354366
"""
355-
return self.scheduler().run(self.snapshots.values(), start, end, latest)
367+
return self.scheduler(use_local_state).run(start, end, latest)
356368

357369
@property
358370
def snapshots(self) -> t.Dict[str, Snapshot]:

sqlmesh/core/plan_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ def evaluate(self, plan: Plan) -> None:
7272
if plan.missing_intervals:
7373
snapshots = plan.snapshots
7474
scheduler = Scheduler(
75-
{snapshot.name: snapshot for snapshot in snapshots},
75+
{snapshot.snapshot_id: snapshot for snapshot in snapshots},
7676
self.snapshot_evaluator,
7777
self.state_sync,
7878
console=self.console,
7979
)
80-
scheduler.run(snapshots, plan.start, plan.end)
80+
scheduler.run(plan.start, plan.end)
8181

8282
self._promote(plan)
8383

sqlmesh/core/scheduler.py

Lines changed: 58 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import logging
44
import traceback
55
import typing as t
6-
from concurrent.futures import Executor, ThreadPoolExecutor, wait
76
from datetime import datetime
8-
from time import sleep
97

108
from sqlmesh.core.console import Console, get_console
9+
from sqlmesh.core.dag import DAG
1110
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdLike
1211
from sqlmesh.core.snapshot_evaluator import SnapshotEvaluator
1312
from sqlmesh.core.state_sync import StateSync
13+
from sqlmesh.utils.concurrency import NodeExecutionFailedError, concurrent_apply_to_dag
1414
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday
1515

1616
logger = logging.getLogger(__name__)
1717
SnapshotBatches = t.List[t.Tuple[Snapshot, t.List[t.Tuple[datetime, datetime]]]]
18+
SchedulingUnit = t.Tuple[SnapshotId, t.Tuple[datetime, datetime]]
1819

1920

2021
class Scheduler:
@@ -36,7 +37,7 @@ class Scheduler:
3637

3738
def __init__(
3839
self,
39-
snapshots: t.Dict[str, Snapshot],
40+
snapshots: t.Dict[SnapshotId, Snapshot],
4041
snapshot_evaluator: SnapshotEvaluator,
4142
state_sync: StateSync,
4243
max_workers: int = 1,
@@ -46,9 +47,6 @@ def __init__(
4647
self.snapshot_evaluator = snapshot_evaluator
4748
self.state_sync = state_sync
4849
self.max_workers = max_workers
49-
self.running: t.Set[str] = set()
50-
self.failed: t.Dict[str, str] = {}
51-
self.finished: t.Set[str] = set()
5250
self.console: Console = console or get_console()
5351

5452
def evaluate(
@@ -68,107 +66,79 @@ def evaluate(
6866
latest: The latest datetime to use for non-incremental queries.
6967
kwargs: Additional kwargs to pass to the renderer.
7068
"""
71-
try:
72-
self.snapshot_evaluator.evaluate(
73-
snapshot,
74-
start,
75-
end,
76-
latest,
77-
snapshots=self.snapshots,
78-
**kwargs,
79-
)
80-
self.state_sync.add_interval(snapshot.snapshot_id, start, end)
81-
self.snapshot_evaluator.audit(
82-
snapshot=snapshot,
83-
start=start,
84-
end=end,
85-
latest=latest,
86-
snapshots=self.snapshots,
87-
**kwargs,
88-
)
89-
self.console.update_snapshot_progress(snapshot.name, 1)
90-
except Exception:
91-
self.failed[snapshot.name] = traceback.format_exc()
69+
70+
mapping = {
71+
**{
72+
p_sid.name: self.snapshots[p_sid].table_name
73+
for p_sid in snapshot.parents
74+
},
75+
snapshot.name: snapshot.table_name,
76+
}
77+
78+
self.snapshot_evaluator.evaluate(
79+
snapshot,
80+
start,
81+
end,
82+
latest,
83+
mapping=mapping,
84+
**kwargs,
85+
)
86+
self.state_sync.add_interval(snapshot.snapshot_id, start, end)
87+
self.snapshot_evaluator.audit(
88+
snapshot=snapshot,
89+
start=start,
90+
end=end,
91+
latest=latest,
92+
mapping=mapping,
93+
**kwargs,
94+
)
95+
self.console.update_snapshot_progress(snapshot.name, 1)
9296

9397
def run(
9498
self,
95-
snapshots: t.Iterable[Snapshot],
9699
start: t.Optional[TimeLike] = None,
97100
end: t.Optional[TimeLike] = None,
98101
latest: t.Optional[TimeLike] = None,
99-
) -> t.Dict[str, str]:
102+
) -> None:
100103
"""Concurrently runs all snapshots in topological order.
101104
102105
Args:
103-
snapshots: An iterable of all the snapshots to run.
104106
start: The start of the run. Defaults to the min model start date.
105107
end: The end of the run. Defaults to now.
106108
latest: The latest datetime to use for non-incremental queries.
107-
108-
Returns:
109-
A dictionary of model name to error string.
110109
"""
111-
snapshots = tuple(snapshots)
112110
latest = latest or now()
113-
batches = self.interval_params(snapshots, start, end, latest)
111+
batches = self.interval_params(self.snapshots.values(), start, end, latest)
114112

115-
self.running.clear()
116-
self.finished.clear()
117-
self.failed.clear()
118-
dag = []
113+
intervals_per_snapshot_id = {
114+
snapshot.snapshot_id: intervals for snapshot, intervals in batches
115+
}
119116

117+
dag = DAG[SchedulingUnit]()
120118
for snapshot, intervals in batches:
121-
dag.append(
122-
(
123-
snapshot,
124-
intervals,
125-
{
126-
table
127-
for table in snapshot.model.depends_on
128-
if table in self.snapshots
129-
},
130-
)
131-
)
119+
upstream_dependencies = [
120+
(p_sid, interval)
121+
for p_sid in snapshot.parents
122+
for interval in intervals_per_snapshot_id.get(p_sid, [])
123+
]
124+
sid = snapshot.snapshot_id
125+
for interval in intervals:
126+
dag.add((sid, interval), upstream_dependencies)
132127

133-
for snapshot, intervals, _ in dag[::-1]:
134-
if not intervals:
135-
continue
136-
# We have to run all batches per snapshot to mark it as completed
137-
self.console.start_snapshot_progress(snapshot.name, len(intervals))
138-
139-
with self.snapshot_evaluator.multithreaded_context(), ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor(
140-
max_workers=self.max_workers
141-
) as batch_pool:
142-
while True:
143-
if self.failed:
144-
for model_name, error_message in self.failed.items():
145-
self.console.log_error(
146-
f"Failed Executing Batch.\nModel name:{model_name}\n{error_message}"
147-
)
148-
snapshot_pool.shutdown()
149-
batch_pool.shutdown()
150-
break
151-
if self.finished >= {snapshot.name for snapshot, _, _ in dag}:
152-
break
153-
processed = self.running | self.finished
154-
for snapshot, intervals, deps in dag:
155-
if snapshot.name not in processed and self.finished >= deps:
156-
self.running.add(snapshot.name)
157-
snapshot_pool.submit(
158-
self._run_snapshot_intervals,
159-
snapshot,
160-
intervals,
161-
latest,
162-
batch_pool,
163-
)
164-
sleep(0.1)
165-
166-
if self.failed:
167-
self.console.stop_snapshot_progress()
168-
else:
169-
self.console.complete_snapshot_progress()
128+
def evaluate_node(node: SchedulingUnit) -> None:
129+
assert latest
130+
sid, (start, end) = node
131+
self.evaluate(self.snapshots[sid], start, end, latest)
170132

171-
return self.failed
133+
try:
134+
with self.snapshot_evaluator.multithreaded_context():
135+
concurrent_apply_to_dag(dag, evaluate_node, self.max_workers)
136+
except NodeExecutionFailedError as error:
137+
sid = error.node[0] # type: ignore
138+
self.console.log_error(
139+
f"Failed Executing Batch.\nSnapshot: {sid}\n{traceback.format_exc()}"
140+
)
141+
raise
172142

173143
def interval_params(
174144
self,
@@ -208,22 +178,6 @@ def interval_params(
208178
latest=latest or now(),
209179
)
210180

211-
def _run_snapshot_intervals(
212-
self,
213-
snapshot: Snapshot,
214-
intervals: t.List[t.Tuple[datetime, datetime]],
215-
latest: TimeLike,
216-
pool: Executor,
217-
) -> None:
218-
wait(
219-
[
220-
pool.submit(self.evaluate, snapshot, start, end, latest)
221-
for start, end in intervals
222-
],
223-
)
224-
self.finished.add(snapshot.name)
225-
self.running.remove(snapshot.name)
226-
227181

228182
def compute_interval_params(
229183
target: t.Iterable[SnapshotIdLike],

sqlmesh/core/snapshot_evaluator.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ def evaluate(
6060
start: TimeLike,
6161
end: TimeLike,
6262
latest: TimeLike,
63+
mapping: t.Dict[str, str],
6364
limit: int = 0,
64-
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
65-
mapping: t.Optional[t.Dict[str, str]] = None,
6665
**kwargs,
6766
) -> t.Optional[DF]:
6867
"""Evaluate a snapshot, creating its schema and table if it doesn't exist and then inserting it.
@@ -72,7 +71,6 @@ def evaluate(
7271
start: The start datetime to render.
7372
end: The end datetime to render.
7473
latest: The latest datetime to use for non-incremental queries.
75-
snapshots: All snapshots to use for mapping of physical locations.
7674
mapping: Mapping of model references to physical snapshots.
7775
limit: If limit is >= 0, the query will not be persisted but evaluated and returned
7876
as a dataframe.
@@ -86,10 +84,6 @@ def evaluate(
8684
for sql_statement in model.sql_statements:
8785
self.adapter.execute(sql_statement)
8886

89-
mapping = mapping or {
90-
name: snapshot.table_name for name, snapshot in (snapshots or {}).items()
91-
}
92-
9387
if model.is_sql:
9488
query_or_df = model.render_query(
9589
start=start,
@@ -212,7 +206,6 @@ def audit(
212206
start: t.Optional[TimeLike] = None,
213207
end: t.Optional[TimeLike] = None,
214208
latest: t.Optional[TimeLike] = None,
215-
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
216209
mapping: t.Optional[t.Dict[str, str]] = None,
217210
raise_exception: bool = True,
218211
**kwargs,
@@ -223,7 +216,6 @@ def audit(
223216
snapshot: Snapshot to evaluate. start: The start datetime to audit. Defaults to epoch start.
224217
end: The end datetime to audit. Defaults to epoch start.
225218
latest: The latest datetime to use for non-incremental queries. Defaults to epoch start.
226-
snapshots: All snapshots to use for mapping of physical locations.
227219
mapping: Mapping of model references to physical snapshots.
228220
collection_exceptions:
229221
kwargs: Additional kwargs to pass to the renderer.
@@ -233,7 +225,6 @@ def audit(
233225
start=start,
234226
end=end,
235227
latest=latest,
236-
snapshots=snapshots,
237228
mapping=mapping,
238229
**kwargs,
239230
):

sqlmesh/core/state_sync.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@ class StateReader(abc.ABC):
4646

4747
@abc.abstractmethod
4848
def get_snapshots(
49-
self, snapshot_ids: t.Iterable[SnapshotIdLike]
49+
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
5050
) -> t.Dict[SnapshotId, Snapshot]:
5151
"""Bulk fetch snapshots given the corresponding snapshot ids.
5252
5353
Args:
54-
snapshot_ids: Iterable of snapshot ids to get.
54+
snapshot_ids: Iterable of snapshot ids to get. If not provided all
55+
available snapshots will be returned.
5556
5657
Returns:
5758
A dictionary of snapshot ids to snapshots for ones that could be found.
@@ -312,7 +313,7 @@ def remove_expired_snapshots(self) -> t.List[Snapshot]:
312313

313314
class CommonStateSyncMixin(StateSync):
314315
def get_snapshots(
315-
self, snapshot_ids: t.Iterable[SnapshotIdLike]
316+
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
316317
) -> t.Dict[SnapshotId, Snapshot]:
317318
return self._get_snapshots(snapshot_ids)
318319

@@ -766,7 +767,7 @@ def remove_expired_snapshots(self) -> t.List[Snapshot]:
766767
return expired_snapshots
767768

768769
def get_snapshots(
769-
self, snapshot_ids: t.Iterable[SnapshotIdLike]
770+
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
770771
) -> t.Dict[SnapshotId, Snapshot]:
771772
snapshots = super().get_snapshots(snapshot_ids)
772773
self._update_cache(snapshots.values())

sqlmesh/schedulers/airflow/state_sync/http.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def get_environments(self) -> t.List[Environment]:
7070
)
7171

7272
def get_snapshots(
73-
self, snapshot_ids: t.Iterable[SnapshotIdLike]
73+
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
7474
) -> t.Dict[SnapshotId, Snapshot]:
7575
"""Gets multiple snapshots from the rest api.
7676
@@ -79,7 +79,9 @@ def get_snapshots(
7979
call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects
8080
on the production server.
8181
"""
82-
snapshot_ids = list(snapshot_ids)
82+
snapshot_ids = (
83+
list(snapshot_ids) if snapshot_ids else self._client.get_snapshot_ids()
84+
)
8385
if len(snapshot_ids) > 1:
8486
logger.warning(
8587
"Fetching multiple snapshots from Airflow using the REST API is inefficient and not recommended"

0 commit comments

Comments
 (0)