Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Extend the built-in scheduler with the ability to run against the per…
…sisted state
  • Loading branch information
izeigerman committed Dec 9, 2022
commit 6fcfb85ca9e1b023004ebacfbe13cfaf8766e207
28 changes: 20 additions & 8 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,22 @@ def upsert_model(self, model: t.Union[str, Model] = "", **kwargs) -> Model:
self._add_model_to_dag(model)
return model

def scheduler(self) -> Scheduler:
"""The built in scheduler."""
def scheduler(self, use_local_state: bool = True) -> Scheduler:
Comment thread
izeigerman marked this conversation as resolved.
Outdated
"""Returns the built-in scheduler.

Args:
use_local_state: Whether to initialize the scheduler from the currently loaded
local state or use the persisted state instead.

Returns:
The built-in scheduler instance.
"""
if use_local_state:
snapshots = {s.snapshot_id: s for s in self.snapshots.values()}
else:
snapshots = self.state_sync.get_snapshots(None)
return Scheduler(
self.snapshots,
snapshots,
self.snapshot_evaluator,
self.state_sync,
console=self.console,
Expand Down Expand Up @@ -341,18 +353,18 @@ def run(
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
latest: t.Optional[TimeLike] = None,
) -> t.Dict[str, str]:
use_local_state: bool = True,
) -> None:
"""Run the entire dag through the scheduler.

Args:
start: The start of the interval to render.
end: The end of the interval to render.
latest: The latest time used for non incremental datasets.

Returns:
Dictionary of stacktraces if errors occur.
use_local_state: If set to True runs against the currently loaded local state,
otherwise uses the persisted state.
"""
return self.scheduler().run(self.snapshots.values(), start, end, latest)
return self.scheduler(use_local_state).run(start, end, latest)

@property
def snapshots(self) -> t.Dict[str, Snapshot]:
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/plan_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ def evaluate(self, plan: Plan) -> None:
if plan.missing_intervals:
snapshots = plan.snapshots
scheduler = Scheduler(
{snapshot.name: snapshot for snapshot in snapshots},
{snapshot.snapshot_id: snapshot for snapshot in snapshots},
self.snapshot_evaluator,
self.state_sync,
console=self.console,
)
scheduler.run(snapshots, plan.start, plan.end)
scheduler.run(plan.start, plan.end)

self._promote(plan)

Expand Down
162 changes: 58 additions & 104 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import logging
import traceback
import typing as t
from concurrent.futures import Executor, ThreadPoolExecutor, wait
from datetime import datetime
from time import sleep

from sqlmesh.core.console import Console, get_console
from sqlmesh.core.dag import DAG
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdLike
from sqlmesh.core.snapshot_evaluator import SnapshotEvaluator
from sqlmesh.core.state_sync import StateSync
from sqlmesh.utils.concurrency import NodeExecutionFailedError, concurrent_apply_to_dag
from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday

logger = logging.getLogger(__name__)
SnapshotBatches = t.List[t.Tuple[Snapshot, t.List[t.Tuple[datetime, datetime]]]]
SchedulingUnit = t.Tuple[SnapshotId, t.Tuple[datetime, datetime]]


class Scheduler:
Expand All @@ -36,7 +37,7 @@ class Scheduler:

def __init__(
self,
snapshots: t.Dict[str, Snapshot],
snapshots: t.Dict[SnapshotId, Snapshot],
Comment thread
izeigerman marked this conversation as resolved.
Outdated
snapshot_evaluator: SnapshotEvaluator,
state_sync: StateSync,
max_workers: int = 1,
Expand All @@ -46,9 +47,6 @@ def __init__(
self.snapshot_evaluator = snapshot_evaluator
self.state_sync = state_sync
self.max_workers = max_workers
self.running: t.Set[str] = set()
self.failed: t.Dict[str, str] = {}
self.finished: t.Set[str] = set()
self.console: Console = console or get_console()

def evaluate(
Expand All @@ -68,107 +66,79 @@ def evaluate(
latest: The latest datetime to use for non-incremental queries.
kwargs: Additional kwargs to pass to the renderer.
"""
try:
self.snapshot_evaluator.evaluate(
snapshot,
start,
end,
latest,
snapshots=self.snapshots,
**kwargs,
)
self.state_sync.add_interval(snapshot.snapshot_id, start, end)
self.snapshot_evaluator.audit(
snapshot=snapshot,
start=start,
end=end,
latest=latest,
snapshots=self.snapshots,
**kwargs,
)
self.console.update_snapshot_progress(snapshot.name, 1)
except Exception:
self.failed[snapshot.name] = traceback.format_exc()

mapping = {
Comment thread
tobymao marked this conversation as resolved.
**{
p_sid.name: self.snapshots[p_sid].table_name
for p_sid in snapshot.parents
},
snapshot.name: snapshot.table_name,
}

self.snapshot_evaluator.evaluate(
snapshot,
start,
end,
latest,
mapping=mapping,
**kwargs,
)
self.state_sync.add_interval(snapshot.snapshot_id, start, end)
self.snapshot_evaluator.audit(
snapshot=snapshot,
start=start,
end=end,
latest=latest,
mapping=mapping,
**kwargs,
)
self.console.update_snapshot_progress(snapshot.name, 1)

def run(
self,
snapshots: t.Iterable[Snapshot],
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
latest: t.Optional[TimeLike] = None,
) -> t.Dict[str, str]:
) -> None:
"""Concurrently runs all snapshots in topological order.

Args:
snapshots: An iterable of all the snapshots to run.
start: The start of the run. Defaults to the min model start date.
end: The end of the run. Defaults to now.
latest: The latest datetime to use for non-incremental queries.

Returns:
A dictionary of model name to error string.
"""
snapshots = tuple(snapshots)
latest = latest or now()
batches = self.interval_params(snapshots, start, end, latest)
batches = self.interval_params(self.snapshots.values(), start, end, latest)

self.running.clear()
self.finished.clear()
self.failed.clear()
dag = []
intervals_per_snapshot_id = {
snapshot.snapshot_id: intervals for snapshot, intervals in batches
}

dag = DAG[SchedulingUnit]()
for snapshot, intervals in batches:
dag.append(
(
snapshot,
intervals,
{
table
for table in snapshot.model.depends_on
if table in self.snapshots
},
)
)
upstream_dependencies = [
(p_sid, interval)
for p_sid in snapshot.parents
for interval in intervals_per_snapshot_id.get(p_sid, [])
]
sid = snapshot.snapshot_id
for interval in intervals:
dag.add((sid, interval), upstream_dependencies)

for snapshot, intervals, _ in dag[::-1]:
if not intervals:
continue
# We have to run all batches per snapshot to mark it as completed
self.console.start_snapshot_progress(snapshot.name, len(intervals))

with self.snapshot_evaluator.multithreaded_context(), ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor(
max_workers=self.max_workers
) as batch_pool:
while True:
if self.failed:
for model_name, error_message in self.failed.items():
self.console.log_error(
f"Failed Executing Batch.\nModel name:{model_name}\n{error_message}"
)
snapshot_pool.shutdown()
batch_pool.shutdown()
break
if self.finished >= {snapshot.name for snapshot, _, _ in dag}:
break
processed = self.running | self.finished
for snapshot, intervals, deps in dag:
if snapshot.name not in processed and self.finished >= deps:
self.running.add(snapshot.name)
snapshot_pool.submit(
self._run_snapshot_intervals,
snapshot,
intervals,
latest,
batch_pool,
)
sleep(0.1)

if self.failed:
self.console.stop_snapshot_progress()
else:
self.console.complete_snapshot_progress()
def evaluate_node(node: SchedulingUnit) -> None:
assert latest
sid, (start, end) = node
self.evaluate(self.snapshots[sid], start, end, latest)

return self.failed
try:
with self.snapshot_evaluator.multithreaded_context():
Comment thread
izeigerman marked this conversation as resolved.
Outdated
concurrent_apply_to_dag(dag, evaluate_node, self.max_workers)
except NodeExecutionFailedError as error:
sid = error.node[0] # type: ignore
self.console.log_error(
f"Failed Executing Batch.\nSnapshot: {sid}\n{traceback.format_exc()}"
)
raise

def interval_params(
self,
Expand Down Expand Up @@ -208,22 +178,6 @@ def interval_params(
latest=latest or now(),
)

def _run_snapshot_intervals(
self,
snapshot: Snapshot,
intervals: t.List[t.Tuple[datetime, datetime]],
latest: TimeLike,
pool: Executor,
) -> None:
wait(
[
pool.submit(self.evaluate, snapshot, start, end, latest)
for start, end in intervals
],
)
self.finished.add(snapshot.name)
self.running.remove(snapshot.name)


def compute_interval_params(
target: t.Iterable[SnapshotIdLike],
Expand Down
11 changes: 1 addition & 10 deletions sqlmesh/core/snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def evaluate(
start: TimeLike,
end: TimeLike,
latest: TimeLike,
mapping: t.Dict[str, str],
limit: int = 0,
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
mapping: t.Optional[t.Dict[str, str]] = None,
**kwargs,
) -> t.Optional[DF]:
"""Evaluate a snapshot, creating its schema and table if it doesn't exist and then inserting it.
Expand All @@ -72,7 +71,6 @@ def evaluate(
start: The start datetime to render.
end: The end datetime to render.
latest: The latest datetime to use for non-incremental queries.
snapshots: All snapshots to use for mapping of physical locations.
mapping: Mapping of model references to physical snapshots.
limit: If limit is >= 0, the query will not be persisted but evaluated and returned
as a dataframe.
Expand All @@ -86,10 +84,6 @@ def evaluate(
for sql_statement in model.sql_statements:
self.adapter.execute(sql_statement)

mapping = mapping or {
name: snapshot.table_name for name, snapshot in (snapshots or {}).items()
}

if model.is_sql:
query_or_df = model.render_query(
start=start,
Expand Down Expand Up @@ -212,7 +206,6 @@ def audit(
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
latest: t.Optional[TimeLike] = None,
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
mapping: t.Optional[t.Dict[str, str]] = None,
raise_exception: bool = True,
**kwargs,
Expand All @@ -223,7 +216,6 @@ def audit(
snapshot: Snapshot to evaluate. start: The start datetime to audit. Defaults to epoch start.
end: The end datetime to audit. Defaults to epoch start.
latest: The latest datetime to use for non-incremental queries. Defaults to epoch start.
snapshots: All snapshots to use for mapping of physical locations.
mapping: Mapping of model references to physical snapshots.
collection_exceptions:
kwargs: Additional kwargs to pass to the renderer.
Expand All @@ -233,7 +225,6 @@ def audit(
start=start,
end=end,
latest=latest,
snapshots=snapshots,
mapping=mapping,
**kwargs,
):
Expand Down
9 changes: 5 additions & 4 deletions sqlmesh/core/state_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ class StateReader(abc.ABC):

@abc.abstractmethod
def get_snapshots(
self, snapshot_ids: t.Iterable[SnapshotIdLike]
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
) -> t.Dict[SnapshotId, Snapshot]:
"""Bulk fetch snapshots given the corresponding snapshot ids.

Args:
snapshot_ids: Iterable of snapshot ids to get.
snapshot_ids: Iterable of snapshot ids to get. If not provided all
available snapshots will be returned.

Returns:
A dictionary of snapshot ids to snapshots for ones that could be found.
Expand Down Expand Up @@ -312,7 +313,7 @@ def remove_expired_snapshots(self) -> t.List[Snapshot]:

class CommonStateSyncMixin(StateSync):
def get_snapshots(
self, snapshot_ids: t.Iterable[SnapshotIdLike]
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
) -> t.Dict[SnapshotId, Snapshot]:
return self._get_snapshots(snapshot_ids)

Expand Down Expand Up @@ -766,7 +767,7 @@ def remove_expired_snapshots(self) -> t.List[Snapshot]:
return expired_snapshots

def get_snapshots(
self, snapshot_ids: t.Iterable[SnapshotIdLike]
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
) -> t.Dict[SnapshotId, Snapshot]:
snapshots = super().get_snapshots(snapshot_ids)
self._update_cache(snapshots.values())
Expand Down
6 changes: 4 additions & 2 deletions sqlmesh/schedulers/airflow/state_sync/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_environments(self) -> t.List[Environment]:
)

def get_snapshots(
self, snapshot_ids: t.Iterable[SnapshotIdLike]
self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]]
) -> t.Dict[SnapshotId, Snapshot]:
"""Gets multiple snapshots from the rest api.

Expand All @@ -79,7 +79,9 @@ def get_snapshots(
call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects
on the production server.
"""
snapshot_ids = list(snapshot_ids)
snapshot_ids = (
list(snapshot_ids) if snapshot_ids else self._client.get_snapshot_ids()
)
if len(snapshot_ids) > 1:
logger.warning(
"Fetching multiple snapshots from Airflow using the REST API is inefficient and not recommended"
Comment thread
izeigerman marked this conversation as resolved.
Outdated
Expand Down
Loading