Skip to content

Commit 39754d7

Browse files
authored
Break up the DB state sync (#3903)
1 parent 28e242c commit 39754d7

15 files changed

Lines changed: 2742 additions & 2201 deletions

sqlmesh/core/state_sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121
)
2222
from sqlmesh.core.state_sync.cache import CachingStateSync as CachingStateSync
2323
from sqlmesh.core.state_sync.common import cleanup_expired_views as cleanup_expired_views
24-
from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync as EngineAdapterStateSync
24+
from sqlmesh.core.state_sync.db import EngineAdapterStateSync as EngineAdapterStateSync

sqlmesh/core/state_sync/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,9 @@ def raise_error(
252252
return versions
253253

254254
@abc.abstractmethod
255-
def _get_versions(self, lock_for_update: bool = False) -> Versions:
255+
def _get_versions(self) -> Versions:
256256
"""Queries the store to get the current versions of SQLMesh and deps.
257257
258-
Args:
259-
lock_for_update: Whether or not the usage of this method plans to update the row.
260-
261258
Returns:
262259
The versions object.
263260
"""
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from sqlmesh.core.state_sync.db.facade import EngineAdapterStateSync
2+
3+
__all__ = ["EngineAdapterStateSync"]
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
import pandas as pd
5+
import json
6+
import logging
7+
from sqlglot import exp
8+
9+
from sqlmesh.core import constants as c
10+
from sqlmesh.core.engine_adapter import EngineAdapter
11+
from sqlmesh.core.state_sync.db.utils import (
12+
fetchall,
13+
fetchone,
14+
)
15+
from sqlmesh.core.environment import Environment
16+
from sqlmesh.utils.migration import index_text_type, blob_text_type
17+
from sqlmesh.utils.date import now_timestamp, time_like_to_str
18+
from sqlmesh.utils.errors import SQLMeshError
19+
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class EnvironmentState:
25+
def __init__(
26+
self,
27+
engine_adapter: EngineAdapter,
28+
schema: t.Optional[str] = None,
29+
):
30+
self.engine_adapter = engine_adapter
31+
self.environments_table = exp.table_("_environments", db=schema)
32+
33+
index_type = index_text_type(engine_adapter.dialect)
34+
blob_type = blob_text_type(engine_adapter.dialect)
35+
36+
self._environment_columns_to_types = {
37+
"name": exp.DataType.build(index_type),
38+
"snapshots": exp.DataType.build(blob_type),
39+
"start_at": exp.DataType.build("text"),
40+
"end_at": exp.DataType.build("text"),
41+
"plan_id": exp.DataType.build("text"),
42+
"previous_plan_id": exp.DataType.build("text"),
43+
"expiration_ts": exp.DataType.build("bigint"),
44+
"finalized_ts": exp.DataType.build("bigint"),
45+
"promoted_snapshot_ids": exp.DataType.build(blob_type),
46+
"suffix_target": exp.DataType.build("text"),
47+
"catalog_name_override": exp.DataType.build("text"),
48+
"previous_finalized_snapshots": exp.DataType.build(blob_type),
49+
"normalize_name": exp.DataType.build("boolean"),
50+
"requirements": exp.DataType.build(blob_type),
51+
}
52+
53+
def update_environment(self, environment: Environment) -> None:
54+
"""Updates the environment.
55+
56+
Args:
57+
environment: The environment
58+
"""
59+
self.engine_adapter.delete_from(
60+
self.environments_table,
61+
where=exp.EQ(
62+
this=exp.column("name"),
63+
expression=exp.Literal.string(environment.name),
64+
),
65+
)
66+
67+
self.engine_adapter.insert_append(
68+
self.environments_table,
69+
_environment_to_df(environment),
70+
columns_to_types=self._environment_columns_to_types,
71+
)
72+
73+
def invalidate_environment(self, name: str) -> None:
74+
"""Invalidates the environment.
75+
76+
Args:
77+
name: The name of the environment
78+
"""
79+
name = name.lower()
80+
if name == c.PROD:
81+
raise SQLMeshError("Cannot invalidate the production environment.")
82+
83+
filter_expr = exp.column("name").eq(name)
84+
85+
self.engine_adapter.update_table(
86+
self.environments_table,
87+
{"expiration_ts": now_timestamp()},
88+
where=filter_expr,
89+
)
90+
91+
def finalize(self, environment: Environment) -> None:
92+
"""Finalize the target environment, indicating that this environment has been
93+
fully promoted and is ready for use.
94+
95+
Args:
96+
environment: The target environment to finalize.
97+
"""
98+
logger.info("Finalizing environment '%s'", environment.name)
99+
100+
environment_filter = exp.column("name").eq(exp.Literal.string(environment.name))
101+
102+
stored_plan_id_query = (
103+
exp.select("plan_id")
104+
.from_(self.environments_table)
105+
.where(environment_filter, copy=False)
106+
.lock(copy=False)
107+
)
108+
stored_plan_id_row = fetchone(self.engine_adapter, stored_plan_id_query)
109+
110+
if not stored_plan_id_row:
111+
raise SQLMeshError(f"Missing environment '{environment.name}' can't be finalized")
112+
113+
stored_plan_id = stored_plan_id_row[0]
114+
if stored_plan_id != environment.plan_id:
115+
raise SQLMeshError(
116+
f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. "
117+
f"Stored plan ID: '{stored_plan_id}'. Please recreate the plan and try again"
118+
)
119+
120+
environment.finalized_ts = now_timestamp()
121+
self.engine_adapter.update_table(
122+
self.environments_table,
123+
{"finalized_ts": environment.finalized_ts},
124+
where=environment_filter,
125+
)
126+
127+
def delete_expired_environments(self) -> t.List[Environment]:
128+
"""Deletes expired environments.
129+
130+
Returns:
131+
A list of deleted environments.
132+
"""
133+
now_ts = now_timestamp()
134+
filter_expr = exp.LTE(
135+
this=exp.column("expiration_ts"),
136+
expression=exp.Literal.number(now_ts),
137+
)
138+
139+
rows = fetchall(
140+
self.engine_adapter,
141+
self._environments_query(
142+
where=filter_expr,
143+
lock_for_update=True,
144+
),
145+
)
146+
environments = [self._environment_from_row(r) for r in rows]
147+
148+
self.engine_adapter.delete_from(
149+
self.environments_table,
150+
where=filter_expr,
151+
)
152+
153+
return environments
154+
155+
def get_environments(self) -> t.List[Environment]:
156+
"""Fetches all environments.
157+
158+
Returns:
159+
A list of all environments.
160+
"""
161+
return [
162+
self._environment_from_row(row)
163+
for row in fetchall(self.engine_adapter, self._environments_query())
164+
]
165+
166+
def get_environments_summary(self) -> t.Dict[str, int]:
167+
"""Fetches all environment names along with expiry datetime.
168+
169+
Returns:
170+
A dict of all environment names along with expiry datetime.
171+
"""
172+
return dict(
173+
fetchall(
174+
self.engine_adapter,
175+
self._environments_query(required_fields=["name", "expiration_ts"]),
176+
),
177+
)
178+
179+
def get_environment(
180+
self, environment: str, lock_for_update: bool = False
181+
) -> t.Optional[Environment]:
182+
"""Fetches the environment if it exists.
183+
184+
Args:
185+
environment: The environment
186+
lock_for_update: Lock the snapshot rows for future update
187+
188+
Returns:
189+
The environment object.
190+
"""
191+
row = fetchone(
192+
self.engine_adapter,
193+
self._environments_query(
194+
where=exp.EQ(
195+
this=exp.column("name"),
196+
expression=exp.Literal.string(environment),
197+
),
198+
lock_for_update=lock_for_update,
199+
),
200+
)
201+
202+
if not row:
203+
return None
204+
205+
env = self._environment_from_row(row)
206+
return env
207+
208+
def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment:
209+
return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())})
210+
211+
def _environments_query(
212+
self,
213+
where: t.Optional[str | exp.Expression] = None,
214+
lock_for_update: bool = False,
215+
required_fields: t.Optional[t.List[str]] = None,
216+
) -> exp.Select:
217+
query_fields = required_fields if required_fields else Environment.all_fields()
218+
query = (
219+
exp.select(*(exp.to_identifier(field) for field in query_fields))
220+
.from_(self.environments_table)
221+
.where(where)
222+
)
223+
if lock_for_update:
224+
return query.lock(copy=False)
225+
return query
226+
227+
228+
def _environment_to_df(environment: Environment) -> pd.DataFrame:
229+
return pd.DataFrame(
230+
[
231+
{
232+
"name": environment.name,
233+
"snapshots": json.dumps(environment.snapshot_dicts()),
234+
"start_at": time_like_to_str(environment.start_at),
235+
"end_at": time_like_to_str(environment.end_at) if environment.end_at else None,
236+
"plan_id": environment.plan_id,
237+
"previous_plan_id": environment.previous_plan_id,
238+
"expiration_ts": environment.expiration_ts,
239+
"finalized_ts": environment.finalized_ts,
240+
"promoted_snapshot_ids": (
241+
json.dumps(environment.promoted_snapshot_id_dicts())
242+
if environment.promoted_snapshot_ids is not None
243+
else None
244+
),
245+
"suffix_target": environment.suffix_target.value,
246+
"catalog_name_override": environment.catalog_name_override,
247+
"previous_finalized_snapshots": (
248+
json.dumps(environment.previous_finalized_snapshot_dicts())
249+
if environment.previous_finalized_snapshots is not None
250+
else None
251+
),
252+
"normalize_name": environment.normalize_name,
253+
"requirements": json.dumps(environment.requirements),
254+
}
255+
]
256+
)

0 commit comments

Comments
 (0)