diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 5d9867facd..1f8908af4e 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -5,7 +5,7 @@ import types import typing as t from enum import Enum -from functools import reduce +from functools import lru_cache, reduce from itertools import chain from pathlib import Path from string import Template @@ -237,7 +237,7 @@ def evaluate_macros( self.transform(value) if isinstance(value, exp.Expression) else value ) if isinstance(node, exp.Identifier) and "@" in node.this: - text = self.template(node.this, self.locals) + text = self.template(node.this, {}) if node.this != text: changed = True node.args["this"] = text @@ -287,18 +287,9 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str: for k, v in chain(variables.items(), self.locals.items(), local_variables.items()): # try to convert all variables into sqlglot expressions # because they're going to be converted into strings in sql - # we use bare Exception instead of ValueError because there's - # a recursive error with MagicMock. # we don't convert strings because that would result in adding quotes - if not isinstance(v, str): - try: - v = exp.convert(v) - except Exception: - pass - - if isinstance(v, exp.Expression): - v = v.sql(dialect=self.dialect) - mapping[k] = v + if k != c.SQLMESH_VARS: + mapping[k] = convert_sql(v, self.dialect) return MacroStrTemplate(str(text)).safe_substitute(mapping) @@ -1378,3 +1369,30 @@ def _coerce( f"Coercion of expression '{expr}' to type '{typ}' failed. Using non coerced expression at '{path}'", ) return expr + + +def convert_sql(v: t.Any, dialect: DialectType) -> t.Any: + try: + return _cache_convert_sql(v, dialect, v.__class__) + # dicts aren't hashable but are convertable + except TypeError: + return _convert_sql(v, dialect) + + +def _convert_sql(v: t.Any, dialect: DialectType) -> t.Any: + if not isinstance(v, str): + try: + v = exp.convert(v) + # we use bare Exception instead of ValueError because there's + # a recursive error with MagicMock. + except Exception: + pass + + if isinstance(v, exp.Expression): + v = v.sql(dialect=dialect) + return v + + +@lru_cache(maxsize=1028) +def _cache_convert_sql(v: t.Any, dialect: DialectType, t: type) -> t.Any: + return _convert_sql(v, dialect)