@@ -23,6 +23,9 @@ def concurrent_apply_to_snapshots(
2323 fn: The function that will be applied concurrently to each snapshot.
2424 tasks_num: The number of concurrent tasks.
2525 reverse_order: Whether the order should be reversed. Default: False..
26+
27+ Raises:
28+ NodeExecutionFailedError when execution fails for any snapshot.
2629 """
2730 snapshots_by_id = {s .snapshot_id : s for s in snapshots }
2831
@@ -55,6 +58,9 @@ def concurrent_apply_to_dag(
5558 fn: The function that will be applied concurrently to each snapshot.
5659 tasks_num: The number of concurrent tasks.
5760 reverse_order: Whether the order should be reversed. Default: False..
61+
62+ Raises:
63+ NodeExecutionFailedError when execution fails for any node.
5864 """
5965 if tasks_num <= 0 :
6066 raise ConfigError (f"Invalid number of concurrent tasks { tasks_num } " )
@@ -64,40 +70,43 @@ def concurrent_apply_to_dag(
6470 return
6571
6672 unprocessed_nodes = dag .graph if not reverse_order else dag .reversed_graph
73+ unprocessed_nodes_num = len (unprocessed_nodes )
6774 unprocessed_nodes_lock = Lock ()
6875 finished_future = Future () # type: ignore
6976
7077 def submit_next_nodes (
7178 executor : Executor , processed_node : t .Optional [H ] = None
7279 ) -> None :
73- with unprocessed_nodes_lock :
74- if not unprocessed_nodes :
75- finished_future .set_result (None )
76- return
77-
78- submitted_nodes = []
79- for next_node , deps in unprocessed_nodes .items ():
80- if processed_node :
81- deps .remove (processed_node )
82- if not deps :
83- executor .submit (process_node , next_node , executor )
84- submitted_nodes .append (next_node )
85- for submitted_node in submitted_nodes :
86- unprocessed_nodes .pop (submitted_node )
80+ if not unprocessed_nodes_num :
81+ finished_future .set_result (None )
82+ return
83+
84+ submitted_nodes = []
85+ for next_node , deps in unprocessed_nodes .items ():
86+ if processed_node :
87+ deps .discard (processed_node )
88+ if not deps :
89+ executor .submit (process_node , next_node , executor )
90+ submitted_nodes .append (next_node )
91+ for submitted_node in submitted_nodes :
92+ unprocessed_nodes .pop (submitted_node )
8793
8894 def process_node (node : H , executor : Executor ) -> None :
8995 try :
9096 fn (node )
97+
98+ with unprocessed_nodes_lock :
99+ nonlocal unprocessed_nodes_num
100+ unprocessed_nodes_num -= 1
101+ submit_next_nodes (executor , node )
91102 except Exception as ex :
92103 error = NodeExecutionFailedError (node )
93104 error .__cause__ = ex
94105 finished_future .set_exception (error )
95- return
96-
97- submit_next_nodes (executor , node )
98106
99107 with ThreadPoolExecutor (max_workers = tasks_num ) as pool :
100- submit_next_nodes (pool )
108+ with unprocessed_nodes_lock :
109+ submit_next_nodes (pool )
101110 finished_future .result ()
102111
103112
0 commit comments