Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions example/models/customer_revenue_by_day.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ MODEL (
name sushi.customer_revenue_by_day,
owner jen,
cron '@daily',
dialect hive,
batch_size 10,
time_column ds
);
Expand All @@ -16,7 +17,7 @@ WITH order_total AS (
LEFT JOIN sushi.items AS i
ON oi.item_id = i.id AND oi.ds = i.ds
WHERE
oi.ds BETWEEN @start_ds AND @end_ds
oi.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}'
Copy link
Copy Markdown
Collaborator

@izeigerman izeigerman Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noooooooooooooooooooooo 🤮 😆

GROUP BY
oi.order_id,
oi.ds
Expand All @@ -29,7 +30,7 @@ FROM sushi.orders AS o
LEFT JOIN order_total AS ot
ON o.id = ot.order_id AND o.ds = ot.ds
WHERE
o.ds BETWEEN @start_ds AND @end_ds
o.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}'
GROUP BY
o.customer_id,
o.ds
2 changes: 1 addition & 1 deletion example/models/order_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def execute(
"order_id": order["id"],
"item_id": item["id"],
"quantity": random.randint(1, 10),
"ds": dt,
"ds": to_ds(dt),
}
)
dfs.append(
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
"duckdb",
"dateparser",
"hyperscript",
"jinja2",
"pandas",
"pydantic",
"requests",
"rich",
"ruamel.yaml",
"sqlglot>=10.2.5",
"sqlglot>=10.2.6",
],
extras_require={
"dev": [
Expand Down
8 changes: 4 additions & 4 deletions sqlmesh/core/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _parse_expression(cls, v: str) -> exp.Expression:
@classmethod
def load(
cls,
expressions: t.List[exp.Expression | None],
expressions: t.List[exp.Expression],
*,
path: pathlib.Path,
dialect: t.Optional[str] = None,
Expand Down Expand Up @@ -178,7 +178,7 @@ def load(

audit = cls(
query=query,
expressions=[statement for statement in statements if statement],
expressions=statements,
**{
"dialect": dialect or "",
**AuditMeta(
Expand All @@ -197,12 +197,12 @@ def load(
@classmethod
def load_multiple(
cls,
expressions: t.List[exp.Expression | None],
expressions: t.List[exp.Expression],
*,
path: pathlib.Path,
dialect: t.Optional[str] = None,
) -> t.Generator[Audit, None, None]:
audit_block: t.List[t.Optional[exp.Expression]] = []
audit_block: t.List[exp.Expression] = []
for expression in expressions:
if isinstance(expression, d.Audit):
if audit_block:
Expand Down
10 changes: 5 additions & 5 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from io import StringIO
from pathlib import Path

from sqlglot import exp, parse
from sqlglot import exp

from sqlmesh.core import constants as c
from sqlmesh.core._typing import NotificationTarget
Expand All @@ -50,7 +50,7 @@
from sqlmesh.core.console import Console, get_console
from sqlmesh.core.context_diff import ContextDiff
from sqlmesh.core.dag import DAG
from sqlmesh.core.dialect import extend_sqlglot, format_model_expressions
from sqlmesh.core.dialect import extend_sqlglot, format_model_expressions, parse_model
from sqlmesh.core.engine_adapter import DF, EngineAdapter
from sqlmesh.core.environment import Environment
from sqlmesh.core.macros import macro
Expand Down Expand Up @@ -511,7 +511,7 @@ def format(self) -> None:
"""Format all models in a given directory."""
for model in self.models.values():
with open(model._path, "r+", encoding="utf-8") as file:
expressions = [e for e in parse(file.read(), read=self.dialect) if e]
expressions = parse_model(file.read(), default_dialect=self.dialect)
file.seek(0)
file.write(format_model_expressions(expressions, model.dialect))
file.truncate()
Expand Down Expand Up @@ -792,7 +792,7 @@ def _load_models(self):
for path in self._glob_path(self.models_directory_path, ".sql"):
self._path_mtimes[path] = path.stat().st_mtime
with open(path, "r", encoding="utf-8") as file:
expressions = parse(file.read(), read=self.dialect)
expressions = parse_model(file.read(), default_dialect=self.dialect)
model = Model.load(
expressions,
module=module,
Expand Down Expand Up @@ -827,7 +827,7 @@ def _load_audits(self) -> None:
for path in self._glob_path(self.audits_directory_path, ".sql"):
self._path_mtimes[path] = path.stat().st_mtime
with open(path, "r", encoding="utf-8") as file:
expressions = parse(file.read(), read=self.dialect)
expressions = parse_model(file.read(), default_dialect=self.dialect)
for audit in Audit.load_multiple(
expressions=expressions,
path=path,
Expand Down
82 changes: 81 additions & 1 deletion sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from __future__ import annotations

import re
import typing as t
from difflib import unified_diff

from jinja2 import Environment
from jinja2.meta import find_undeclared_variables
from sqlglot import Dialect, Generator, Parser, TokenType, exp
from sqlglot.tokens import Token

Expand All @@ -13,12 +18,19 @@ class Audit(exp.Expression):
arg_types = {"expressions": True}


class Jinja(exp.Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True


class MacroVar(exp.Var):
pass


class MacroFunc(exp.Func):
pass
@property
def name(self):
return self.this.name


class MacroDef(MacroFunc):
Expand Down Expand Up @@ -322,6 +334,69 @@ def text_diff(
)


DIALECT_PATTERN = re.compile(
r"(model|audit).*?\(.*?dialect.+?([a-z]+)", re.IGNORECASE | re.DOTALL
)
JINJA_PATTERN = re.compile(r"{{|{%|{#")


def parse_model(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]:
Comment thread
tobymao marked this conversation as resolved.
"""Parse a sql string containing a model definition.

If a jinja block is detected, the query is stored as raw string in a Jinja node.

Args:
sql: The sql based definition.
default_dialect: The dialect to use if the model does not specify one.

Returns:
A list of the expressions, [Model, *Statements, Query | Jinja]
"""
match = DIALECT_PATTERN.search(sql)
dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)()

tokens = dialect.tokenizer.tokenize(sql)
chunks: t.List[t.Tuple[t.List, bool]] = [([], False)]
total = len(tokens)

for i, token in enumerate(tokens):
if token.token_type == TokenType.SEMICOLON:
if i < total - 1:
chunks.append(([], False))
else:
if token.token_type == TokenType.BLOCK_START or (
token.token_type == TokenType.STRING
and JINJA_PATTERN.search(token.text)
):
chunks[-1] = (chunks[-1][0], True)
chunks[-1][0].append(token)

expressions: t.List[exp.Expression] = []
sql_lines = None

for chunk, is_jinja in chunks:
if is_jinja:
start, *_, end = chunk
sql_lines = sql_lines or sql.split("\n")
lines = sql_lines[start.line - 1 : end.line]
lines[0] = lines[0][start.col - 1 :]
lines[-1] = lines[-1][: end.col + len(end.text) - 1]
segment = "\n".join(lines)
variables = [
exp.Literal.string(var)
for var in find_undeclared_variables(Environment().parse(segment))
]
expressions.append(
Jinja(this=exp.Literal.string(segment), expressions=variables)
)
else:
for expression in dialect.parser().parse(chunk, sql):
if expression:
expressions.append(expression)

return expressions


@t.no_type_check
def extend_sqlglot() -> None:
"""Extend SQLGlot with SQLMesh's custom macro aware dialect."""
Expand Down Expand Up @@ -353,6 +428,11 @@ def extend_sqlglot() -> None:
)

for parser in parsers:
parser.FUNCTIONS.update(
{
"JINJA": Jinja.from_arg_list,
}
)
parser.PRIMARY_PARSERS.update(
{
TokenType.PARAMETER: _parse_macro,
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
self.macros = {
normalize_macro_name(k): v.func for k, v in macro.get_registry().items()
}
prepare_env(self.env, self.python_env)
prepare_env(self.python_env, self.env)
for k, v in self.python_env.items():
if v.is_definition:
self.macros[normalize_macro_name(k)] = self.env[v.name or k]
Expand Down
Loading