Skip to content

Commit d28d85f

Browse files
committed
Add multithread context manager to the snapshot evaluator
1 parent 8c8c36f commit d28d85f

3 files changed

Lines changed: 36 additions & 30 deletions

File tree

sqlmesh/core/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import duckdb
1212
from sqlmesh.core.engine_adapter import EngineAdapter
1313
local_config = Config(
14-
engine_config_factory=duckdb.connect,
14+
engine_connection_factory=duckdb.connect,
1515
engine_dialect="duckdb"
1616
)
1717
# End config.py
@@ -274,7 +274,7 @@ class Config(PydanticModel):
274274
engine_dialect: str = "duckdb"
275275
scheduler_backend: SchedulerBackend = BuiltInSchedulerBackend()
276276
notification_targets: t.List[NotificationTarget] = []
277-
dialect: t.Optional[str] = None
277+
dialect: str = ""
278278
physical_schema: str = ""
279279
snapshot_ttl: str = ""
280280
ignore_patterns: t.List[str] = []

sqlmesh/core/scheduler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def run(
136136
# We have to run all batches per snapshot to mark it as completed
137137
self.console.start_snapshot_progress(snapshot.name, len(intervals))
138138

139-
with ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor(
139+
with self.snapshot_evaluator.multithreaded_context(), ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor(
140140
max_workers=self.max_workers
141141
) as batch_pool:
142142
while True:
@@ -168,8 +168,6 @@ def run(
168168
else:
169169
self.console.complete_snapshot_progress()
170170

171-
self.snapshot_evaluator.recycle()
172-
173171
return self.failed
174172

175173
def interval_params(

sqlmesh/core/snapshot_evaluator.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import logging
2525
import typing as t
26+
from contextlib import contextmanager
2627

2728
from sqlglot import exp, select
2829

@@ -137,12 +138,12 @@ def promote(
137138
target_snapshots: Snapshots to promote.
138139
environment: The target environment.
139140
"""
140-
concurrent_apply_to_snapshots(
141-
target_snapshots,
142-
lambda s: self._promote_snapshot(s, environment),
143-
self.ddl_concurrent_tasks,
144-
)
145-
self.recycle()
141+
with self.multithreaded_context():
142+
concurrent_apply_to_snapshots(
143+
target_snapshots,
144+
lambda s: self._promote_snapshot(s, environment),
145+
self.ddl_concurrent_tasks,
146+
)
146147

147148
def demote(
148149
self, target_snapshots: t.Iterable[SnapshotInfoLike], environment: str
@@ -153,12 +154,12 @@ def demote(
153154
target_snapshots: Snapshots to demote.
154155
environment: The target environment.
155156
"""
156-
concurrent_apply_to_snapshots(
157-
target_snapshots,
158-
lambda s: self._demote_snapshot(s, environment),
159-
self.ddl_concurrent_tasks,
160-
)
161-
self.recycle()
157+
with self.multithreaded_context():
158+
concurrent_apply_to_snapshots(
159+
target_snapshots,
160+
lambda s: self._demote_snapshot(s, environment),
161+
self.ddl_concurrent_tasks,
162+
)
162163

163164
def create(
164165
self,
@@ -170,26 +171,26 @@ def create(
170171
Args:
171172
target_snapshots: Target snapshost.
172173
"""
173-
concurrent_apply_to_snapshots(
174-
target_snapshots,
175-
lambda s: self._create_snapshot(s, snapshots),
176-
self.ddl_concurrent_tasks,
177-
)
178-
self.recycle()
174+
with self.multithreaded_context():
175+
concurrent_apply_to_snapshots(
176+
target_snapshots,
177+
lambda s: self._create_snapshot(s, snapshots),
178+
self.ddl_concurrent_tasks,
179+
)
179180

180181
def cleanup(self, target_snapshots: t.Iterable[SnapshotInfoLike]) -> None:
181182
"""Cleans up the given snapshots by removing its table
182183
183184
Args:
184185
target_snapshots: Snapshots to cleanup.
185186
"""
186-
concurrent_apply_to_snapshots(
187-
target_snapshots,
188-
self._cleanup_snapshot,
189-
self.ddl_concurrent_tasks,
190-
reverse_order=True,
191-
)
192-
self.recycle()
187+
with self.multithreaded_context():
188+
concurrent_apply_to_snapshots(
189+
target_snapshots,
190+
self._cleanup_snapshot,
191+
self.ddl_concurrent_tasks,
192+
reverse_order=True,
193+
)
193194

194195
def audit(
195196
self,
@@ -235,6 +236,13 @@ def audit(
235236
results.append(AuditResult(audit=audit, count=count, query=query))
236237
return results
237238

239+
@contextmanager
240+
def multithreaded_context(self) -> t.Generator[None, None, None]:
241+
try:
242+
yield
243+
finally:
244+
self.recycle()
245+
238246
def recycle(self) -> None:
239247
"""Closes all open connections and releases all allocated resources associated with any thread
240248
except the calling one."""

0 commit comments

Comments
 (0)