From 50113bb278c4203e0e00e0ed8ca05ffa49af09b5 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Sat, 10 Dec 2022 13:41:58 -0800 Subject: [PATCH] Fix key error when removing a dependency which doesn't exist when concurrently applying a function to a DAG --- example/config.py | 2 +- sqlmesh/utils/concurrency.py | 45 +++++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/example/config.py b/example/config.py index 8492819068..b68f0ecab7 100644 --- a/example/config.py +++ b/example/config.py @@ -33,7 +33,7 @@ DEFAULT_AIRFLOW_KWARGS = { **DEFAULT_KWARGS, "backfill_concurrent_tasks": 4, - "ddl_concurrent_tasks": 1, + "ddl_concurrent_tasks": 4, } diff --git a/sqlmesh/utils/concurrency.py b/sqlmesh/utils/concurrency.py index 06b4b25e98..528b5dc6a0 100644 --- a/sqlmesh/utils/concurrency.py +++ b/sqlmesh/utils/concurrency.py @@ -23,6 +23,9 @@ def concurrent_apply_to_snapshots( fn: The function that will be applied concurrently to each snapshot. tasks_num: The number of concurrent tasks. reverse_order: Whether the order should be reversed. Default: False.. + + Raises: + NodeExecutionFailedError when execution fails for any snapshot. """ snapshots_by_id = {s.snapshot_id: s for s in snapshots} @@ -55,6 +58,9 @@ def concurrent_apply_to_dag( fn: The function that will be applied concurrently to each snapshot. tasks_num: The number of concurrent tasks. reverse_order: Whether the order should be reversed. Default: False.. + + Raises: + NodeExecutionFailedError when execution fails for any node. """ if tasks_num <= 0: raise ConfigError(f"Invalid number of concurrent tasks {tasks_num}") @@ -64,40 +70,43 @@ def concurrent_apply_to_dag( return unprocessed_nodes = dag.graph if not reverse_order else dag.reversed_graph + unprocessed_nodes_num = len(unprocessed_nodes) unprocessed_nodes_lock = Lock() finished_future = Future() # type: ignore def submit_next_nodes( executor: Executor, processed_node: t.Optional[H] = None ) -> None: - with unprocessed_nodes_lock: - if not unprocessed_nodes: - finished_future.set_result(None) - return - - submitted_nodes = [] - for next_node, deps in unprocessed_nodes.items(): - if processed_node: - deps.remove(processed_node) - if not deps: - executor.submit(process_node, next_node, executor) - submitted_nodes.append(next_node) - for submitted_node in submitted_nodes: - unprocessed_nodes.pop(submitted_node) + if not unprocessed_nodes_num: + finished_future.set_result(None) + return + + submitted_nodes = [] + for next_node, deps in unprocessed_nodes.items(): + if processed_node: + deps.discard(processed_node) + if not deps: + executor.submit(process_node, next_node, executor) + submitted_nodes.append(next_node) + for submitted_node in submitted_nodes: + unprocessed_nodes.pop(submitted_node) def process_node(node: H, executor: Executor) -> None: try: fn(node) + + with unprocessed_nodes_lock: + nonlocal unprocessed_nodes_num + unprocessed_nodes_num -= 1 + submit_next_nodes(executor, node) except Exception as ex: error = NodeExecutionFailedError(node) error.__cause__ = ex finished_future.set_exception(error) - return - - submit_next_nodes(executor, node) with ThreadPoolExecutor(max_workers=tasks_num) as pool: - submit_next_nodes(pool) + with unprocessed_nodes_lock: + submit_next_nodes(pool) finished_future.result()