diff --git a/docs/guides/signals.md b/docs/guides/signals.md index 385676df42..bc76d9d655 100644 --- a/docs/guides/signals.md +++ b/docs/guides/signals.md @@ -116,3 +116,18 @@ MODEL ( SELECT @start_ds AS ds ``` + +### Accessing execution context / engine adapter +It is possible to access the execution context in a signal and access the engine adapter (warehouse connection). + +```python +import typing as t + +from sqlmesh import signal, DatetimeRanges, ExecutionContext + + +# add the context argument to your function +@signal() +def one_week_ago(batch: DatetimeRanges, context: ExecutionContext) -> t.Union[bool, DatetimeRanges]: + return len(context.engine_adapter.fetchdf("SELECT 1")) > 1 +``` diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 9b0cdfba8a..4928bc0620 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -214,7 +214,9 @@ def send( raise SQLMeshError(f"Macro '{name}' does not exist.") try: - return call_macro(func, self.dialect, self._path, self, *args, **kwargs) # type: ignore + return call_macro( + func, self.dialect, self._path, provided_args=(self, *args), provided_kwargs=kwargs + ) # type: ignore except Exception as e: print_exception(e, self.python_env) raise MacroEvalError("Error trying to eval macro.") from e @@ -1286,12 +1288,21 @@ def call_macro( func: t.Callable, dialect: DialectType, path: Path, - *args: t.Any, - **kwargs: t.Any, + provided_args: t.Tuple[t.Any, ...], + provided_kwargs: t.Dict[str, t.Any], + **optional_kwargs: t.Any, ) -> t.Any: # Bind the macro's actual parameters to its formal parameters sig = inspect.signature(func) - bound = sig.bind(*args, **kwargs) + + if optional_kwargs: + provided_kwargs = provided_kwargs.copy() + + for k, v in optional_kwargs.items(): + if k in sig.parameters: + provided_kwargs[k] = v + + bound = sig.bind(*provided_args, **provided_kwargs) bound.apply_defaults() try: diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 0052cb785b..6d8eef561e 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -95,6 +95,7 @@ def __init__( ): self.state_sync = state_sync self.snapshots = {s.snapshot_id: s for s in snapshots} + self.snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()} self.snapshot_per_version = _resolve_one_snapshot_per_version(self.snapshots.values()) self.default_catalog = default_catalog self.snapshot_evaluator = snapshot_evaluator @@ -348,7 +349,11 @@ def run( return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS - def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]: + def batch_intervals( + self, + merged_intervals: SnapshotToIntervals, + deployability_index: t.Optional[DeployabilityIndex], + ) -> t.Dict[Snapshot, Intervals]: dag = snapshots_to_dag(merged_intervals) snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = { @@ -369,7 +374,20 @@ def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snaps continue snapshot, intervals = snapshot_intervals[snapshot_id] unready = set(intervals) - intervals = snapshot.check_ready_intervals(intervals) + + from sqlmesh.core.context import ExecutionContext + + adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway) + + context = ExecutionContext( + adapter, + self.snapshots_by_name, + deployability_index, + default_dialect=adapter.dialect, + default_catalog=self.default_catalog, + ) + + intervals = snapshot.check_ready_intervals(intervals, context) unready -= set(intervals) for parent in snapshot.parents: @@ -424,7 +442,7 @@ def run_merged_intervals( """ execution_time = execution_time or now_timestamp() - batched_intervals = self.batch_intervals(merged_intervals) + batched_intervals = self.batch_intervals(merged_intervals, deployability_index) self.console.start_evaluation_progress( {snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()}, @@ -434,8 +452,6 @@ def run_merged_intervals( dag = self._dag(batched_intervals) - snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()} - if run_environment_statements: environment_statements = self.state_sync.get_environment_statements( environment_naming_info.name @@ -446,7 +462,7 @@ def run_merged_intervals( runtime_stage=RuntimeStage.BEFORE_ALL, environment_naming_info=environment_naming_info, default_catalog=self.default_catalog, - snapshots=snapshots_by_name, + snapshots=self.snapshots_by_name, start=start, end=end, execution_time=execution_time, @@ -459,7 +475,7 @@ def evaluate_node(node: SchedulingUnit) -> None: snapshot_name, ((start, end), batch_idx) = node if batch_idx == -1: return - snapshot = snapshots_by_name[snapshot_name] + snapshot = self.snapshots_by_name[snapshot_name] self.console.start_snapshot_evaluation_progress(snapshot) @@ -520,7 +536,7 @@ def evaluate_node(node: SchedulingUnit) -> None: runtime_stage=RuntimeStage.AFTER_ALL, environment_naming_info=environment_naming_info, default_catalog=self.default_catalog, - snapshots=snapshots_by_name, + snapshots=self.snapshots_by_name, start=start, end=end, execution_time=execution_time, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index f208001904..56b638cb5a 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -46,6 +46,7 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType from sqlmesh.core.environment import EnvironmentNamingInfo + from sqlmesh.core.context import ExecutionContext Interval = t.Tuple[int, int] Intervals = t.List[Interval] @@ -940,7 +941,7 @@ def missing_intervals( model_end_ts, ) - def check_ready_intervals(self, intervals: Intervals) -> Intervals: + def check_ready_intervals(self, intervals: Intervals, context: ExecutionContext) -> Intervals: """Returns a list of intervals that are considered ready by the provided signal. Note that this will handle gaps in the provided intervals. The returned intervals @@ -959,6 +960,7 @@ def check_ready_intervals(self, intervals: Intervals) -> Intervals: intervals = _check_ready_intervals( env[signal_name], intervals, + context, dialect=self.model.dialect, path=self.model._path, kwargs=kwargs, @@ -2148,6 +2150,7 @@ def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]: def _check_ready_intervals( check: t.Callable, intervals: Intervals, + context: ExecutionContext, dialect: DialectType = None, path: Path = Path(), kwargs: t.Optional[t.Dict] = None, @@ -2158,7 +2161,14 @@ def _check_ready_intervals( batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch] try: - ready_intervals = call_macro(check, dialect, path, batch, **(kwargs or {})) + ready_intervals = call_macro( + check, + dialect, + path, + provided_args=(batch,), + provided_kwargs=(kwargs or {}), + context=context, + ) except Exception: raise SQLMeshError("Error evaluating signal") diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index b56248f15d..aa67453c75 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -321,7 +321,7 @@ def create( def _get_data_objects(schema: exp.Table, gateway: t.Optional[str] = None) -> t.Set[str]: logger.info("Listing data objects in schema %s", schema.sql()) - objs = self._get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema]) + objs = self.get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema]) return {obj.name for obj in objs} with self.concurrent_context(): @@ -409,7 +409,7 @@ def migrate( s, snapshots, allow_destructive_snapshots, - self._get_adapter(s.model_gateway), + self.get_adapter(s.model_gateway), deployability_index, ), self.ddl_concurrent_tasks, @@ -437,7 +437,7 @@ def cleanup( lambda s: self._cleanup_snapshot( s, snapshots_to_dev_table_only[s.snapshot_id], - self._get_adapter( + self.get_adapter( snapshot_gateways.get(s.snapshot_id.name) if snapshot_gateways else None ), on_complete, @@ -471,7 +471,7 @@ def audit( kwargs: Additional kwargs to pass to the renderer. """ deployability_index = deployability_index or DeployabilityIndex.all_deployable() - adapter = self._get_adapter(snapshot.model_gateway) + adapter = self.get_adapter(snapshot.model_gateway) if not snapshot.version: raise ConfigError( @@ -605,7 +605,7 @@ def _evaluate_snapshot( else snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) ) - adapter = self._get_adapter(model.gateway) + adapter = self.get_adapter(model.gateway) evaluation_strategy = _evaluation_strategy(snapshot, adapter) # https://github.com/TobikoData/sqlmesh/issues/2609 @@ -764,7 +764,7 @@ def _create_snapshot( deployability_index = deployability_index or DeployabilityIndex.all_deployable() - adapter = self._get_adapter(snapshot.model.gateway) + adapter = self.get_adapter(snapshot.model.gateway) create_render_kwargs: t.Dict[str, t.Any] = dict( engine_adapter=adapter, snapshots=parent_snapshots_by_name(snapshot, snapshots), @@ -994,7 +994,7 @@ def _wap_publish_snapshot( ) -> None: deployability_index = deployability_index or DeployabilityIndex.all_deployable() table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) - adapter = self._get_adapter(snapshot.model_gateway) + adapter = self.get_adapter(snapshot.model_gateway) adapter.wap_publish(table_name, wap_id) def _audit( @@ -1021,7 +1021,7 @@ def _audit( blocking = audit_args.pop("blocking", None) blocking = blocking == exp.true() if blocking else audit.blocking - adapter = self._get_adapter(snapshot.model_gateway) + adapter = self.get_adapter(snapshot.model_gateway) kwargs = { "start": start, @@ -1068,10 +1068,10 @@ def _create_schemas( for schema_name, catalog in unique_schemas: schema = schema_(schema_name, catalog) logger.info("Creating schema '%s'", schema) - adapter = self._get_adapter(gateways.get(schema)) if gateways else self.adapter + adapter = self.get_adapter(gateways.get(schema)) if gateways else self.adapter adapter.create_schema(schema) - def _get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: + def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: """Returns the adapter for the specified gateway or the default adapter if none is provided.""" if gateway: if adapter := self.adapters.get(gateway): @@ -1089,7 +1089,7 @@ def _execute_create( rendered_physical_properties: t.Dict[str, exp.Expression], dry_run: bool, ) -> None: - adapter = self._get_adapter(snapshot.model.gateway) + adapter = self.get_adapter(snapshot.model.gateway) evaluation_strategy = _evaluation_strategy(snapshot, adapter) # It can still be useful for some strategies to know if the snapshot was actually deployable diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index cfe3bf52bb..41e5e540de 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -5,7 +5,7 @@ from sqlglot import parse_one, parse from sqlglot.helper import first -from sqlmesh.core.context import Context +from sqlmesh.core.context import Context, ExecutionContext from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import AuditResult, SqlModel @@ -66,9 +66,9 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context @pytest.fixture -def get_batched_missing_intervals() -> ( - t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals] -): +def get_batched_missing_intervals( + mocker: MockerFixture, +) -> t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]: def _get_batched_missing_intervals( scheduler: Scheduler, start: TimeLike, @@ -76,7 +76,7 @@ def _get_batched_missing_intervals( execution_time: t.Optional[TimeLike] = None, ) -> SnapshotToIntervals: merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time) - return scheduler.batch_intervals(merged_intervals) + return scheduler.batch_intervals(merged_intervals, mocker.Mock()) return _get_batched_missing_intervals @@ -622,7 +622,9 @@ def test_interval_diff(): def test_signal_intervals(mocker: MockerFixture, make_snapshot, get_batched_missing_intervals): @signal() - def signal_a(batch: DatetimeRanges): + def signal_a(batch: DatetimeRanges, context: ExecutionContext): + if not hasattr(context, "engine_adapter"): + raise return [batch[0], batch[1]] @signal() diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 16547423fb..bc35997bf5 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -2499,7 +2499,7 @@ def test_contiguous_intervals(): def test_check_ready_intervals(mocker: MockerFixture): def assert_always_signal(intervals): - assert _check_ready_intervals(lambda _: True, intervals) == intervals + assert _check_ready_intervals(lambda _: True, intervals, mocker.Mock()) == intervals assert_always_signal([]) assert_always_signal([(0, 1)]) @@ -2507,7 +2507,7 @@ def assert_always_signal(intervals): assert_always_signal([(0, 1), (2, 3)]) def assert_never_signal(intervals): - assert _check_ready_intervals(lambda _: False, intervals) == [] + assert _check_ready_intervals(lambda _: False, intervals, mocker.Mock()) == [] assert_never_signal([]) assert_never_signal([(0, 1)]) @@ -2515,7 +2515,7 @@ def assert_never_signal(intervals): assert_never_signal([(0, 1), (2, 3)]) def assert_empty_signal(intervals): - assert _check_ready_intervals(lambda _: [], intervals) == [] + assert _check_ready_intervals(lambda _: [], intervals, mocker.Mock()) == [] assert_empty_signal([]) assert_empty_signal([(0, 1)]) @@ -2532,7 +2532,7 @@ def assert_check_intervals( ): mock = mocker.Mock() mock.side_effect = [to_intervals(r) for r in ready] - _check_ready_intervals(mock, intervals) == expected + _check_ready_intervals(mock, intervals, mocker.Mock()) == expected assert_check_intervals([], [], []) assert_check_intervals([(0, 1)], [[]], []) @@ -2894,7 +2894,7 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot): ] -def test_render_signal(make_snapshot): +def test_render_signal(make_snapshot, mocker): @signal() def check_types(batch, env: str, default: int = 0): if env != "in_memory" or not default == 0: @@ -2917,4 +2917,4 @@ def check_types(batch, env: str, default: int = 0): signal_definitions=signal.get_registry(), ) snapshot_a = make_snapshot(sql_model) - assert snapshot_a.check_ready_intervals([(0, 1)]) == [(0, 1)] + assert snapshot_a.check_ready_intervals([(0, 1)], mocker.Mock()) == [(0, 1)] diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 728ba8bf02..59c5bbd965 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -3751,9 +3751,9 @@ def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot): assert len(evaluator.adapters) == 3 assert evaluator.adapter == engine_adapters["default"] - assert evaluator._get_adapter() == engine_adapters["default"] - assert evaluator._get_adapter("third") == engine_adapters["third"] - assert evaluator._get_adapter("secondary") == engine_adapters["secondary"] + assert evaluator.get_adapter() == engine_adapters["default"] + assert evaluator.get_adapter("third") == engine_adapters["third"] + assert evaluator.get_adapter("secondary") == engine_adapters["secondary"] model = load_sql_based_model( parse( # type: ignore diff --git a/web/server/api/endpoints/plan.py b/web/server/api/endpoints/plan.py index 6bccb38188..47bab6626e 100644 --- a/web/server/api/endpoints/plan.py +++ b/web/server/api/endpoints/plan.py @@ -132,7 +132,7 @@ def _get_plan_changes(context: Context, plan: Plan) -> models.PlanChanges: def _get_plan_backfills(context: Context, plan: Plan) -> t.Dict[str, t.Any]: """Get plan backfills""" merged_intervals = context.scheduler().merged_missing_intervals() - batches = context.scheduler().batch_intervals(merged_intervals) + batches = context.scheduler().batch_intervals(merged_intervals, None) tasks = {snapshot.name: len(intervals) for snapshot, intervals in batches.items()} snapshots = plan.context_diff.snapshots default_catalog = context.default_catalog