Skip to content

Commit 909d6d6

Browse files
authored
Fix key error when removing a dependency which doesn't exist when concurrently applying a function to a DAG (#41)
1 parent b716869 commit 909d6d6

2 files changed

Lines changed: 28 additions & 19 deletions

File tree

example/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
DEFAULT_AIRFLOW_KWARGS = {
3434
**DEFAULT_KWARGS,
3535
"backfill_concurrent_tasks": 4,
36-
"ddl_concurrent_tasks": 1,
36+
"ddl_concurrent_tasks": 4,
3737
}
3838

3939

sqlmesh/utils/concurrency.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def concurrent_apply_to_snapshots(
2323
fn: The function that will be applied concurrently to each snapshot.
2424
tasks_num: The number of concurrent tasks.
2525
reverse_order: Whether the order should be reversed. Default: False..
26+
27+
Raises:
28+
NodeExecutionFailedError when execution fails for any snapshot.
2629
"""
2730
snapshots_by_id = {s.snapshot_id: s for s in snapshots}
2831

@@ -55,6 +58,9 @@ def concurrent_apply_to_dag(
5558
fn: The function that will be applied concurrently to each snapshot.
5659
tasks_num: The number of concurrent tasks.
5760
reverse_order: Whether the order should be reversed. Default: False..
61+
62+
Raises:
63+
NodeExecutionFailedError when execution fails for any node.
5864
"""
5965
if tasks_num <= 0:
6066
raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}")
@@ -64,40 +70,43 @@ def concurrent_apply_to_dag(
6470
return
6571

6672
unprocessed_nodes = dag.graph if not reverse_order else dag.reversed_graph
73+
unprocessed_nodes_num = len(unprocessed_nodes)
6774
unprocessed_nodes_lock = Lock()
6875
finished_future = Future() # type: ignore
6976

7077
def submit_next_nodes(
7178
executor: Executor, processed_node: t.Optional[H] = None
7279
) -> None:
73-
with unprocessed_nodes_lock:
74-
if not unprocessed_nodes:
75-
finished_future.set_result(None)
76-
return
77-
78-
submitted_nodes = []
79-
for next_node, deps in unprocessed_nodes.items():
80-
if processed_node:
81-
deps.remove(processed_node)
82-
if not deps:
83-
executor.submit(process_node, next_node, executor)
84-
submitted_nodes.append(next_node)
85-
for submitted_node in submitted_nodes:
86-
unprocessed_nodes.pop(submitted_node)
80+
if not unprocessed_nodes_num:
81+
finished_future.set_result(None)
82+
return
83+
84+
submitted_nodes = []
85+
for next_node, deps in unprocessed_nodes.items():
86+
if processed_node:
87+
deps.discard(processed_node)
88+
if not deps:
89+
executor.submit(process_node, next_node, executor)
90+
submitted_nodes.append(next_node)
91+
for submitted_node in submitted_nodes:
92+
unprocessed_nodes.pop(submitted_node)
8793

8894
def process_node(node: H, executor: Executor) -> None:
8995
try:
9096
fn(node)
97+
98+
with unprocessed_nodes_lock:
99+
nonlocal unprocessed_nodes_num
100+
unprocessed_nodes_num -= 1
101+
submit_next_nodes(executor, node)
91102
except Exception as ex:
92103
error = NodeExecutionFailedError(node)
93104
error.__cause__ = ex
94105
finished_future.set_exception(error)
95-
return
96-
97-
submit_next_nodes(executor, node)
98106

99107
with ThreadPoolExecutor(max_workers=tasks_num) as pool:
100-
submit_next_nodes(pool)
108+
with unprocessed_nodes_lock:
109+
submit_next_nodes(pool)
101110
finished_future.result()
102111

103112

0 commit comments

Comments
 (0)