Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import types
import typing as t
from enum import Enum
from functools import reduce
from functools import lru_cache, reduce
from itertools import chain
from pathlib import Path
from string import Template
Expand Down Expand Up @@ -237,7 +237,7 @@ def evaluate_macros(
self.transform(value) if isinstance(value, exp.Expression) else value
)
if isinstance(node, exp.Identifier) and "@" in node.this:
text = self.template(node.this, self.locals)
text = self.template(node.this, {})
if node.this != text:
changed = True
node.args["this"] = text
Expand Down Expand Up @@ -287,18 +287,9 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
for k, v in chain(variables.items(), self.locals.items(), local_variables.items()):
# try to convert all variables into sqlglot expressions
# because they're going to be converted into strings in sql
# we use bare Exception instead of ValueError because there's
# a recursive error with MagicMock.
# we don't convert strings because that would result in adding quotes
if not isinstance(v, str):
try:
v = exp.convert(v)
except Exception:
pass

if isinstance(v, exp.Expression):
v = v.sql(dialect=self.dialect)
mapping[k] = v
if k != c.SQLMESH_VARS:
mapping[k] = convert_sql(v, self.dialect)

return MacroStrTemplate(str(text)).safe_substitute(mapping)

Expand Down Expand Up @@ -1378,3 +1369,30 @@ def _coerce(
f"Coercion of expression '{expr}' to type '{typ}' failed. Using non coerced expression at '{path}'",
)
return expr


def convert_sql(v: t.Any, dialect: DialectType) -> t.Any:
try:
return _cache_convert_sql(v, dialect, v.__class__)
# dicts aren't hashable but are convertable
except TypeError:
return _convert_sql(v, dialect)


def _convert_sql(v: t.Any, dialect: DialectType) -> t.Any:
if not isinstance(v, str):
try:
v = exp.convert(v)
# we use bare Exception instead of ValueError because there's
# a recursive error with MagicMock.
except Exception:
pass

if isinstance(v, exp.Expression):
v = v.sql(dialect=dialect)
return v


@lru_cache(maxsize=1028)
def _cache_convert_sql(v: t.Any, dialect: DialectType, t: type) -> t.Any:
return _convert_sql(v, dialect)