Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def evaluate(
def format(self) -> None:
"""Format all models in a given directory."""
for model in self.models.values():
with open(model.path, "r+", encoding="utf-8") as file:
with open(model._path, "r+", encoding="utf-8") as file:
expressions = [e for e in parse(file.read(), read=self.dialect) if e]
file.seek(0)
file.write(format_model_expressions(expressions, model.dialect))
Expand Down
20 changes: 15 additions & 5 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from sqlmesh.utils import registry_decorator
from sqlmesh.utils.errors import MacroEvalError, SQLMeshError
from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception


class MacroStrTemplate(Template):
Expand Down Expand Up @@ -93,17 +94,21 @@ class MacroEvaluator:
env: Python execution environment including global variables
"""

def __init__(self, dialect: str = "", env: t.Optional[t.Dict[str, t.Any]] = None):
from sqlmesh.core.model import prepare_env

def __init__(
self, dialect: str = "", env: t.Optional[t.Dict[str, Executable]] = None
):
self.dialect = dialect
self.generator = MacroDialect().generator()
self.locals: t.Dict[str, t.Any] = {}
self.env = {**ENV, "self": self}
self.python_env = env or {}
self.macros = {
normalize_macro_name(k): v.func for k, v in macro.get_registry().items()
}
prepare_env(self.env, env, self.macros)
prepare_env(self.env, self.python_env)
for k, v in self.python_env.items():
if v.is_def:
self.macros[normalize_macro_name(k)] = self.env[v.name or k]

def send(
self, name: str, *args
Expand All @@ -113,7 +118,11 @@ def send(
if not callable(func):
raise SQLMeshError(f"Macro '{name}' does not exist.")

return func(self, *args)
try:
return func(self, *args)
except Exception as e:
print_exception(e, self.python_env)
raise MacroEvalError(f"Error trying to eval macro.") from e

def transform(
self, query: exp.Expression
Expand Down Expand Up @@ -194,6 +203,7 @@ def eval_expression(self, node: exp.Expression) -> t.Any:
code = self.generator.generate(node)
return eval(code, self.env, self.locals)
except Exception as e:
print_exception(e, self.python_env)
raise MacroEvalError(
f"Error trying to eval macro.\n\nGenerated code: {code}\n\nOriginal sql: {node}"
) from e
Expand Down
91 changes: 29 additions & 62 deletions sqlmesh/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,8 @@

from sqlmesh.core import dialect as d
from sqlmesh.core.audit import Audit
from sqlmesh.core.macros import (
MacroEvaluator,
MacroRegistry,
macro,
normalize_macro_name,
)
from sqlmesh.utils import UniqueKeyDict, registry_decorator, trim_path, unique
from sqlmesh.core.macros import MacroEvaluator, MacroRegistry, macro
from sqlmesh.utils import UniqueKeyDict, registry_decorator, unique
from sqlmesh.utils.date import (
TimeLike,
date_dict,
Expand All @@ -274,7 +269,13 @@
to_datetime,
)
from sqlmesh.utils.errors import ConfigError, SQLMeshError
from sqlmesh.utils.metaprogramming import build_env, print_exception, serialize_env
from sqlmesh.utils.metaprogramming import (
Executable,
build_env,
prepare_env,
print_exception,
serialize_env,
)
from sqlmesh.utils.pydantic import PydanticModel

if t.TYPE_CHECKING:
Expand All @@ -299,7 +300,6 @@
),
}

EXEC_PREFIX = "__EXEC__ "
EPOCH_DS = "1970-01-01"


Expand Down Expand Up @@ -613,10 +613,10 @@ class Model(ModelMeta, frozen=True):
expressions_: t.Optional[t.List[exp.Expression]] = Field(
default=None, alias="expressions"
)
python_env_: t.Optional[t.Dict[str, t.Any]] = Field(
python_env_: t.Optional[t.Dict[str, Executable]] = Field(
default=None, alias="python_env"
)
path: Path = Path()
_path: Path = Path()
_depends_on: t.Optional[t.Set[str]] = None
_columns: t.Optional[t.Dict[str, exp.DataType]] = None
_column_descriptions: t.Optional[t.Dict[str, str]] = None
Expand Down Expand Up @@ -689,7 +689,6 @@ def load(
query=query,
expressions=statements,
python_env=_python_env(query, module, macros or macro.get_registry()),
path=path,
**{
"dialect": dialect or "",
**ModelMeta(
Expand All @@ -701,6 +700,8 @@ def load(
},
)

model._path = path

if time_column_format and model.time_column and not model.time_column.format:
if dialect != model.dialect:
# Transpile default time column format in default dialect to model's dialect
Expand Down Expand Up @@ -998,10 +999,9 @@ def exec_python(
)
return df
except Exception as e:
path = trim_path(self.path, "models")
print_exception(e, self.python_env, str(path))
print_exception(e, self.python_env)
raise SQLMeshError(
f"Error executing Python model '{path}::{self.query.name}'"
f"Error executing Python model '{self.name}::{self.query.name}'"
)

def render_audit_queries(
Expand Down Expand Up @@ -1101,7 +1101,7 @@ def expressions(self) -> t.List[exp.Expression]:
return self.expressions_ or []

@property
def python_env(self) -> t.Dict[str, t.Any]:
def python_env(self) -> t.Dict[str, Executable]:
return self.python_env_ or {}

def validate_definition(self) -> None:
Expand All @@ -1120,7 +1120,7 @@ def validate_definition(self) -> None:
if isinstance(expression, exp.Star):
_raise_config_error(
"SELECT * is not allowed you must explicitly select columns.",
self.path,
self._path,
)

alias = expression.alias_or_name
Expand All @@ -1131,21 +1131,21 @@ def validate_definition(self) -> None:
elif not alias:
_raise_config_error(
f"Outer projection `{expression}` must have inferrable names or explicit aliases.",
self.path,
self._path,
)

for name, count in name_counts.items():
if count > 1:
_raise_config_error(
f"Found duplicate outer select name `{name}`", self.path
f"Found duplicate outer select name `{name}`", self._path
)

if self.partitioned_by:
unique_partition_keys = {k.strip().lower() for k in self.partitioned_by}
if len(self.partitioned_by) != len(unique_partition_keys):
_raise_config_error(
"All partition keys must be unique in the model definition",
self.path,
self._path,
)

projections = {p.lower() for p in query.named_selects}
Expand All @@ -1154,13 +1154,13 @@ def validate_definition(self) -> None:
missing_keys_str = ", ".join(f"'{k}'" for k in sorted(missing_keys))
_raise_config_error(
f"Partition keys [{missing_keys_str}] are missing in the query in the model definition",
self.path,
self._path,
)

if self.kind == ModelKind.INCREMENTAL and not self.time_column:
_raise_config_error(
"Incremental models must have a time_column field.",
self.path,
self._path,
)

def _filter_time_column(
Expand Down Expand Up @@ -1236,47 +1236,14 @@ def model(self, module: str, path: Path) -> Model:
module=module,
)

return Model(
model = Model(
query=f"@{name}",
python_env=serialize_env(env, module=module, prefix=EXEC_PREFIX),
path=path,
python_env=serialize_env(env, module=module),
**self.meta.dict(exclude_defaults=True),
)


def strip_exec_prefix(code: str) -> str:
return code[len(EXEC_PREFIX) :]


def prepare_env(
env: t.Dict[str, t.Any],
python_env: t.Optional[t.Dict[str, t.Any]] = None,
functions: t.Optional[t.Dict[str, t.Callable]] = None,
) -> None:
"""Prepare a python env by hydrating and executing functions.

The Python ENV is stored in a json serializable format.
Because we store macros as function strings, we need to detect
when a variable is supposed to a function and then deserialize it
appropriately. We assume strings with EXEC_PREFIX are actually
functions and then execute them using our local env to hydrate them.

Args:
env: The dictionary to execute code in.
python_env: The dictionary containing the serialized python environment.
functions: Optional dictionary to set function names.
"""
if python_env:
for name, value in python_env.items():
if isinstance(value, str) and value.startswith(EXEC_PREFIX):
code = strip_exec_prefix(value)
exec(code, env)
if functions:
func = list(env.values())[-1]
if callable(func) and code.startswith("def "):
functions[normalize_macro_name(name)] = func
else:
env[name] = value
model._path = path
return model


def parse_model_name(name: str) -> t.Tuple[t.Optional[str], t.Optional[str], str]:
Expand Down Expand Up @@ -1316,8 +1283,8 @@ def find_tables(query: exp.Expression) -> t.Set[str]:

def _python_env(
query: exp.Expression, module: str, macros: MacroRegistry
) -> t.Dict[str, t.Any]:
python_env: t.Dict[str, t.Any] = {}
) -> t.Dict[str, Executable]:
python_env: t.Dict[str, Executable] = {}

for macro_func in query.find_all(d.MacroFunc):
if isinstance(macro_func, (d.MacroSQL, d.MacroStrReplace)):
Expand All @@ -1332,7 +1299,7 @@ def _python_env(
module=module,
)

return serialize_env(python_env, module=module, prefix=EXEC_PREFIX)
return serialize_env(python_env, module=module)


def _raise_config_error(msg: str, location: t.Optional[str | Path] = None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def model(self, line: str, sql: t.Optional[str] = None):
loaded = Model.load(
parse(sql),
macros=self.context.macros,
path=model.path,
path=model._path,
dialect=self.context.dialect,
time_column_format=self.context.config.time_column_format,
)
Expand All @@ -87,7 +87,7 @@ def model(self, line: str, sql: t.Optional[str] = None):
replace=True,
)

with open(model.path, "w", encoding="utf-8") as file:
with open(model._path, "w", encoding="utf-8") as file:
file.write(formatted)

self.context.models.update({model.name: model})
Expand Down
Loading