33import logging
44import traceback
55import typing as t
6- from concurrent .futures import Executor , ThreadPoolExecutor , wait
76from datetime import datetime
8- from time import sleep
97
108from sqlmesh .core .console import Console , get_console
9+ from sqlmesh .core .dag import DAG
1110from sqlmesh .core .snapshot import Snapshot , SnapshotId , SnapshotIdLike
1211from sqlmesh .core .snapshot_evaluator import SnapshotEvaluator
1312from sqlmesh .core .state_sync import StateSync
13+ from sqlmesh .utils .concurrency import NodeExecutionFailedError , concurrent_apply_to_dag
1414from sqlmesh .utils .date import TimeLike , now , to_datetime , yesterday
1515
1616logger = logging .getLogger (__name__ )
1717SnapshotBatches = t .List [t .Tuple [Snapshot , t .List [t .Tuple [datetime , datetime ]]]]
18+ SchedulingUnit = t .Tuple [SnapshotId , t .Tuple [datetime , datetime ]]
1819
1920
2021class Scheduler :
@@ -36,7 +37,7 @@ class Scheduler:
3637
3738 def __init__ (
3839 self ,
39- snapshots : t .Dict [str , Snapshot ],
40+ snapshots : t .Dict [SnapshotId , Snapshot ],
4041 snapshot_evaluator : SnapshotEvaluator ,
4142 state_sync : StateSync ,
4243 max_workers : int = 1 ,
@@ -46,9 +47,6 @@ def __init__(
4647 self .snapshot_evaluator = snapshot_evaluator
4748 self .state_sync = state_sync
4849 self .max_workers = max_workers
49- self .running : t .Set [str ] = set ()
50- self .failed : t .Dict [str , str ] = {}
51- self .finished : t .Set [str ] = set ()
5250 self .console : Console = console or get_console ()
5351
5452 def evaluate (
@@ -68,107 +66,79 @@ def evaluate(
6866 latest: The latest datetime to use for non-incremental queries.
6967 kwargs: Additional kwargs to pass to the renderer.
7068 """
71- try :
72- self .snapshot_evaluator .evaluate (
73- snapshot ,
74- start ,
75- end ,
76- latest ,
77- snapshots = self .snapshots ,
78- ** kwargs ,
79- )
80- self .state_sync .add_interval (snapshot .snapshot_id , start , end )
81- self .snapshot_evaluator .audit (
82- snapshot = snapshot ,
83- start = start ,
84- end = end ,
85- latest = latest ,
86- snapshots = self .snapshots ,
87- ** kwargs ,
88- )
89- self .console .update_snapshot_progress (snapshot .name , 1 )
90- except Exception :
91- self .failed [snapshot .name ] = traceback .format_exc ()
69+
70+ mapping = {
71+ ** {
72+ p_sid .name : self .snapshots [p_sid ].table_name
73+ for p_sid in snapshot .parents
74+ },
75+ snapshot .name : snapshot .table_name ,
76+ }
77+
78+ self .snapshot_evaluator .evaluate (
79+ snapshot ,
80+ start ,
81+ end ,
82+ latest ,
83+ mapping = mapping ,
84+ ** kwargs ,
85+ )
86+ self .state_sync .add_interval (snapshot .snapshot_id , start , end )
87+ self .snapshot_evaluator .audit (
88+ snapshot = snapshot ,
89+ start = start ,
90+ end = end ,
91+ latest = latest ,
92+ mapping = mapping ,
93+ ** kwargs ,
94+ )
95+ self .console .update_snapshot_progress (snapshot .name , 1 )
9296
9397 def run (
9498 self ,
95- snapshots : t .Iterable [Snapshot ],
9699 start : t .Optional [TimeLike ] = None ,
97100 end : t .Optional [TimeLike ] = None ,
98101 latest : t .Optional [TimeLike ] = None ,
99- ) -> t . Dict [ str , str ] :
102+ ) -> None :
100103 """Concurrently runs all snapshots in topological order.
101104
102105 Args:
103- snapshots: An iterable of all the snapshots to run.
104106 start: The start of the run. Defaults to the min model start date.
105107 end: The end of the run. Defaults to now.
106108 latest: The latest datetime to use for non-incremental queries.
107-
108- Returns:
109- A dictionary of model name to error string.
110109 """
111- snapshots = tuple (snapshots )
112110 latest = latest or now ()
113- batches = self .interval_params (snapshots , start , end , latest )
111+ batches = self .interval_params (self . snapshots . values () , start , end , latest )
114112
115- self .running .clear ()
116- self .finished .clear ()
117- self .failed .clear ()
118- dag = []
113+ intervals_per_snapshot_id = {
114+ snapshot .snapshot_id : intervals for snapshot , intervals in batches
115+ }
119116
117+ dag = DAG [SchedulingUnit ]()
120118 for snapshot , intervals in batches :
121- dag .append (
122- (
123- snapshot ,
124- intervals ,
125- {
126- table
127- for table in snapshot .model .depends_on
128- if table in self .snapshots
129- },
130- )
131- )
119+ upstream_dependencies = [
120+ (p_sid , interval )
121+ for p_sid in snapshot .parents
122+ for interval in intervals_per_snapshot_id .get (p_sid , [])
123+ ]
124+ sid = snapshot .snapshot_id
125+ for interval in intervals :
126+ dag .add ((sid , interval ), upstream_dependencies )
132127
133- for snapshot , intervals , _ in dag [::- 1 ]:
134- if not intervals :
135- continue
136- # We have to run all batches per snapshot to mark it as completed
137- self .console .start_snapshot_progress (snapshot .name , len (intervals ))
138-
139- with self .snapshot_evaluator .multithreaded_context (), ThreadPoolExecutor () as snapshot_pool , ThreadPoolExecutor (
140- max_workers = self .max_workers
141- ) as batch_pool :
142- while True :
143- if self .failed :
144- for model_name , error_message in self .failed .items ():
145- self .console .log_error (
146- f"Failed Executing Batch.\n Model name:{ model_name } \n { error_message } "
147- )
148- snapshot_pool .shutdown ()
149- batch_pool .shutdown ()
150- break
151- if self .finished >= {snapshot .name for snapshot , _ , _ in dag }:
152- break
153- processed = self .running | self .finished
154- for snapshot , intervals , deps in dag :
155- if snapshot .name not in processed and self .finished >= deps :
156- self .running .add (snapshot .name )
157- snapshot_pool .submit (
158- self ._run_snapshot_intervals ,
159- snapshot ,
160- intervals ,
161- latest ,
162- batch_pool ,
163- )
164- sleep (0.1 )
165-
166- if self .failed :
167- self .console .stop_snapshot_progress ()
168- else :
169- self .console .complete_snapshot_progress ()
128+ def evaluate_node (node : SchedulingUnit ) -> None :
129+ assert latest
130+ sid , (start , end ) = node
131+ self .evaluate (self .snapshots [sid ], start , end , latest )
170132
171- return self .failed
133+ try :
134+ with self .snapshot_evaluator .multithreaded_context ():
135+ concurrent_apply_to_dag (dag , evaluate_node , self .max_workers )
136+ except NodeExecutionFailedError as error :
137+ sid = error .node [0 ] # type: ignore
138+ self .console .log_error (
139+ f"Failed Executing Batch.\n Snapshot: { sid } \n { traceback .format_exc ()} "
140+ )
141+ raise
172142
173143 def interval_params (
174144 self ,
@@ -208,22 +178,6 @@ def interval_params(
208178 latest = latest or now (),
209179 )
210180
211- def _run_snapshot_intervals (
212- self ,
213- snapshot : Snapshot ,
214- intervals : t .List [t .Tuple [datetime , datetime ]],
215- latest : TimeLike ,
216- pool : Executor ,
217- ) -> None :
218- wait (
219- [
220- pool .submit (self .evaluate , snapshot , start , end , latest )
221- for start , end in intervals
222- ],
223- )
224- self .finished .add (snapshot .name )
225- self .running .remove (snapshot .name )
226-
227181
228182def compute_interval_params (
229183 target : t .Iterable [SnapshotIdLike ],
0 commit comments