Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions example/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
import duckdb

from sqlmesh.core.config import AirflowSchedulerBackend, Config
from sqlmesh.core.engine_adapter import EngineAdapter

DATA_DIR = os.path.join(os.path.dirname(__file__), "data")


DEFAULT_KWARGS = {
"dialect": "duckdb", # The default dialect of models is DuckDB SQL.
"engine_adapter": EngineAdapter(
duckdb.connect(), "duckdb"
), # The default engine runs in DuckDB.
"engine_dialect": "duckdb",
"engine_connection_factory": duckdb.connect,
}

# An in memory DuckDB config.
Expand All @@ -22,8 +19,8 @@
local_config = Config(
**{
**DEFAULT_KWARGS,
"engine_adapter": EngineAdapter(
lambda: duckdb.connect(database=f"{DATA_DIR}/local.duckdb"), "duckdb"
"engine_connection_factory": lambda: duckdb.connect(
database=f"{DATA_DIR}/local.duckdb"
),
}
)
Expand Down
52 changes: 15 additions & 37 deletions sqlmesh/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import duckdb
from sqlmesh.core.engine_adapter import EngineAdapter
local_config = Config(
engine_adapter=EngineAdapter(duckdb.connect(), "duckdb"),
dialect="duckdb"
engine_connection_factory=duckdb.connect,
engine_dialect="duckdb"
)
# End config.py

Expand All @@ -25,17 +25,20 @@
>>> from sqlmesh import Context
>>> from sqlmesh.core.config import Config
>>> my_config = Config(
... engine_adapter=EngineAdapter(duckdb.connect(), "duckdb"),
... dialect="duckdb"
... engine_connection_factory=duckdb.connect,
... engine_dialect="duckdb"
... )
>>> context = Context(path="example", config=my_config)

```
- Individual config parameters used when initializing a Context.
```python
>>> adapter = EngineAdapter(duckdb.connect(), "duckdb")
>>> from sqlmesh import Context
>>> from sqlmesh.core.engine_adapter import EngineAdapter
>>> adapter = EngineAdapter(duckdb.connect, "duckdb")
>>> context = Context(
... path="example", engine_adapter=adapter,
... path="example",
... engine_adapter=adapter,
... dialect="duckdb",
... )

Expand All @@ -60,7 +63,7 @@

DEFAULT_KWARGS = {
"dialect": "duckdb", # The default dialect of models is DuckDB SQL.
"engine_adapter": EngineAdapter(duckdb.connect(), "duckdb"), # The default engine runs in DuckDB.
"engine_adapter": EngineAdapter(duckdb.connect, "duckdb"), # The default engine runs in DuckDB.
}

# An in memory DuckDB config.
Expand Down Expand Up @@ -102,13 +105,11 @@
import typing as t

import duckdb
from pydantic import Field
from requests import Session

from sqlmesh.core import constants as c
from sqlmesh.core._typing import NotificationTarget
from sqlmesh.core.console import Console
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.plan_evaluator import (
AirflowPlanEvaluator,
BuiltInPlanEvaluator,
Expand Down Expand Up @@ -256,44 +257,21 @@ class Config(PydanticModel):
"""
An object used by a Context to configure your SQLMesh project.

An engine adapter can lazily establish a database connection if it is passed a callable that returns a
database API compliant connection.
```python
>>> from sqlmesh import Context
>>> context = Context(
... path="example",
... engine_adapter=EngineAdapter(duckdb.connect, "duckdb"),
... dialect="duckdb"
... )

```
```python
>>> def create_connection():
... return duckdb.connect()
...
>>> context = Context(
... path="example",
... engine_adapter=EngineAdapter(create_connection, "duckdb"),
... dialect="duckdb"
... )

```

Args:
engine_adapter: The default engine adapter to use
engine_connection_factory: The calllable which creates a new engine connection on each call.
engine_dialect: The engine dialect.
scheduler_backend: Identifies which scheduler backend to use.
notification_targets: The notification targets to use.
dialect: The default sql dialect of model queries.
dialect: The default sql dialect of model queries. Default: same as engine dialect.
physical_schema: The default schema used to store materialized tables.
snapshot_ttl: Duration before unpromoted snapshots are removed.
time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d.
ddl_concurrent_task: The number of concurrent tasks used for DDL
operations (table / view creation, deletion, etc). Default: 1.
"""

engine_adapter: EngineAdapter = Field(
default_factory=lambda: EngineAdapter(duckdb.connect, "duckdb")
)
engine_connection_factory: t.Callable[[], t.Any] = duckdb.connect
engine_dialect: str = "duckdb"
scheduler_backend: SchedulerBackend = BuiltInSchedulerBackend()
notification_targets: t.List[NotificationTarget] = []
dialect: str = ""
Expand Down
61 changes: 41 additions & 20 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,24 @@ def __init__(
load: bool = True,
console: t.Optional[Console] = None,
):
self.console = console or get_console()
self.path = Path(path).absolute()
self.config = self._load_config(config)

self.test_config = None
try:
self.test_config = self._load_config(test_config or "test_config")
except ConfigError:
self.console.log_error(
"Running without test support since `test_config` was not provided and ` "
"test_config` variable was not found in the namespace"
)

# Initialize cache
cache_path = self.path.joinpath(c.CACHE_PATH)
cache_path.mkdir(exist_ok=True)
self.table_info_cache = FileCache(cache_path.joinpath(c.TABLE_INFO_CACHE))
self.dialect = dialect or self.config.dialect
self.dialect = dialect or self.config.dialect or self.config.engine_dialect
self.physical_schema = (
physical_schema or self.config.physical_schema or "sqlmesh"
)
Expand All @@ -122,34 +132,41 @@ def __init__(
self.models = UniqueKeyDict("models")
self.macros = UniqueKeyDict("macros")
self.dag: DAG[str] = DAG()
self.engine_adapter = engine_adapter or self.config.engine_adapter

self.ddl_concurrent_tasks = (
ddl_concurrent_tasks or self.config.ddl_concurrent_tasks
)

self.engine_adapter = engine_adapter or EngineAdapter(
self.config.engine_connection_factory,
self.config.engine_dialect,
multithreaded=self.ddl_concurrent_tasks > 1,
)
self.test_engine_adapter = (
EngineAdapter(
self.test_config.engine_connection_factory,
self.test_config.engine_dialect,
multithreaded=self.test_config.ddl_concurrent_tasks > 1,
)
if self.test_config
else None
)

self.snapshot_evaluator = SnapshotEvaluator(
self.engine_adapter, ddl_concurrent_tasks=self.ddl_concurrent_tasks
)
self._ignore_patterns = c.IGNORE_PATTERNS + self.config.ignore_patterns
self.console = console or get_console()

self.notification_targets = self.config.notification_targets + (
notification_targets or []
)

self._provided_state_sync: t.Optional[StateSync] = state_sync
self._state_sync: t.Optional[StateSync] = None
self._state_reader: t.Optional[StateReader] = None

self.notification_targets = self.config.notification_targets + (
notification_targets or []
)
self.test_config = None
self._ignore_patterns = c.IGNORE_PATTERNS + self.config.ignore_patterns
self._path_mtimes: t.Dict[Path, float] = {}

try:
self.test_config = self._load_config(test_config or "test_config")
except ConfigError:
self.console.log_error(
"Running without test support since `test_config` was not provided and ` "
"test_config` variable was not found in the namespace"
)

if load:
self.load()

Expand Down Expand Up @@ -367,10 +384,10 @@ def format(self) -> None:
def _run_plan_tests(
self, skip_tests: bool = False
) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]:
if self.test_config and not skip_tests:
if self.test_engine_adapter and not skip_tests:
result, test_output = self.run_tests()
self.console.log_test_results(
result, test_output, self.test_config.engine_adapter.dialect
result, test_output, self.test_engine_adapter.dialect
)
if not result.wasSuccessful():
raise PlanError(
Expand Down Expand Up @@ -522,14 +539,14 @@ def run_tests(
self, path: t.Optional[str] = None
) -> t.Tuple[unittest.result.TestResult, str]:
"""Discover and run model tests"""
if not self.test_config:
if not self.test_engine_adapter:
raise ConfigError("Tried to run tests but test_config is not defined")
test_output = StringIO()
with contextlib.redirect_stderr(test_output):
result = run_all_model_tests(
path=Path(path) if path else self.test_directory_path,
snapshots=self.snapshots,
engine_adapter=self.test_config.engine_adapter,
engine_adapter=self.test_engine_adapter,
ignore_patterns=self._ignore_patterns,
)
return result, test_output.getvalue()
Expand Down Expand Up @@ -584,6 +601,10 @@ def audit(
self.console.show_sql(f"{error.query}")
self.console.log_status_update("Done.")

def close(self):
"""Releases all resources allocated by this context."""
self.snapshot_evaluator.close()

def _context_diff(
self,
environment: str | Environment,
Expand Down
29 changes: 18 additions & 11 deletions sqlmesh/core/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class DAG(t.Generic[T]):
def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
self.graph: t.Dict[T, t.Set[T]] = {}
self._graph: t.Dict[T, t.Set[T]] = {}
for node, dependencies in (graph or {}).items():
self.add(node, dependencies)

Expand All @@ -25,10 +25,10 @@ def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None:
node: The node to add.
dependencies: Optional dependencies to add to the node.
"""
if node not in self.graph:
self.graph[node] = set()
if node not in self._graph:
self._graph[node] = set()
if dependencies:
self.graph[node].update(dependencies)
self._graph[node].update(dependencies)
for d in dependencies:
self.add(d)

Expand All @@ -46,7 +46,7 @@ def subdag(self, *nodes: T) -> DAG[T]:

while queue:
node = queue.pop()
deps = self.graph.get(node, set())
deps = self._graph.get(node, set())
graph[node] = deps
queue.update(deps)

Expand All @@ -60,17 +60,24 @@ def upstream(self, node: T) -> t.List[T]:
def leaves(self) -> t.Set[T]:
"""Returns all nodes in the graph without any upstream dependencies."""
return {
dep for deps in self.graph.values() for dep in deps if dep not in self.graph
dep
for deps in self._graph.values()
for dep in deps
if dep not in self._graph
}

@property
def graph(self) -> t.Dict[T, t.Set[T]]:
graph = {}
for node, deps in self._graph.items():
graph[node] = deps.copy()
return graph

def sorted(self) -> t.List[T]:
"""Returns a list of nodes sorted in topological order."""
result: t.List[T] = []

unprocessed_nodes = {}
for node, deps in self.graph.items():
unprocessed_nodes[node] = deps.copy()

unprocessed_nodes = self.graph
while unprocessed_nodes:
next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps}

Expand Down Expand Up @@ -103,7 +110,7 @@ def visit() -> t.Iterator[T]:
"""Visit topologically sorted nodes after input node and yield downstream dependants."""
downstream = {node}
for current_node in sorted_nodes[node_index + 1 :]:
upstream = self.graph.get(current_node, set())
upstream = self._graph.get(current_node, set())
if not upstream.isdisjoint(downstream):
downstream.add(current_node)
yield current_node
Expand Down
Loading