diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 22c9535425..5291b035d8 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -350,7 +350,7 @@ def evaluate( def format(self) -> None: """Format all models in a given directory.""" for model in self.models.values(): - with open(model.path, "r+", encoding="utf-8") as file: + with open(model._path, "r+", encoding="utf-8") as file: expressions = [e for e in parse(file.read(), read=self.dialect) if e] file.seek(0) file.write(format_model_expressions(expressions, model.dialect)) diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 8a4744948c..04d63c3698 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -19,6 +19,7 @@ ) from sqlmesh.utils import registry_decorator from sqlmesh.utils.errors import MacroEvalError, SQLMeshError +from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception class MacroStrTemplate(Template): @@ -93,17 +94,21 @@ class MacroEvaluator: env: Python execution environment including global variables """ - def __init__(self, dialect: str = "", env: t.Optional[t.Dict[str, t.Any]] = None): - from sqlmesh.core.model import prepare_env - + def __init__( + self, dialect: str = "", env: t.Optional[t.Dict[str, Executable]] = None + ): self.dialect = dialect self.generator = MacroDialect().generator() self.locals: t.Dict[str, t.Any] = {} self.env = {**ENV, "self": self} + self.python_env = env or {} self.macros = { normalize_macro_name(k): v.func for k, v in macro.get_registry().items() } - prepare_env(self.env, env, self.macros) + prepare_env(self.env, self.python_env) + for k, v in self.python_env.items(): + if v.is_def: + self.macros[normalize_macro_name(k)] = self.env[v.name or k] def send( self, name: str, *args @@ -113,7 +118,11 @@ def send( if not callable(func): raise SQLMeshError(f"Macro '{name}' does not exist.") - return func(self, *args) + try: + return func(self, *args) + except Exception as e: + print_exception(e, self.python_env) + raise MacroEvalError(f"Error trying to eval macro.") from e def transform( self, query: exp.Expression @@ -194,6 +203,7 @@ def eval_expression(self, node: exp.Expression) -> t.Any: code = self.generator.generate(node) return eval(code, self.env, self.locals) except Exception as e: + print_exception(e, self.python_env) raise MacroEvalError( f"Error trying to eval macro.\n\nGenerated code: {code}\n\nOriginal sql: {node}" ) from e diff --git a/sqlmesh/core/model.py b/sqlmesh/core/model.py index fe5e381174..0ab2100926 100644 --- a/sqlmesh/core/model.py +++ b/sqlmesh/core/model.py @@ -259,13 +259,8 @@ from sqlmesh.core import dialect as d from sqlmesh.core.audit import Audit -from sqlmesh.core.macros import ( - MacroEvaluator, - MacroRegistry, - macro, - normalize_macro_name, -) -from sqlmesh.utils import UniqueKeyDict, registry_decorator, trim_path, unique +from sqlmesh.core.macros import MacroEvaluator, MacroRegistry, macro +from sqlmesh.utils import UniqueKeyDict, registry_decorator, unique from sqlmesh.utils.date import ( TimeLike, date_dict, @@ -274,7 +269,13 @@ to_datetime, ) from sqlmesh.utils.errors import ConfigError, SQLMeshError -from sqlmesh.utils.metaprogramming import build_env, print_exception, serialize_env +from sqlmesh.utils.metaprogramming import ( + Executable, + build_env, + prepare_env, + print_exception, + serialize_env, +) from sqlmesh.utils.pydantic import PydanticModel if t.TYPE_CHECKING: @@ -299,7 +300,6 @@ ), } -EXEC_PREFIX = "__EXEC__ " EPOCH_DS = "1970-01-01" @@ -613,10 +613,10 @@ class Model(ModelMeta, frozen=True): expressions_: t.Optional[t.List[exp.Expression]] = Field( default=None, alias="expressions" ) - python_env_: t.Optional[t.Dict[str, t.Any]] = Field( + python_env_: t.Optional[t.Dict[str, Executable]] = Field( default=None, alias="python_env" ) - path: Path = Path() + _path: Path = Path() _depends_on: t.Optional[t.Set[str]] = None _columns: t.Optional[t.Dict[str, exp.DataType]] = None _column_descriptions: t.Optional[t.Dict[str, str]] = None @@ -689,7 +689,6 @@ def load( query=query, expressions=statements, python_env=_python_env(query, module, macros or macro.get_registry()), - path=path, **{ "dialect": dialect or "", **ModelMeta( @@ -701,6 +700,8 @@ def load( }, ) + model._path = path + if time_column_format and model.time_column and not model.time_column.format: if dialect != model.dialect: # Transpile default time column format in default dialect to model's dialect @@ -998,10 +999,9 @@ def exec_python( ) return df except Exception as e: - path = trim_path(self.path, "models") - print_exception(e, self.python_env, str(path)) + print_exception(e, self.python_env) raise SQLMeshError( - f"Error executing Python model '{path}::{self.query.name}'" + f"Error executing Python model '{self.name}::{self.query.name}'" ) def render_audit_queries( @@ -1101,7 +1101,7 @@ def expressions(self) -> t.List[exp.Expression]: return self.expressions_ or [] @property - def python_env(self) -> t.Dict[str, t.Any]: + def python_env(self) -> t.Dict[str, Executable]: return self.python_env_ or {} def validate_definition(self) -> None: @@ -1120,7 +1120,7 @@ def validate_definition(self) -> None: if isinstance(expression, exp.Star): _raise_config_error( "SELECT * is not allowed you must explicitly select columns.", - self.path, + self._path, ) alias = expression.alias_or_name @@ -1131,13 +1131,13 @@ def validate_definition(self) -> None: elif not alias: _raise_config_error( f"Outer projection `{expression}` must have inferrable names or explicit aliases.", - self.path, + self._path, ) for name, count in name_counts.items(): if count > 1: _raise_config_error( - f"Found duplicate outer select name `{name}`", self.path + f"Found duplicate outer select name `{name}`", self._path ) if self.partitioned_by: @@ -1145,7 +1145,7 @@ def validate_definition(self) -> None: if len(self.partitioned_by) != len(unique_partition_keys): _raise_config_error( "All partition keys must be unique in the model definition", - self.path, + self._path, ) projections = {p.lower() for p in query.named_selects} @@ -1154,13 +1154,13 @@ def validate_definition(self) -> None: missing_keys_str = ", ".join(f"'{k}'" for k in sorted(missing_keys)) _raise_config_error( f"Partition keys [{missing_keys_str}] are missing in the query in the model definition", - self.path, + self._path, ) if self.kind == ModelKind.INCREMENTAL and not self.time_column: _raise_config_error( "Incremental models must have a time_column field.", - self.path, + self._path, ) def _filter_time_column( @@ -1236,47 +1236,14 @@ def model(self, module: str, path: Path) -> Model: module=module, ) - return Model( + model = Model( query=f"@{name}", - python_env=serialize_env(env, module=module, prefix=EXEC_PREFIX), - path=path, + python_env=serialize_env(env, module=module), **self.meta.dict(exclude_defaults=True), ) - -def strip_exec_prefix(code: str) -> str: - return code[len(EXEC_PREFIX) :] - - -def prepare_env( - env: t.Dict[str, t.Any], - python_env: t.Optional[t.Dict[str, t.Any]] = None, - functions: t.Optional[t.Dict[str, t.Callable]] = None, -) -> None: - """Prepare a python env by hydrating and executing functions. - - The Python ENV is stored in a json serializable format. - Because we store macros as function strings, we need to detect - when a variable is supposed to a function and then deserialize it - appropriately. We assume strings with EXEC_PREFIX are actually - functions and then execute them using our local env to hydrate them. - - Args: - env: The dictionary to execute code in. - python_env: The dictionary containing the serialized python environment. - functions: Optional dictionary to set function names. - """ - if python_env: - for name, value in python_env.items(): - if isinstance(value, str) and value.startswith(EXEC_PREFIX): - code = strip_exec_prefix(value) - exec(code, env) - if functions: - func = list(env.values())[-1] - if callable(func) and code.startswith("def "): - functions[normalize_macro_name(name)] = func - else: - env[name] = value + model._path = path + return model def parse_model_name(name: str) -> t.Tuple[t.Optional[str], t.Optional[str], str]: @@ -1316,8 +1283,8 @@ def find_tables(query: exp.Expression) -> t.Set[str]: def _python_env( query: exp.Expression, module: str, macros: MacroRegistry -) -> t.Dict[str, t.Any]: - python_env: t.Dict[str, t.Any] = {} +) -> t.Dict[str, Executable]: + python_env: t.Dict[str, Executable] = {} for macro_func in query.find_all(d.MacroFunc): if isinstance(macro_func, (d.MacroSQL, d.MacroStrReplace)): @@ -1332,7 +1299,7 @@ def _python_env( module=module, ) - return serialize_env(python_env, module=module, prefix=EXEC_PREFIX) + return serialize_env(python_env, module=module) def _raise_config_error(msg: str, location: t.Optional[str | Path] = None) -> None: diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index d8e8c7b025..3ab828f308 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -64,7 +64,7 @@ def model(self, line: str, sql: t.Optional[str] = None): loaded = Model.load( parse(sql), macros=self.context.macros, - path=model.path, + path=model._path, dialect=self.context.dialect, time_column_format=self.context.config.time_column_format, ) @@ -87,7 +87,7 @@ def model(self, line: str, sql: t.Optional[str] = None): replace=True, ) - with open(model.path, "w", encoding="utf-8") as file: + with open(model._path, "w", encoding="utf-8") as file: file.write(formatted) self.context.models.update({model.name: model}) diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index 58a6e2c3cf..e2959721fe 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -8,8 +8,12 @@ import traceback import types import typing as t +from enum import Enum +from pathlib import Path +from sqlmesh.utils import trim_path from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.pydantic import PydanticModel def _code_globals(code: types.CodeType) -> t.Dict[str, None]: @@ -138,9 +142,34 @@ def walk(obj: t.Any) -> None: ) -def serialize_env( - env: t.Dict[str, t.Any], *, module: str, prefix: str = "" -) -> t.Dict[str, t.Any]: +class ExecutableKind(str, Enum): + """The kind of of executable.""" + + DEF = "def" + IMPORT = "import" + VALUE = "value" + + +class Executable(PydanticModel): + payload: t.Any + kind: ExecutableKind = ExecutableKind.DEF + name: t.Optional[str] = None + path: t.Optional[str] = None + + @property + def is_def(self): + return self.kind == ExecutableKind.DEF + + @property + def is_import(self): + return self.kind == ExecutableKind.IMPORT + + @property + def is_value(self): + return self.kind == ExecutableKind.VALUE + + +def serialize_env(env: t.Dict[str, t.Any], module: str) -> t.Dict[str, Executable]: """Serializes a python function into a self contained dictionary. Recursively walks a function's globals to store all other references inside of env. @@ -148,30 +177,62 @@ def serialize_env( Args: env: Dictionary to store the env. module: The module to filter on. Other modules will not be walked and treated as imports. - prefix: Optional prefix to namespace the function definition. """ serialized = {} for k, v in env.items(): if callable(v): + name = v.__name__ + name = k if name == "" else name + if v.__module__.startswith(module): - serialized[k] = f"{prefix}{normalize_source(v)}" + serialized[k] = Executable( + name=name if name != k else None, + payload=normalize_source(v), + kind=ExecutableKind.DEF, + path=trim_path(Path(inspect.getfile(v)), module).name, + ) else: - serialized[k] = f"{prefix}from {v.__module__} import {k}" + serialized[k] = Executable( + payload=f"from {v.__module__} import {name}", + kind=ExecutableKind.IMPORT, + ) elif inspect.ismodule(v): name = v.__name__ postfix = "" if name == k else f" as {k}" - serialized[k] = f"{prefix}import {name}" + postfix + serialized[k] = Executable( + payload=f"import {name}{postfix}", + kind=ExecutableKind.IMPORT, + ) else: - serialized[k] = v + serialized[k] = Executable(payload=v, kind=ExecutableKind.VALUE) return serialized +def prepare_env( + env: t.Dict[str, t.Any], + python_env: t.Dict[str, Executable], +) -> None: + """Prepare a python env by hydrating and executing functions. + + The Python ENV is stored in a json serializable format. + Functions and imports are stored as a special data class. + + Args: + env: The dictionary to execute code in. + python_env: The dictionary containing the serialized python environment. + """ + for name, executable in python_env.items(): + if executable.is_value: + env[name] = executable.payload + else: + exec(executable.payload, env) + + def print_exception( exception: Exception, - python_env: t.Dict[str, t.Any], - path: str, + python_env: t.Dict[str, Executable], out=sys.stderr, ) -> None: """Formats exceptions that occur from evaled code. @@ -182,10 +243,7 @@ def print_exception( Args: exception: The exception to print the stack trace for. python_env: The environment containing stringified python code. - path: The path to show in the error message. """ - from sqlmesh.core.model import strip_exec_prefix - tb: t.List[str] = [] if sys.version_info < (3, 10): @@ -209,12 +267,12 @@ def print_exception( tb.append(error_line) continue + executable = python_env[func] indent = error_line[: match.start()] - error_line = ( - f"{indent}File '{path}' (or imported file), line {line_num}, in {func}" - ) - code = strip_exec_prefix(python_env[func]) + error_line = f"{indent}File '{executable.path}' (or imported file), line {line_num}, in {func}" + + code = executable.payload formatted = [] for i, code_line in enumerate(code.splitlines()): diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index 88a5552bb4..556154fe58 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -2,7 +2,7 @@ from sqlglot import exp, parse_one from sqlmesh.core.macros import MacroEvaluator, macro -from sqlmesh.core.model import EXEC_PREFIX +from sqlmesh.utils.metaprogramming import Executable @macro() @@ -15,7 +15,8 @@ def filter_country( @pytest.fixture def macro_evaluator() -> MacroEvaluator: return MacroEvaluator( - "hive", {"test": f"{EXEC_PREFIX}def test(_):\n return 'test'"} + "hive", + {"test": Executable(name="test", payload=f"def test(_):\n return 'test'")}, ) diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 0c4f6195d5..71bf4f7971 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -62,7 +62,6 @@ def test_json(snapshot: Snapshot): "model": { "audits": {}, "cron": "1 0 * * *", - "path": ".", "batch_size": 30, "kind": "incremental", "start": "2020-01-01", diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index 51ed22d8a4..18b73cd847 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -101,7 +101,6 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot, dag_run_entries: "kind": "incremental", "name": "test_model", "partitioned_by": ["a"], - "path": ".", "query": "SELECT a " "FROM tbl", "storage_format": "parquet", }, diff --git a/tests/utils/test_metaprogramming.py b/tests/utils/test_metaprogramming.py index 598fdafd6b..d8659d8186 100644 --- a/tests/utils/test_metaprogramming.py +++ b/tests/utils/test_metaprogramming.py @@ -6,11 +6,13 @@ from pytest_mock.plugin import MockerFixture from sqlglot.expressions import to_table -from sqlmesh.core.model import EXEC_PREFIX, prepare_env from sqlmesh.utils.metaprogramming import ( + Executable, + ExecutableKind, build_env, func_globals, normalize_source, + prepare_env, print_exception, serialize_env, ) @@ -103,26 +105,36 @@ def closure(z): def test_serialize_env() -> None: env: t.Dict[str, t.Any] = {} build_env(main_func, env=env, name="MAIN", module="tests") - env = serialize_env(env, module="tests", prefix="PREFIX ") # type: ignore + env = serialize_env(env, module="tests") # type: ignore lambda_no_args_padding = " " if sys.version_info < (3, 11) else "" assert env == { - "MAIN": """PREFIX def main_func(y): + "MAIN": Executable( + name="main_func", + path="test_metaprogramming.py", + payload="""def main_func(y): sqlglot.parse_one('1') MyClass() def closure(z): return z + Z return closure(y) + other_func(Y)""", - "X": 1, - "Y": 2, - "Z": 3, - "KLASS_X": 1, - "KLASS_Y": 2, - "KLASS_Z": 3, - "to_table": "PREFIX from sqlglot.expressions import to_table", - "MyClass": """PREFIX class MyClass: + ), + "X": Executable(payload=1, kind=ExecutableKind.VALUE), + "Y": Executable(payload=2, kind=ExecutableKind.VALUE), + "Z": Executable(payload=3, kind=ExecutableKind.VALUE), + "KLASS_X": Executable(payload=1, kind=ExecutableKind.VALUE), + "KLASS_Y": Executable(payload=2, kind=ExecutableKind.VALUE), + "KLASS_Z": Executable(payload=3, kind=ExecutableKind.VALUE), + "to_table": Executable( + kind=ExecutableKind.IMPORT, + payload="from sqlglot.expressions import to_table", + ), + "MyClass": Executable( + kind=ExecutableKind.DEF, + path="test_metaprogramming.py", + payload="""class MyClass: @staticmethod def foo(): @@ -134,48 +146,52 @@ def bar(cls): def baz(self): return KLASS_Z""", - "pd": "PREFIX import pandas as pd", - "sqlglot": "PREFIX import sqlglot", - "my_lambda": f"PREFIX my_lambda = lambda{lambda_no_args_padding}: print('z')", - "other_func": """PREFIX def other_func(a): + ), + "pd": Executable(payload="import pandas as pd", kind=ExecutableKind.IMPORT), + "sqlglot": Executable(kind=ExecutableKind.IMPORT, payload="import sqlglot"), + "my_lambda": Executable( + path="test_metaprogramming.py", + payload=f"my_lambda = lambda{lambda_no_args_padding}: print('z')", + ), + "other_func": Executable( + path="test_metaprogramming.py", + payload="""def other_func(a): import sqlglot sqlglot.parse_one('1') pd.DataFrame([{'x': 1}]) to_table('y') my_lambda() return X + a""", + ), } def test_print_exception(mocker: MockerFixture): out_mock = mocker.Mock() - test_code = """ - -def test_fun(): - raise RuntimeError("error") - -""" - - test_path = "/test/path.py" - test_env = {"test_fun": f"{EXEC_PREFIX}{test_code}"} + test_env = { + "test_fun": Executable( + name="test_func", + payload="""def test_fun(): + raise RuntimeError("error")""", + path="/test/path.py", + ), + } env: t.Dict[str, t.Any] = {} prepare_env(env, test_env) try: eval("test_fun()", env) except Exception as ex: - print_exception(ex, test_env, test_path, out_mock) + print_exception(ex, test_env, out_mock) expected_message = f"""Traceback (most recent call last): - File "{__file__}", line 165, in test_print_exception + File "{__file__}", line 196, in test_print_exception eval("test_fun()", env) File "", line 1, in - File '/test/path.py' (or imported file), line 4, in test_fun - - + File '/test/path.py' (or imported file), line 2, in test_fun def test_fun(): raise RuntimeError("error")