Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,23 @@ def upsert_model(self, model: t.Union[str, Model] = "", **kwargs) -> Model:
self._add_model_to_dag(model)
return model

def scheduler(self) -> Scheduler:
"""The built in scheduler."""
def scheduler(self, global_state: bool = False) -> Scheduler:
"""Returns the built-in scheduler.

Args:
global_state: Whether to initialize the scheduler from the persisted state
or from the currently loaded local state. Default: False.

Returns:
The built-in scheduler instance.
"""
snapshots: t.Iterable[Snapshot]
if global_state:
snapshots = self.state_sync.get_snapshots(None).values()
else:
snapshots = self.snapshots.values()
return Scheduler(
self.snapshots,
snapshots,
self.snapshot_evaluator,
self.state_sync,
console=self.console,
Expand Down Expand Up @@ -341,18 +354,18 @@ def run(
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
latest: t.Optional[TimeLike] = None,
) -> t.Dict[str, str]:
global_state: bool = False,
) -> None:
"""Run the entire dag through the scheduler.

Args:
start: The start of the interval to render.
end: The end of the interval to render.
latest: The latest time used for non incremental datasets.

Returns:
Dictionary of stacktraces if errors occur.
global_state: If set to True runs against the persisted state,
otherwise uses the currently loaded local state. Default: False.
"""
return self.scheduler().run(self.snapshots.values(), start, end, latest)
return self.scheduler(global_state).run(start, end, latest)

@property
def snapshots(self) -> t.Dict[str, Snapshot]:
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/plan_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ def evaluate(self, plan: Plan) -> None:
if plan.missing_intervals:
snapshots = plan.snapshots
scheduler = Scheduler(
{snapshot.name: snapshot for snapshot in snapshots},
snapshots,
self.snapshot_evaluator,
self.state_sync,
console=self.console,
)
scheduler.run(snapshots, plan.start, plan.end)
scheduler.run(plan.start, plan.end)

self._promote(plan)

Expand Down
166 changes: 60 additions & 106 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import logging
import traceback
import typing as t
from concurrent.futures import Executor, ThreadPoolExecutor, wait
from datetime import datetime
from time import sleep

from sqlmesh.core.console import Console, get_console
from sqlmesh.core.dag import DAG
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdLike
from sqlmesh.core.snapshot_evaluator import SnapshotEvaluator
from sqlmesh.core.state_sync import StateSync
from sqlmesh.utils.concurrency import NodeExecutionFailedError, concurrent_apply_to_dag
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday

logger = logging.getLogger(__name__)
SnapshotBatches = t.List[t.Tuple[Snapshot, t.List[t.Tuple[datetime, datetime]]]]
SchedulingUnit = t.Tuple[SnapshotId, t.Tuple[datetime, datetime]]


class Scheduler:
Expand All @@ -27,7 +28,7 @@ class Scheduler:
The scheduler comes equipped with a simple ThreadPoolExecutor based evaluation engine.

Args:
snapshots: A dictionary of all snapshots.
snapshots: A collection of snapshots.
snapshot_evaluator: The snapshot evaluator to execute queries.
state_sync: The state sync to pull saved snapshots.
max_workers: The maximum number of parallel queries to run.
Expand All @@ -36,19 +37,16 @@ class Scheduler:

def __init__(
self,
snapshots: t.Dict[str, Snapshot],
snapshots: t.Iterable[Snapshot],
snapshot_evaluator: SnapshotEvaluator,
state_sync: StateSync,
max_workers: int = 1,
console: t.Optional[Console] = None,
):
self.snapshots = snapshots
self.snapshots = {s.snapshot_id: s for s in snapshots}
self.snapshot_evaluator = snapshot_evaluator
self.state_sync = state_sync
self.max_workers = max_workers
self.running: t.Set[str] = set()
self.failed: t.Dict[str, str] = {}
self.finished: t.Set[str] = set()
self.console: Console = console or get_console()

def evaluate(
Expand All @@ -68,107 +66,79 @@ def evaluate(
latest: The latest datetime to use for non-incremental queries.
kwargs: Additional kwargs to pass to the renderer.
"""
try:
self.snapshot_evaluator.evaluate(
snapshot,
start,
end,
latest,
snapshots=self.snapshots,
**kwargs,
)
self.state_sync.add_interval(snapshot.snapshot_id, start, end)
self.snapshot_evaluator.audit(
snapshot=snapshot,
start=start,
end=end,
latest=latest,
snapshots=self.snapshots,
**kwargs,
)
self.console.update_snapshot_progress(snapshot.name, 1)
except Exception:
self.failed[snapshot.name] = traceback.format_exc()

mapping = {
Comment thread
tobymao marked this conversation as resolved.
**{
p_sid.name: self.snapshots[p_sid].table_name
for p_sid in snapshot.parents
},
snapshot.name: snapshot.table_name,
}

self.snapshot_evaluator.evaluate(
snapshot,
start,
end,
latest,
mapping=mapping,
**kwargs,
)
self.state_sync.add_interval(snapshot.snapshot_id, start, end)
self.snapshot_evaluator.audit(
snapshot=snapshot,
start=start,
end=end,
latest=latest,
mapping=mapping,
**kwargs,
)
self.console.update_snapshot_progress(snapshot.name, 1)

def run(
self,
snapshots: t.Iterable[Snapshot],
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
latest: t.Optional[TimeLike] = None,
) -> t.Dict[str, str]:
) -> None:
"""Concurrently runs all snapshots in topological order.

Args:
snapshots: An iterable of all the snapshots to run.
start: The start of the run. Defaults to the min model start date.
end: The end of the run. Defaults to now.
latest: The latest datetime to use for non-incremental queries.

Returns:
A dictionary of model name to error string.
"""
snapshots = tuple(snapshots)
latest = latest or now()
batches = self.interval_params(snapshots, start, end, latest)
batches = self.interval_params(self.snapshots.values(), start, end, latest)

self.running.clear()
self.finished.clear()
self.failed.clear()
dag = []
intervals_per_snapshot_id = {
snapshot.snapshot_id: intervals for snapshot, intervals in batches
}

dag = DAG[SchedulingUnit]()
for snapshot, intervals in batches:
dag.append(
(
snapshot,
intervals,
{
table
for table in snapshot.model.depends_on
if table in self.snapshots
},
)
)
upstream_dependencies = [
(p_sid, interval)
for p_sid in snapshot.parents
for interval in intervals_per_snapshot_id.get(p_sid, [])
]
sid = snapshot.snapshot_id
for interval in intervals:
dag.add((sid, interval), upstream_dependencies)

for snapshot, intervals, _ in dag[::-1]:
if not intervals:
continue
# We have to run all batches per snapshot to mark it as completed
self.console.start_snapshot_progress(snapshot.name, len(intervals))

with self.snapshot_evaluator.multithreaded_context(), ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor(
max_workers=self.max_workers
) as batch_pool:
while True:
if self.failed:
for model_name, error_message in self.failed.items():
self.console.log_error(
f"Failed Executing Batch.\nModel name:{model_name}\n{error_message}"
)
snapshot_pool.shutdown()
batch_pool.shutdown()
break
if self.finished >= {snapshot.name for snapshot, _, _ in dag}:
break
processed = self.running | self.finished
for snapshot, intervals, deps in dag:
if snapshot.name not in processed and self.finished >= deps:
self.running.add(snapshot.name)
snapshot_pool.submit(
self._run_snapshot_intervals,
snapshot,
intervals,
latest,
batch_pool,
)
sleep(0.1)

if self.failed:
self.console.stop_snapshot_progress()
else:
self.console.complete_snapshot_progress()
def evaluate_node(node: SchedulingUnit) -> None:
assert latest
sid, (start, end) = node
self.evaluate(self.snapshots[sid], start, end, latest)

return self.failed
try:
with self.snapshot_evaluator.concurrent_context():
concurrent_apply_to_dag(dag, evaluate_node, self.max_workers)
except NodeExecutionFailedError as error:
sid = error.node[0] # type: ignore
self.console.log_error(
f"Failed Executing Batch.\nSnapshot: {sid}\n{traceback.format_exc()}"
)
raise

def interval_params(
self,
Expand Down Expand Up @@ -208,22 +178,6 @@ def interval_params(
latest=latest or now(),
)

def _run_snapshot_intervals(
self,
snapshot: Snapshot,
intervals: t.List[t.Tuple[datetime, datetime]],
latest: TimeLike,
pool: Executor,
) -> None:
wait(
[
pool.submit(self.evaluate, snapshot, start, end, latest)
for start, end in intervals
],
)
self.finished.add(snapshot.name)
self.running.remove(snapshot.name)


def compute_interval_params(
target: t.Iterable[SnapshotIdLike],
Expand Down
Loading