Skip to content

Commit 21433b9

Browse files
committed
feat!: add cron_tz to allow scheduling according to time zones
1 parent cdd97c5 commit 21433b9

10 files changed

Lines changed: 102 additions & 22 deletions

File tree

sqlmesh/core/audit/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def metadata_hash(self) -> str:
270270
*sorted(self.tags),
271271
str(self.sorted_python_env),
272272
self.stamp,
273+
self.cron,
274+
self.cron_tz.key if self.cron_tz else None,
273275
]
274276

275277
query = self.render_audit_query() or self.query

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,7 @@ def metadata_hash(self) -> str:
10861086
self.description,
10871087
json.dumps(self.column_descriptions, sort_keys=True),
10881088
self.cron,
1089+
self.cron_tz.key if self.cron_tz else None,
10891090
str(self.start) if self.start else None,
10901091
str(self.end) if self.end else None,
10911092
str(self.retention) if self.retention else None,

sqlmesh/core/node.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import typing as t
4+
import zoneinfo
45
from datetime import datetime
56
from enum import Enum
67
from pathlib import Path
@@ -177,6 +178,7 @@ class _Node(PydanticModel):
177178
the date from the scheduler will be used
178179
cron: A cron string specifying how often the node should be run, leveraging the
179180
[croniter](https://github.com/kiorky/croniter) library.
181+
cron_tz: Time zone for the cron, defaults to utc, [IANA time zones](https://docs.python.org/3/library/zoneinfo.html).
180182
interval_unit: The duration of an interval for the node. By default, it is computed from the cron expression.
181183
tags: A list of tags that can be used to filter nodes.
182184
stamp: An optional arbitrary string sequence used to create new node versions without making
@@ -190,6 +192,7 @@ class _Node(PydanticModel):
190192
start: t.Optional[TimeLike] = None
191193
end: t.Optional[TimeLike] = None
192194
cron: SQLGlotCron = "@daily"
195+
cron_tz: t.Optional[zoneinfo.ZoneInfo] = None
193196
interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None)
194197
tags: t.List[str] = []
195198
stamp: t.Optional[str] = None
@@ -226,6 +229,22 @@ def _name_validator(cls, v: t.Any) -> t.Optional[str]:
226229
return v.meta["sql"]
227230
return str(v)
228231

232+
@field_validator("cron_tz", mode="before")
233+
def _cron_tz_validator(cls, v: t.Any) -> t.Optional[zoneinfo.ZoneInfo]:
234+
if not v or v == "UTC":
235+
return None
236+
237+
v = str_or_exp_to_str(v)
238+
239+
try:
240+
return zoneinfo.ZoneInfo(v)
241+
except Exception as e:
242+
raise ConfigError(
243+
f"{e}. {v} must be in {zoneinfo.available_timezones()} or IANA time zone data is not available on your system. `pip install tzdata` to leverage cron time zones or remove this field which will default to UTC."
244+
)
245+
246+
return None
247+
229248
@field_validator("start", "end", mode="before")
230249
@classmethod
231250
def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]:
@@ -319,7 +338,7 @@ def croniter(self, value: TimeLike) -> CroniterCache:
319338
if self._croniter is None:
320339
self._croniter = CroniterCache(self.cron, value)
321340
else:
322-
self._croniter.curr = to_datetime(value)
341+
self._croniter.curr = to_datetime(value, tz=self.cron_tz)
323342
return self._croniter
324343

325344
def cron_next(self, value: TimeLike, estimate: bool = False) -> datetime:

sqlmesh/migrations/v0071_add_dev_version_to_intervals.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def migrate(state_sync, **kwargs): # type: ignore
3131
)
3232
engine_adapter.execute(alter_table_exp)
3333

34-
used_dev_versions: t.Set[t.Tuple[str, str]] = set()
35-
used_versions: t.Set[t.Tuple[str, str]] = set()
36-
used_snapshot_ids: t.Set[t.Tuple[str, str]] = set()
37-
snapshot_ids_to_dev_versions: t.Dict[t.Tuple[str, str], str] = {}
34+
used_dev_versions = set()
35+
used_versions = set()
36+
used_snapshot_ids = set()
37+
snapshot_ids_to_dev_versions = {}
3838

3939
_migrate_snapshots(
4040
engine_adapter,

sqlmesh/utils/cron.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import typing as t
4-
from datetime import datetime, timedelta
4+
from datetime import datetime, timedelta, tzinfo
55
from functools import lru_cache
66

77
from croniter import croniter
@@ -34,21 +34,22 @@ def interval_seconds(cron: str) -> int:
3434

3535

3636
class CroniterCache:
37-
def __init__(self, cron: str, time: t.Optional[TimeLike] = None):
37+
def __init__(self, cron: str, time: t.Optional[TimeLike] = None, tz: t.Optional[tzinfo] = None):
3838
self.cron = cron
39-
self.curr: datetime = to_datetime(now() if time is None else time)
39+
self.tz = tz
40+
self.curr: datetime = to_datetime(now() if time is None else time, tz=self.tz)
4041
self.interval_seconds = interval_seconds(self.cron)
4142

4243
def get_next(self, estimate: bool = False) -> datetime:
4344
if estimate and self.interval_seconds:
4445
self.curr = self.curr + timedelta(seconds=self.interval_seconds)
4546
else:
46-
self.curr = to_datetime(croniter(self.cron, self.curr).get_next() * 1000)
47+
self.curr = to_datetime(croniter(self.cron, self.curr).get_next() * 1000, tz=self.tz)
4748
return self.curr
4849

4950
def get_prev(self, estimate: bool = False) -> datetime:
5051
if estimate and self.interval_seconds:
5152
self.curr = self.curr - timedelta(seconds=self.interval_seconds)
5253
else:
53-
self.curr = to_datetime(croniter(self.cron, self.curr).get_prev() * 1000)
54+
self.curr = to_datetime(croniter(self.cron, self.curr).get_prev() * 1000, tz=self.tz)
5455
return self.curr

sqlmesh/utils/date.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import typing as t
66
import warnings
77

8-
from pandas.api.types import is_datetime64_any_dtype # type: ignore
9-
10-
from datetime import date, datetime, timedelta, timezone
8+
from datetime import date, datetime, timedelta, timezone, tzinfo
119

1210
import dateparser
1311
import pandas as pd
1412
from dateparser import freshness_date_parser as freshness_date_parser_module
1513
from dateparser.freshness_date_parser import freshness_date_parser
14+
from pandas.api.types import is_datetime64_any_dtype # type: ignore
1615
from sqlglot import exp
1716

1817
from sqlmesh.utils import ttl_cache
@@ -149,19 +148,21 @@ def to_datetime(
149148
value: TimeLike,
150149
relative_base: t.Optional[datetime] = None,
151150
check_categorical_relative_expression: bool = True,
151+
tz: t.Optional[tzinfo] = None,
152152
) -> datetime:
153153
"""Converts a value into a UTC datetime object.
154154
155155
Args:
156156
value: A variety of date formats. If the value is number-like, it is assumed to be millisecond epochs.
157157
relative_base: The datetime to reference for time expressions that are using relative terms.
158158
check_categorical_relative_expression: If True, takes into account the relative expressions that are categorical.
159+
tz: Timezone to convert datetime to, defaults to utc
159160
160161
Raises:
161162
ValueError if value cannot be converted to a datetime.
162163
163164
Returns:
164-
A datetime object with tz utc.
165+
A datetime object with tz (default UTC).
165166
"""
166167
if isinstance(value, datetime):
167168
dt: t.Optional[datetime] = value
@@ -198,9 +199,11 @@ def to_datetime(
198199
if dt is None:
199200
raise ValueError(f"Could not convert `{value}` to datetime.")
200201

202+
tz = tz or UTC
203+
201204
if dt.tzinfo:
202-
return dt if dt.tzinfo == UTC else dt.astimezone(UTC)
203-
return dt.replace(tzinfo=UTC)
205+
return dt if dt.tzinfo == tz else dt.astimezone(tz)
206+
return dt.replace(tzinfo=tz)
204207

205208

206209
def to_date(value: TimeLike, relative_base: t.Optional[datetime] = None) -> date:

sqlmesh/utils/pydantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import typing as t
5+
from datetime import tzinfo
56

67
import pydantic
78
from pydantic import ValidationInfo as ValidationInfo
@@ -72,6 +73,7 @@ class PydanticModel(pydantic.BaseModel):
7273
exp.Tuple: _expression_encoder,
7374
AuditQueryTypes: _expression_encoder, # type: ignore
7475
ModelQueryTypes: _expression_encoder, # type: ignore
76+
tzinfo: lambda tz: tz.key,
7577
},
7678
arbitrary_types_allowed=True,
7779
extra="forbid",

tests/core/test_integration.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4884,6 +4884,58 @@ def test_plan_production_environment_statements(tmp_path: Path):
48844884
assert environment_statements[0].python_env["__sqlmesh__vars__"].payload == "{'var_5': 5}"
48854885

48864886

4887+
@time_machine.travel("2025-03-08 00:00:00 UTC")
4888+
def test_tz(init_and_plan_context):
4889+
context, _ = init_and_plan_context("examples/sushi")
4890+
4891+
model = context.get_model("sushi.waiter_revenue_by_day")
4892+
context.upsert_model(
4893+
SqlModel.parse_obj(
4894+
{**model.dict(), "cron_tz": "America/Los_Angeles", "start": "2025-03-07"}
4895+
)
4896+
)
4897+
4898+
def assert_intervals(plan, intervals):
4899+
assert (
4900+
next(
4901+
intervals.intervals
4902+
for intervals in plan.missing_intervals
4903+
if intervals.snapshot_id.name == model.fqn
4904+
)
4905+
== intervals
4906+
)
4907+
4908+
plan = context.plan_builder("prod", skip_tests=True).build()
4909+
4910+
assert_intervals(plan, [(to_timestamp("2025-03-07"), to_timestamp("2025-03-08"))])
4911+
4912+
with time_machine.travel("2025-03-09 07:00:00 UTC"):
4913+
plan = context.plan_builder("prod", skip_tests=True).build()
4914+
4915+
assert_intervals(
4916+
plan,
4917+
[
4918+
(to_timestamp("2025-03-07"), to_timestamp("2025-03-08")),
4919+
],
4920+
)
4921+
4922+
with time_machine.travel("2025-03-09 08:00:00 UTC"):
4923+
plan = context.plan_builder("prod", skip_tests=True).build()
4924+
4925+
assert_intervals(
4926+
plan,
4927+
[
4928+
(to_timestamp("2025-03-07"), to_timestamp("2025-03-08")),
4929+
(to_timestamp("2025-03-08"), to_timestamp("2025-03-09")),
4930+
],
4931+
)
4932+
4933+
context.apply(plan)
4934+
4935+
plan = context.plan_builder("prod", skip_tests=True).build()
4936+
assert not plan.missing_intervals
4937+
4938+
48874939
def apply_to_environment(
48884940
context: Context,
48894941
environment: str,

tests/core/test_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,7 +2113,7 @@ def python_model_prop(context, **kwargs):
21132113
),
21142114
}
21152115

2116-
snapshot: Snapshot = make_snapshot(m)
2116+
snapshot = make_snapshot(m)
21172117
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
21182118

21192119
# Rendering the properties will result to a TRANSIENT creatable_type and the removal of the conditional prop
@@ -3674,7 +3674,7 @@ def test_conditional_physical_properties(make_snapshot):
36743674
)
36753675

36763676
# substitution occurs at runtime
3677-
snapshot: Snapshot = make_snapshot(full_model)
3677+
snapshot = make_snapshot(full_model)
36783678
snapshot_view: Snapshot = make_snapshot(view_model)
36793679
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
36803680

@@ -3775,7 +3775,7 @@ def test_model_defaults_macros(make_snapshot):
37753775
variables={"gateway": "dev", "create_type": "SECURE"},
37763776
)
37773777

3778-
snapshot: Snapshot = make_snapshot(model)
3778+
snapshot = make_snapshot(model)
37793779
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
37803780

37813781
# Validate rendering of model defaults
@@ -3881,7 +3881,7 @@ def python_model_prop_macro(context, **kwargs):
38813881
# Validate disabling attribute dynamically
38823882
assert not m.storage_format
38833883

3884-
snapshot: Snapshot = make_snapshot(m)
3884+
snapshot = make_snapshot(m)
38853885
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
38863886

38873887
# Ensure properties are not rendered at load time

tests/core/test_snapshot_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,8 +3192,8 @@ def insert(
31923192
def test_custom_materialization_strategy_with_custom_properties(adapter_mock, make_snapshot):
31933193
custom_insert_kind = None
31943194

3195-
class TestCustomKind(CustomKind): # type: ignore[no-untyped-def]
3196-
_primary_key: t.List[exp.Expression]
3195+
class TestCustomKind(CustomKind):
3196+
_primary_key: t.List[exp.Expression] # type: ignore[no-untyped-def]
31973197

31983198
@model_validator(mode="after")
31993199
def _validate_model(self) -> Self:

0 commit comments

Comments
 (0)