Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/guides/signals.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,18 @@ MODEL (

SELECT @start_ds AS ds
```

### Accessing execution context / engine adapter
It is possible to access the execution context in a signal and access the engine adapter (warehouse connection).

```python
import typing as t

from sqlmesh import signal, DatetimeRanges, ExecutionContext


# add the context argument to your function
@signal()
def one_week_ago(batch: DatetimeRanges, context: ExecutionContext) -> t.Union[bool, DatetimeRanges]:
return len(context.engine_adapter.fetchdf("SELECT 1")) > 1
```
19 changes: 15 additions & 4 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def send(
raise SQLMeshError(f"Macro '{name}' does not exist.")

try:
return call_macro(func, self.dialect, self._path, self, *args, **kwargs) # type: ignore
return call_macro(
func, self.dialect, self._path, provided_args=(self, *args), provided_kwargs=kwargs
) # type: ignore
except Exception as e:
print_exception(e, self.python_env)
raise MacroEvalError("Error trying to eval macro.") from e
Expand Down Expand Up @@ -1286,12 +1288,21 @@ def call_macro(
func: t.Callable,
dialect: DialectType,
path: Path,
*args: t.Any,
**kwargs: t.Any,
provided_args: t.Tuple[t.Any, ...],
provided_kwargs: t.Dict[str, t.Any],
**optional_kwargs: t.Any,
Comment thread
tobymao marked this conversation as resolved.
) -> t.Any:
# Bind the macro's actual parameters to its formal parameters
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)

if optional_kwargs:
provided_kwargs = provided_kwargs.copy()

for k, v in optional_kwargs.items():
if k in sig.parameters:
provided_kwargs[k] = v

bound = sig.bind(*provided_args, **provided_kwargs)
bound.apply_defaults()

try:
Expand Down
32 changes: 24 additions & 8 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
):
self.state_sync = state_sync
self.snapshots = {s.snapshot_id: s for s in snapshots}
self.snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()}
self.snapshot_per_version = _resolve_one_snapshot_per_version(self.snapshots.values())
self.default_catalog = default_catalog
self.snapshot_evaluator = snapshot_evaluator
Expand Down Expand Up @@ -348,7 +349,11 @@ def run(

return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS

def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]:
def batch_intervals(
self,
merged_intervals: SnapshotToIntervals,
deployability_index: t.Optional[DeployabilityIndex],
) -> t.Dict[Snapshot, Intervals]:
dag = snapshots_to_dag(merged_intervals)

snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = {
Expand All @@ -369,7 +374,20 @@ def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snaps
continue
snapshot, intervals = snapshot_intervals[snapshot_id]
unready = set(intervals)
intervals = snapshot.check_ready_intervals(intervals)

from sqlmesh.core.context import ExecutionContext

adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway)

context = ExecutionContext(
adapter,
self.snapshots_by_name,
Comment thread
tobymao marked this conversation as resolved.
deployability_index,
default_dialect=adapter.dialect,
default_catalog=self.default_catalog,
)

intervals = snapshot.check_ready_intervals(intervals, context)
unready -= set(intervals)

for parent in snapshot.parents:
Expand Down Expand Up @@ -424,7 +442,7 @@ def run_merged_intervals(
"""
execution_time = execution_time or now_timestamp()

batched_intervals = self.batch_intervals(merged_intervals)
batched_intervals = self.batch_intervals(merged_intervals, deployability_index)

self.console.start_evaluation_progress(
{snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()},
Expand All @@ -434,8 +452,6 @@ def run_merged_intervals(

dag = self._dag(batched_intervals)

snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()}

if run_environment_statements:
environment_statements = self.state_sync.get_environment_statements(
environment_naming_info.name
Expand All @@ -446,7 +462,7 @@ def run_merged_intervals(
runtime_stage=RuntimeStage.BEFORE_ALL,
environment_naming_info=environment_naming_info,
default_catalog=self.default_catalog,
snapshots=snapshots_by_name,
snapshots=self.snapshots_by_name,
start=start,
end=end,
execution_time=execution_time,
Expand All @@ -459,7 +475,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
snapshot_name, ((start, end), batch_idx) = node
if batch_idx == -1:
return
snapshot = snapshots_by_name[snapshot_name]
snapshot = self.snapshots_by_name[snapshot_name]

self.console.start_snapshot_evaluation_progress(snapshot)

Expand Down Expand Up @@ -520,7 +536,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
runtime_stage=RuntimeStage.AFTER_ALL,
environment_naming_info=environment_naming_info,
default_catalog=self.default_catalog,
snapshots=snapshots_by_name,
snapshots=self.snapshots_by_name,
start=start,
end=end,
execution_time=execution_time,
Expand Down
14 changes: 12 additions & 2 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.context import ExecutionContext

Interval = t.Tuple[int, int]
Intervals = t.List[Interval]
Expand Down Expand Up @@ -940,7 +941,7 @@ def missing_intervals(
model_end_ts,
)

def check_ready_intervals(self, intervals: Intervals) -> Intervals:
def check_ready_intervals(self, intervals: Intervals, context: ExecutionContext) -> Intervals:
"""Returns a list of intervals that are considered ready by the provided signal.

Note that this will handle gaps in the provided intervals. The returned intervals
Expand All @@ -959,6 +960,7 @@ def check_ready_intervals(self, intervals: Intervals) -> Intervals:
intervals = _check_ready_intervals(
env[signal_name],
intervals,
context,
dialect=self.model.dialect,
path=self.model._path,
kwargs=kwargs,
Expand Down Expand Up @@ -2148,6 +2150,7 @@ def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]:
def _check_ready_intervals(
check: t.Callable,
intervals: Intervals,
context: ExecutionContext,
dialect: DialectType = None,
path: Path = Path(),
kwargs: t.Optional[t.Dict] = None,
Expand All @@ -2158,7 +2161,14 @@ def _check_ready_intervals(
batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch]

try:
ready_intervals = call_macro(check, dialect, path, batch, **(kwargs or {}))
ready_intervals = call_macro(
check,
dialect,
path,
provided_args=(batch,),
provided_kwargs=(kwargs or {}),
context=context,
)
except Exception:
raise SQLMeshError("Error evaluating signal")

Expand Down
22 changes: 11 additions & 11 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def create(

def _get_data_objects(schema: exp.Table, gateway: t.Optional[str] = None) -> t.Set[str]:
logger.info("Listing data objects in schema %s", schema.sql())
objs = self._get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema])
objs = self.get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema])
return {obj.name for obj in objs}

with self.concurrent_context():
Expand Down Expand Up @@ -409,7 +409,7 @@ def migrate(
s,
snapshots,
allow_destructive_snapshots,
self._get_adapter(s.model_gateway),
self.get_adapter(s.model_gateway),
deployability_index,
),
self.ddl_concurrent_tasks,
Expand Down Expand Up @@ -437,7 +437,7 @@ def cleanup(
lambda s: self._cleanup_snapshot(
s,
snapshots_to_dev_table_only[s.snapshot_id],
self._get_adapter(
self.get_adapter(
snapshot_gateways.get(s.snapshot_id.name) if snapshot_gateways else None
),
on_complete,
Expand Down Expand Up @@ -471,7 +471,7 @@ def audit(
kwargs: Additional kwargs to pass to the renderer.
"""
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
adapter = self._get_adapter(snapshot.model_gateway)
adapter = self.get_adapter(snapshot.model_gateway)

if not snapshot.version:
raise ConfigError(
Expand Down Expand Up @@ -605,7 +605,7 @@ def _evaluate_snapshot(
else snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
)

adapter = self._get_adapter(model.gateway)
adapter = self.get_adapter(model.gateway)
evaluation_strategy = _evaluation_strategy(snapshot, adapter)

# https://github.com/TobikoData/sqlmesh/issues/2609
Expand Down Expand Up @@ -764,7 +764,7 @@ def _create_snapshot(

deployability_index = deployability_index or DeployabilityIndex.all_deployable()

adapter = self._get_adapter(snapshot.model.gateway)
adapter = self.get_adapter(snapshot.model.gateway)
create_render_kwargs: t.Dict[str, t.Any] = dict(
engine_adapter=adapter,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
Expand Down Expand Up @@ -994,7 +994,7 @@ def _wap_publish_snapshot(
) -> None:
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
adapter = self._get_adapter(snapshot.model_gateway)
adapter = self.get_adapter(snapshot.model_gateway)
adapter.wap_publish(table_name, wap_id)

def _audit(
Expand All @@ -1021,7 +1021,7 @@ def _audit(
blocking = audit_args.pop("blocking", None)
blocking = blocking == exp.true() if blocking else audit.blocking

adapter = self._get_adapter(snapshot.model_gateway)
adapter = self.get_adapter(snapshot.model_gateway)

kwargs = {
"start": start,
Expand Down Expand Up @@ -1068,10 +1068,10 @@ def _create_schemas(
for schema_name, catalog in unique_schemas:
schema = schema_(schema_name, catalog)
logger.info("Creating schema '%s'", schema)
adapter = self._get_adapter(gateways.get(schema)) if gateways else self.adapter
adapter = self.get_adapter(gateways.get(schema)) if gateways else self.adapter
adapter.create_schema(schema)

def _get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
"""Returns the adapter for the specified gateway or the default adapter if none is provided."""
if gateway:
if adapter := self.adapters.get(gateway):
Expand All @@ -1089,7 +1089,7 @@ def _execute_create(
rendered_physical_properties: t.Dict[str, exp.Expression],
dry_run: bool,
) -> None:
adapter = self._get_adapter(snapshot.model.gateway)
adapter = self.get_adapter(snapshot.model.gateway)
evaluation_strategy = _evaluation_strategy(snapshot, adapter)

# It can still be useful for some strategies to know if the snapshot was actually deployable
Expand Down
14 changes: 8 additions & 6 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlglot import parse_one, parse
from sqlglot.helper import first

from sqlmesh.core.context import Context
from sqlmesh.core.context import Context, ExecutionContext
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.model import load_sql_based_model
from sqlmesh.core.model.definition import AuditResult, SqlModel
Expand Down Expand Up @@ -66,17 +66,17 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context


@pytest.fixture
def get_batched_missing_intervals() -> (
t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]
):
def get_batched_missing_intervals(
mocker: MockerFixture,
) -> t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]:
def _get_batched_missing_intervals(
scheduler: Scheduler,
start: TimeLike,
end: TimeLike,
execution_time: t.Optional[TimeLike] = None,
) -> SnapshotToIntervals:
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
return scheduler.batch_intervals(merged_intervals)
return scheduler.batch_intervals(merged_intervals, mocker.Mock())

return _get_batched_missing_intervals

Expand Down Expand Up @@ -622,7 +622,9 @@ def test_interval_diff():

def test_signal_intervals(mocker: MockerFixture, make_snapshot, get_batched_missing_intervals):
@signal()
def signal_a(batch: DatetimeRanges):
def signal_a(batch: DatetimeRanges, context: ExecutionContext):
if not hasattr(context, "engine_adapter"):
raise
return [batch[0], batch[1]]

@signal()
Expand Down
12 changes: 6 additions & 6 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,23 +2499,23 @@ def test_contiguous_intervals():

def test_check_ready_intervals(mocker: MockerFixture):
def assert_always_signal(intervals):
assert _check_ready_intervals(lambda _: True, intervals) == intervals
assert _check_ready_intervals(lambda _: True, intervals, mocker.Mock()) == intervals

assert_always_signal([])
assert_always_signal([(0, 1)])
assert_always_signal([(0, 1), (1, 2)])
assert_always_signal([(0, 1), (2, 3)])

def assert_never_signal(intervals):
assert _check_ready_intervals(lambda _: False, intervals) == []
assert _check_ready_intervals(lambda _: False, intervals, mocker.Mock()) == []

assert_never_signal([])
assert_never_signal([(0, 1)])
assert_never_signal([(0, 1), (1, 2)])
assert_never_signal([(0, 1), (2, 3)])

def assert_empty_signal(intervals):
assert _check_ready_intervals(lambda _: [], intervals) == []
assert _check_ready_intervals(lambda _: [], intervals, mocker.Mock()) == []

assert_empty_signal([])
assert_empty_signal([(0, 1)])
Expand All @@ -2532,7 +2532,7 @@ def assert_check_intervals(
):
mock = mocker.Mock()
mock.side_effect = [to_intervals(r) for r in ready]
_check_ready_intervals(mock, intervals) == expected
_check_ready_intervals(mock, intervals, mocker.Mock()) == expected

assert_check_intervals([], [], [])
assert_check_intervals([(0, 1)], [[]], [])
Expand Down Expand Up @@ -2894,7 +2894,7 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot):
]


def test_render_signal(make_snapshot):
def test_render_signal(make_snapshot, mocker):
@signal()
def check_types(batch, env: str, default: int = 0):
if env != "in_memory" or not default == 0:
Expand All @@ -2917,4 +2917,4 @@ def check_types(batch, env: str, default: int = 0):
signal_definitions=signal.get_registry(),
)
snapshot_a = make_snapshot(sql_model)
assert snapshot_a.check_ready_intervals([(0, 1)]) == [(0, 1)]
assert snapshot_a.check_ready_intervals([(0, 1)], mocker.Mock()) == [(0, 1)]
6 changes: 3 additions & 3 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3751,9 +3751,9 @@ def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot):

assert len(evaluator.adapters) == 3
assert evaluator.adapter == engine_adapters["default"]
assert evaluator._get_adapter() == engine_adapters["default"]
assert evaluator._get_adapter("third") == engine_adapters["third"]
assert evaluator._get_adapter("secondary") == engine_adapters["secondary"]
assert evaluator.get_adapter() == engine_adapters["default"]
assert evaluator.get_adapter("third") == engine_adapters["third"]
assert evaluator.get_adapter("secondary") == engine_adapters["secondary"]

model = load_sql_based_model(
parse( # type: ignore
Expand Down
Loading