diff --git a/setup.py b/setup.py index b9e7c54c85..5f724318c4 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "duckdb", "dateparser", "hyperscript", + "jinja2", "pandas", "pydantic", "requests", diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 7b690102ce..dcff656ae6 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -49,7 +49,7 @@ from sqlmesh.core.console import Console, get_console from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.dag import DAG -from sqlmesh.core.dialect import extend_sqlglot, format_model_expressions +from sqlmesh.core.dialect import JinjaModel, extend_sqlglot, format_model_expressions from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.environment import Environment from sqlmesh.core.macros import macro @@ -65,6 +65,7 @@ from sqlmesh.utils.date import TimeLike, yesterday_ds from sqlmesh.utils.errors import ConfigError, MissingDependencyError, PlanError from sqlmesh.utils.file_cache import FileCache +from sqlmesh.utils.jinja import JINJA_RE if t.TYPE_CHECKING: import graphviz @@ -657,7 +658,13 @@ def _load_models(self): for path in self._glob_path(self.models_directory_path, ".sql"): self._path_mtimes[path] = path.stat().st_mtime with open(path, "r", encoding="utf-8") as file: - expressions = parse(file.read(), read=self.dialect) + file_contents = file.read() + + if JINJA_RE.search(file_contents): + expressions = [JinjaModel(this=file_contents)] + else: + expressions = parse(file_contents, read=self.dialect) + model = Model.load( expressions, module=module, diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 3794ba7faa..23952cae18 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -9,6 +9,10 @@ class Model(exp.Expression): arg_types = {"expressions": True} +class JinjaModel(exp.Expression): + """Stores a model file that contains Jinja code as a raw string.""" + + class Audit(exp.Expression): arg_types = {"expressions": True} diff --git a/sqlmesh/core/model.py b/sqlmesh/core/model.py index fa4bf33d84..21e98ef6a2 100644 --- a/sqlmesh/core/model.py +++ b/sqlmesh/core/model.py @@ -609,7 +609,7 @@ class Model(ModelMeta, frozen=True): python_env: Dictionary containing all global variables needed to render the model's macros. """ - query: t.Union[exp.Subqueryable, d.MacroVar] + query: t.Union[exp.Subqueryable, d.MacroVar, d.JinjaModel] expressions_: t.Optional[t.List[exp.Expression]] = Field( default=None, alias="expressions" ) @@ -664,6 +664,15 @@ def load( time_column_format: The default time column format to use if no model time column is configured. """ if len(expressions) < 2: + if expressions and isinstance(expressions[0], d.JinjaModel): + model = cls( + query=expressions[0], name="test" + ) # We need the model's name for this instantiation to be valid + model._path = path + + model.validate_definition() + return model + _raise_config_error( "Incomplete model definition, missing MODEL and QUERY", path ) @@ -953,6 +962,11 @@ def render_query( Returns: The rendered expression. """ + + # If the query is an instance of d.JinjaModel, render it and parse the produced string + # to create and validate the resulting model. Then we can extract the query and pass it + # as the query_ argument below. Do we want to do any validation earlier (e.g. inside load())? + return self._render_query( start=start, end=end, diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py new file mode 100644 index 0000000000..4fba7e9acf --- /dev/null +++ b/sqlmesh/utils/jinja.py @@ -0,0 +1,5 @@ +import re + +# Captures one of the following patterns: "{{", "{#", "{%" and "{%-". +# Q: this will also flag text that contains "{{" inside a string as Jinja. Is this a problem? +JINJA_RE = re.compile("{({|#|(%(-)?))")