Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add multithread context manager to the snapshot evaluator
  • Loading branch information
izeigerman committed Dec 8, 2022
commit d28d85f3158c7127f822d3e2f740f2c8c598fc48
4 changes: 2 additions & 2 deletions sqlmesh/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import duckdb
from sqlmesh.core.engine_adapter import EngineAdapter
local_config = Config(
engine_config_factory=duckdb.connect,
engine_connection_factory=duckdb.connect,
engine_dialect="duckdb"
)
# End config.py
Expand Down Expand Up @@ -274,7 +274,7 @@ class Config(PydanticModel):
engine_dialect: str = "duckdb"
scheduler_backend: SchedulerBackend = BuiltInSchedulerBackend()
notification_targets: t.List[NotificationTarget] = []
dialect: t.Optional[str] = None
dialect: str = ""
physical_schema: str = ""
snapshot_ttl: str = ""
ignore_patterns: t.List[str] = []
Expand Down
4 changes: 1 addition & 3 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def run(
# We have to run all batches per snapshot to mark it as completed
self.console.start_snapshot_progress(snapshot.name, len(intervals))

with ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor(
with self.snapshot_evaluator.multithreaded_context(), ThreadPoolExecutor() as snapshot_pool, ThreadPoolExecutor(
max_workers=self.max_workers
) as batch_pool:
while True:
Expand Down Expand Up @@ -168,8 +168,6 @@ def run(
else:
self.console.complete_snapshot_progress()

self.snapshot_evaluator.recycle()

return self.failed

def interval_params(
Expand Down
58 changes: 33 additions & 25 deletions sqlmesh/core/snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import logging
import typing as t
from contextlib import contextmanager

from sqlglot import exp, select

Expand Down Expand Up @@ -137,12 +138,12 @@ def promote(
target_snapshots: Snapshots to promote.
environment: The target environment.
"""
concurrent_apply_to_snapshots(
target_snapshots,
lambda s: self._promote_snapshot(s, environment),
self.ddl_concurrent_tasks,
)
self.recycle()
with self.multithreaded_context():
concurrent_apply_to_snapshots(
target_snapshots,
lambda s: self._promote_snapshot(s, environment),
self.ddl_concurrent_tasks,
)

def demote(
self, target_snapshots: t.Iterable[SnapshotInfoLike], environment: str
Expand All @@ -153,12 +154,12 @@ def demote(
target_snapshots: Snapshots to demote.
environment: The target environment.
"""
concurrent_apply_to_snapshots(
target_snapshots,
lambda s: self._demote_snapshot(s, environment),
self.ddl_concurrent_tasks,
)
self.recycle()
with self.multithreaded_context():
concurrent_apply_to_snapshots(
target_snapshots,
lambda s: self._demote_snapshot(s, environment),
self.ddl_concurrent_tasks,
)

def create(
self,
Expand All @@ -170,26 +171,26 @@ def create(
Args:
target_snapshots: Target snapshost.
"""
concurrent_apply_to_snapshots(
target_snapshots,
lambda s: self._create_snapshot(s, snapshots),
self.ddl_concurrent_tasks,
)
self.recycle()
with self.multithreaded_context():
concurrent_apply_to_snapshots(
target_snapshots,
lambda s: self._create_snapshot(s, snapshots),
self.ddl_concurrent_tasks,
)

def cleanup(self, target_snapshots: t.Iterable[SnapshotInfoLike]) -> None:
"""Cleans up the given snapshots by removing its table

Args:
target_snapshots: Snapshots to cleanup.
"""
concurrent_apply_to_snapshots(
target_snapshots,
self._cleanup_snapshot,
self.ddl_concurrent_tasks,
reverse_order=True,
)
self.recycle()
with self.multithreaded_context():
concurrent_apply_to_snapshots(
target_snapshots,
self._cleanup_snapshot,
self.ddl_concurrent_tasks,
reverse_order=True,
)

def audit(
self,
Expand Down Expand Up @@ -235,6 +236,13 @@ def audit(
results.append(AuditResult(audit=audit, count=count, query=query))
return results

@contextmanager
def multithreaded_context(self) -> t.Generator[None, None, None]:
try:
yield
finally:
self.recycle()

def recycle(self) -> None:
"""Closes all open connections and releases all allocated resources associated with any thread
except the calling one."""
Expand Down