diff --git a/example/models/items.py b/example/models/items.py index 6fa432d333..f226bb8381 100644 --- a/example/models/items.py +++ b/example/models/items.py @@ -5,7 +5,7 @@ import pandas as pd from example.helper import iter_dates -from sqlmesh import EngineAdapter, model +from sqlmesh import ExecutionContext, model from sqlmesh.utils.date import to_ds ITEMS = [ @@ -63,7 +63,7 @@ """ ) def execute( - engine: EngineAdapter, + context: ExecutionContext, start: datetime, end: datetime, latest: datetime, diff --git a/example/models/order_items.py b/example/models/order_items.py index 98c77a0bfc..b01a4e3bb0 100644 --- a/example/models/order_items.py +++ b/example/models/order_items.py @@ -1,11 +1,10 @@ import random -import typing as t from datetime import datetime import pandas as pd from example.helper import iter_dates -from sqlmesh import EngineAdapter, Snapshot, model +from sqlmesh import ExecutionContext, model from sqlmesh.utils.date import to_ds @@ -29,26 +28,22 @@ """ ) def execute( - engine: EngineAdapter, + context: ExecutionContext, start: datetime, end: datetime, latest: datetime, - snapshots: t.Dict[str, Snapshot], - mapping: t.Optional[t.Dict[str, str]], **kwargs, ) -> pd.DataFrame: dfs = [] - raw_orders = ( - snapshots["sushi.orders"].table_name if snapshots else mapping["sushi.orders"] - ) + orders_table = context.table("sushi.orders") + items_table = context.table("sushi.items") for dt in iter_dates(start, end): - # this section not super clean, make it easier to fetch other snapshots - orders = engine.fetchdf( + orders = context.fetchdf( f""" SELECT * - FROM {raw_orders} + FROM {orders_table} WHERE ds = '{to_ds(dt)}' """ ) @@ -56,10 +51,10 @@ def execute( if not isinstance(orders, pd.DataFrame): orders = orders.toPandas() - items = engine.fetchdf( + items = context.fetchdf( f""" SELECT * - FROM {raw_orders} + FROM {items_table} WHERE ds = '{to_ds(dt)}' """ ) diff --git a/example/models/orders.py b/example/models/orders.py index 0156a40810..dffd4a2de8 100644 --- a/example/models/orders.py +++ b/example/models/orders.py @@ -4,7 +4,7 @@ import pandas as pd from example.helper import iter_dates -from sqlmesh import EngineAdapter, model +from sqlmesh import ExecutionContext, model from sqlmesh.utils.date import to_ds CUSTOMERS = list(range(0, 100)) @@ -33,7 +33,7 @@ """ ) def execute( - engine: EngineAdapter, + context: ExecutionContext, start: datetime, end: datetime, latest: datetime, diff --git a/setup.py b/setup.py index b9e7c54c85..ce39e787cc 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ "requests", "rich", "ruamel.yaml", - "sqlglot>=10.2.4", + "sqlglot>=10.2.5", ], extras_require={ "dev": [ diff --git a/sqlmesh/__init__.py b/sqlmesh/__init__.py index 4de42dc00b..de0b493d94 100644 --- a/sqlmesh/__init__.py +++ b/sqlmesh/__init__.py @@ -8,7 +8,7 @@ import sys from enum import Enum -from sqlmesh.core.context import Context +from sqlmesh.core.context import Context, ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.macros import macro from sqlmesh.core.model import Model, model diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 5b42e6eaad..990215592e 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -80,6 +80,36 @@ def render( ctx.obj.console.show_sql(sql) +@cli.command("evaluate") +@click.argument("model") +@opt.start_time +@opt.end_time +@opt.latest_time +@click.option( + "--limit", + type=int, + help="The number of rows which the query should be limited to.", +) +@click.pass_context +def evaluate( + ctx, + model: str, + start: TimeLike, + end: TimeLike, + latest: t.Optional[TimeLike] = None, + limit: t.Optional[int] = None, +) -> None: + """Evaluate a model and return a dataframe with a default limit of 1000.""" + df = ctx.obj.evaluate( + model, + start=start, + end=end, + latest=latest, + limit=limit, + ) + ctx.obj.console.log_success(df) + + @cli.command("format") @click.pass_context def format(ctx) -> None: diff --git a/sqlmesh/core/config.py b/sqlmesh/core/config.py index 6cb98ef5b6..0f68714f0b 100644 --- a/sqlmesh/core/config.py +++ b/sqlmesh/core/config.py @@ -266,6 +266,7 @@ class Config(PydanticModel): physical_schema: The default schema used to store materialized tables. snapshot_ttl: Duration before unpromoted snapshots are removed. time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d. + This time format uses python format codes. https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes. ddl_concurrent_task: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). Default: 1. """ diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 7b690102ce..76b4dfd84d 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -32,6 +32,7 @@ """ from __future__ import annotations +import abc import contextlib import importlib import types @@ -50,7 +51,7 @@ from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.dag import DAG from sqlmesh.core.dialect import extend_sqlglot, format_model_expressions -from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.engine_adapter import DF, EngineAdapter from sqlmesh.core.environment import Environment from sqlmesh.core.macros import macro from sqlmesh.core.model import Model @@ -69,10 +70,76 @@ if t.TYPE_CHECKING: import graphviz + MODEL_OR_SNAPSHOT = t.Union[str, Model, Snapshot] + extend_sqlglot() -class Context: +class BaseContext(abc.ABC): + """The base context which defines methods to execute a model.""" + + @property + @abc.abstractmethod + def model_tables(self) -> t.Dict[str, str]: + """Returns a mapping of model names to tables.""" + + @property + @abc.abstractmethod + def engine_adapter(self) -> EngineAdapter: + """Returns an engine adapter.""" + + @property + def spark(self) -> t.Optional["pyspark.sql.SparkSession"]: # type: ignore + """Returns the spark session if it exists.""" + return self.engine_adapter.spark + + def table(self, model_name: str) -> str: + """Gets the physical table name for a given model. + + Args: + model_name: The model name. + + Returns: + The physical table name. + """ + return self.model_tables[model_name] + + def fetchdf(self, query: t.Union[exp.Expression, str]) -> DF: + """Fetches a dataframe given a sql string or sqlglot expression. + + Args: + query: SQL string or sqlglot expression. + + Returns: + The default dataframe is Pandas, but for Spark a PySpark dataframe is returned. + """ + return self.engine_adapter.fetchdf(query) + + +class ExecutionContext(BaseContext): + """The minimal context needed to execute a model. + + Args: + engine_adapter: The engine adapter to execute queries against. + mapping: A mapping of models to physical tables. + """ + + def __init__(self, engine_adapter: EngineAdapter, model_tables: t.Dict[str, str]): + self._engine_adapter = engine_adapter + self._model_tables = model_tables + + @property + def engine_adapter(self) -> EngineAdapter: + """Returns an engine adapter.""" + return self._engine_adapter + + @property + def model_tables(self) -> t.Dict[str, str]: + """Returns a mapping of model names to tables.""" + return self._model_tables + + +class Context(BaseContext): """Encapsulates a SQLMesh environment supplying convenient functions to perform various tasks. Args: @@ -137,7 +204,7 @@ def __init__( ddl_concurrent_tasks or self.config.ddl_concurrent_tasks ) - self.engine_adapter = engine_adapter or EngineAdapter( + self._engine_adapter = engine_adapter or EngineAdapter( self.config.engine_connection_factory, self.config.engine_dialect, multithreaded=self.ddl_concurrent_tasks > 1, @@ -170,6 +237,11 @@ def __init__( if load: self.load() + @property + def engine_adapter(self) -> EngineAdapter: + """Returns an engine adapter.""" + return self._engine_adapter + def upsert_model(self, model: t.Union[str, Model] = "", **kwargs) -> Model: """Update or insert a model. @@ -303,9 +375,22 @@ def snapshots(self) -> t.Dict[str, Snapshot]: snapshots[model.name] = snapshot return snapshots + @property + def model_tables(self) -> t.Dict[str, str]: + """Mapping of model name to physical table name. + + If a snapshot has not been versioned yet, its view name will be returned. + """ + return { + name: snapshot.table_name + if snapshot.version + else snapshot.qualified_view_name.for_environment(c.PROD) + for name, snapshot in self.snapshots.items() + } + def render( self, - model: t.Union[str, Model], + model_or_snapshot: MODEL_OR_SNAPSHOT, *, start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, @@ -317,7 +402,7 @@ def render( """Renders a model's query, expanding macros with provided kwargs, and optionally expanding referenced models. Args: - model: The model name or instance to render. + model_or_snapshot: The model, model name, or snapshot to render. start: The start of the interval to render. end: The end of the interval to render. latest: The latest time used for non incremental datasets. @@ -330,7 +415,14 @@ def render( The rendered expression. """ latest = latest or yesterday_ds() - model = model if isinstance(model, Model) else self.models[model] + + if isinstance(model_or_snapshot, str): + model = self.models[model_or_snapshot] + elif isinstance(model_or_snapshot, Snapshot): + model = model_or_snapshot.model + else: + model = model_or_snapshot + expand = self.dag.upstream(model.name) if expand is True else expand or [] return model.render_query( @@ -345,33 +437,43 @@ def render( def evaluate( self, - snapshot: Snapshot | str, + model_or_snapshot: MODEL_OR_SNAPSHOT, start: TimeLike, end: TimeLike, latest: TimeLike, + limit: t.Optional[int] = None, **kwargs, - ) -> None: - """Evaluate a snapshot (running its query against a DB/Engine). + ) -> DF: + """Evaluate a model or snapshot (running its query against a DB/Engine). + + This method is used to test or iterate on models without side effects. Args: - snapshot: The snapshot to evaluate. + model_or_snapshot: The model, model name, or snapshot to render. start: The start of the interval to evaluate. end: The end of the interval to evaluate. latest: The latest time used for non incremental datasets. + limit: A limit applied to the model, this must be > 0. """ - if isinstance(snapshot, str): - snapshot = self.snapshots[snapshot] + if isinstance(model_or_snapshot, str): + snapshot = self.snapshots[model_or_snapshot] + elif isinstance(model_or_snapshot, Model): + snapshot = self.snapshots[model_or_snapshot.name] + else: + snapshot = model_or_snapshot + + if not limit or limit <= 0: + limit = 1000 - self.snapshot_evaluator.evaluate( + return self.snapshot_evaluator.evaluate( snapshot, start, end, latest, - snapshots=self.snapshots, + mapping=self.model_tables, + limit=limit, ) - self.state_sync.add_interval(snapshot.snapshot_id, start, end) - def format(self) -> None: """Format all models in a given directory.""" for model in self.models.values(): @@ -680,7 +782,11 @@ def _load_models(self): new = registry.keys() - registered registered |= new for name in new: - model = registry[name].model(module, path) + model = registry[name].model( + module=module, + path=path, + time_column_format=self.config.time_column_format, + ) self.models[model.name] = model self._add_model_to_dag(model) diff --git a/sqlmesh/core/model.py b/sqlmesh/core/model.py index fa4bf33d84..a8c8c0e68e 100644 --- a/sqlmesh/core/model.py +++ b/sqlmesh/core/model.py @@ -257,6 +257,7 @@ from sqlglot.optimizer.simplify import simplify from sqlglot.time import format_time +from sqlmesh.core import constants as c from sqlmesh.core import dialect as d from sqlmesh.core.audit import Audit from sqlmesh.core.macros import MacroEvaluator, MacroRegistry, macro @@ -279,7 +280,8 @@ from sqlmesh.utils.pydantic import PydanticModel if t.TYPE_CHECKING: - from sqlmesh.core.engine_adapter import DF, EngineAdapter + from sqlmesh.core.context import ExecutionContext + from sqlmesh.core.engine_adapter import DF from sqlmesh.core.snapshot import Snapshot META_FIELD_CONVERTER: t.Dict[str, t.Callable] = { @@ -291,7 +293,6 @@ "partitioned_by_": lambda value: ( exp.to_identifier(value[0]) if len(value) == 1 else exp.Tuple(expressions=value) ), - "time_column": lambda value: value.expression, "depends_on_": lambda value: exp.Tuple(expressions=value), "columns_": lambda value: exp.Schema( expressions=[ @@ -650,9 +651,9 @@ def load( *, path: Path = Path(), module: str = "", + time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, macros: t.Optional[MacroRegistry] = None, dialect: t.Optional[str] = None, - time_column_format: t.Optional[str] = None, ) -> Model: """Load a model from a parsed SQLMesh model file. @@ -662,6 +663,7 @@ def load( path: An optional path to the file. dialect: The default dialect if no model dialect is configured. time_column_format: The default time column format to use if no model time column is configured. + The format must adhere to Python's strftime codes. """ if len(expressions) < 2: _raise_config_error( @@ -701,20 +703,7 @@ def load( ) model._path = path - - if time_column_format and model.time_column and not model.time_column.format: - if dialect != model.dialect: - # Transpile default time column format in default dialect to model's dialect - default_format = format_time( - time_column_format, d.Dialect.get_or_raise(dialect).time_mapping - ) - time_column_format = ( - d.Dialect.get_or_raise(model.dialect)() - .generator() - .format_time(default_format) - ) - model.time_column.format = time_column_format - + model.set_time_format(time_column_format) model.validate_definition() return model @@ -777,6 +766,27 @@ def render(self) -> t.List[exp.Expression]: if field_value is not None: if field.name == "description": comment = field_value + elif field.name == "time_column": + expression = field_value.expression + + # time_column.format is stored as python format in memory + # convert it back to the model dialect + if field_value.format: + expression.expressions.pop() + expression.append( + "expressions", + exp.Literal.string( + format_time( + field_value.format, + d.Dialect.get_or_raise( + self.dialect + ).inverse_time_mapping, + ) + ), + ) + expressions.append( + exp.Property(this="time_column", value=expression) + ) else: expressions.append( exp.Property( @@ -966,7 +976,7 @@ def render_query( def exec_python( self, - adapter: EngineAdapter, + context: ExecutionContext, *, start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, @@ -976,7 +986,7 @@ def exec_python( """Executes this model's python script. Args: - adapter: The engine adapter to use for fetching data. + context: The execution context used for fetching data. start: The start date/time of the run. end: The end date/time of the run. latest: The latest date/time to use for the run. @@ -989,14 +999,35 @@ def exec_python( f"Model '{self.name}' is a SQL model and cannot be executed as a Python script." ) + from sqlmesh.core.engine_adapter import pyspark + env: t.Dict[str, t.Any] = {} prepare_env(env, self.python_env) start, end = make_inclusive(start or EPOCH_DS, end or EPOCH_DS) latest = to_datetime(latest or EPOCH_DS) try: df = env[self.query.name]( - adapter, start=start, end=end, latest=latest, **kwargs + context, start=start, end=end, latest=latest, **kwargs ) + if self.kind == ModelKind.INCREMENTAL: + assert self.time_column + + if pyspark and isinstance(df, pyspark.sql.DataFrame): + self.convert_to_time_column(end) + df = df.where( + f""" + {self.time_column.column} BETWEEN + {self.convert_to_time_column(start).sql("spark")} AND + {self.convert_to_time_column(end).sql("spark")} + """ + ) + else: + if self.time_column.format: + start = start.strftime(self.time_column.format) + end = end.strftime(self.time_column.format) + + df_time = df[self.time_column.column] + df = df[(df_time >= start) & (df_time <= end)] return df except Exception as e: print_exception(e, self.python_env) @@ -1064,14 +1095,31 @@ def text_diff(self, other: Model) -> str: ) ).strip() + def set_time_format( + self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMAT + ) -> None: + """Sets the default time format for a model. + + Args: + default_time_format: A python time format used as the default format when none is provided. + """ + if not self.time_column: + return + + if self.time_column.format: + # Transpile the time column format into the generic dialect + self.time_column.format = format_time( + self.time_column.format, + d.Dialect.get_or_raise(self.dialect).time_mapping, + ) + else: + self.time_column.format = default_time_format + def convert_to_time_column(self, time: TimeLike) -> exp.Expression: """Convert a TimeLike object to the same time format and type as the model's time column.""" if self.time_column: if self.time_column.format: - mapping = d.Dialect.get_or_raise(self.dialect).time_mapping - fmt = format_time(self.time_column.format, mapping) - if fmt: - time = to_datetime(time).strftime(fmt) + time = to_datetime(time).strftime(self.time_column.format) time_column_type = self.columns[self.time_column.column] if time_column_type.this in exp.DataType.TEXT_TYPES: @@ -1224,7 +1272,13 @@ def __init__(self, definition: str = "", **kwargs): self.expressions = expressions self.name = self.meta.name - def model(self, module: str, path: Path) -> Model: + def model( + self, + *, + module: str, + path: Path, + time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, + ) -> Model: """Get the model registered by this function.""" env: t.Dict[str, t.Any] = {} name = self.func.__name__ @@ -1242,6 +1296,7 @@ def model(self, module: str, path: Path) -> Model: **self.meta.dict(exclude_defaults=True), ) + model.set_time_format(time_column_format) model._path = path return model diff --git a/sqlmesh/core/snapshot_evaluator.py b/sqlmesh/core/snapshot_evaluator.py index 3ba6a71bae..4aa919c1cb 100644 --- a/sqlmesh/core/snapshot_evaluator.py +++ b/sqlmesh/core/snapshot_evaluator.py @@ -28,7 +28,7 @@ from sqlglot import exp, select from sqlmesh.core.audit import AuditResult -from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.engine_adapter import DF, EngineAdapter from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotInfoLike from sqlmesh.utils.concurrency import concurrent_apply_to_snapshots from sqlmesh.utils.date import TimeLike @@ -60,10 +60,11 @@ def evaluate( start: TimeLike, end: TimeLike, latest: TimeLike, - snapshots: t.Dict[str, Snapshot], + limit: int = 0, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, mapping: t.Optional[t.Dict[str, str]] = None, **kwargs, - ) -> None: + ) -> t.Optional[DF]: """Evaluate a snapshot, creating its schema and table if it doesn't exist and then inserting it. Args: @@ -73,37 +74,48 @@ def evaluate( latest: The latest datetime to use for non-incremental queries. snapshots: All snapshots to use for mapping of physical locations. mapping: Mapping of model references to physical snapshots. + limit: If limit is >= 0, the query will not be persisted but evaluated and returned + as a dataframe. kwargs: Additional kwargs to pass to the renderer. """ if snapshot.is_embedded_kind: - return + return None - table_name = snapshot.table_name model = snapshot.model for sql_statement in model.sql_statements: self.adapter.execute(sql_statement) + mapping = mapping or { + name: snapshot.table_name for name, snapshot in (snapshots or {}).items() + } + if model.is_sql: query_or_df = model.render_query( start=start, end=end, latest=latest, - snapshots=snapshots, mapping=mapping, **kwargs, ) else: + from sqlmesh.core.context import ExecutionContext + query_or_df = model.exec_python( - self.adapter, + ExecutionContext(self.adapter, mapping), start=start, end=end, latest=latest, - snapshots=snapshots, - mapping=mapping, **kwargs, ) + if limit > 0: + if isinstance(query_or_df, exp.Expression): + query_or_df = self.adapter.fetchdf(query_or_df.limit(limit)) + return query_or_df.head(limit) + + table_name = snapshot.table_name + if snapshot.is_view_kind: logger.info("Replacing view '%s'", table_name) self.adapter.create_view(table_name, query_or_df, model.columns) @@ -127,6 +139,7 @@ def evaluate( ) else: self.adapter.insert_append(table_name, query_or_df, columns=columns) + return None def promote( self, target_snapshots: t.Iterable[SnapshotInfoLike], environment: str diff --git a/sqlmesh/engines/commands.py b/sqlmesh/engines/commands.py index 897326d3e9..09eba6113f 100644 --- a/sqlmesh/engines/commands.py +++ b/sqlmesh/engines/commands.py @@ -54,7 +54,6 @@ def evaluate( command_payload.start, command_payload.end, command_payload.latest, - {}, mapping=command_payload.table_mapping, ) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 3ab828f308..eb6ce41be6 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -213,19 +213,25 @@ def plan(self, line) -> None: @argument("--start", "-s", type=str, help="Start date to render.") @argument("--end", "-e", type=str, help="End date to render.") @argument("--latest", "-l", type=str, help="Latest date to render.") + @argument( + "--limit", + type=int, + help="The number of rows which the query should be limited to.", + ) @line_magic def evaluate(self, line): """Evaluate a model query and fetches a dataframe.""" self.context.refresh() args = parse_argstring(self.evaluate, line) - query = self.context.render( + df = self.context.evaluate( args.model, start=args.start, end=args.end, latest=args.latest, + limit=args.limit, ) - self.display(self.context.engine_adapter.fetchdf(query)) + self.display(df) @magic_arguments() @argument( @@ -239,7 +245,7 @@ def evaluate(self, line): def fetchdf(self, line, sql: str): """Fetches a dataframe from sql, optionally storing it in a variable.""" args = parse_argstring(self.fetchdf, line) - df = self.context.engine_adapter.fetchdf(sql) + df = self.context.fetchdf(sql) if args.df_var: self.shell.user_ns[args.df_var] = df self.display(df) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8dbcb41af7..ddd0cd8445 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -384,8 +384,8 @@ def test_time_column(): ) model = Model.load(expressions) assert model.time_column.column == "ds" - assert model.time_column.format is None - assert model.time_column.expression == parse_one("ds") + assert model.time_column.format == "%Y-%m-%d" + assert model.time_column.expression == parse_one("(ds, '%Y-%m-%d')") expressions = parse( """ @@ -399,14 +399,15 @@ def test_time_column(): ) model = Model.load(expressions) assert model.time_column.column == "ds" - assert model.time_column.format is None - assert model.time_column.expression == parse_one("ds") + assert model.time_column.format == "%Y-%m-%d" + assert model.time_column.expression == parse_one("(ds, '%Y-%m-%d')") expressions = parse( """ MODEL ( name db.table, - time_column (ds, 'yyyy-mm-dd') + time_column (ds, 'yyyy-MM'), + dialect 'hive', ); SELECT col::text, ds::text @@ -414,8 +415,8 @@ def test_time_column(): ) model = Model.load(expressions) assert model.time_column.column == "ds" - assert model.time_column.format == "yyyy-mm-dd" - assert model.time_column.expression == parse_one("(ds, 'yyyy-mm-dd')") + assert model.time_column.format == "%Y-%m" + assert model.time_column.expression == parse_one("(ds, '%Y-%m')") def test_default_time_column(): @@ -429,35 +430,35 @@ def test_default_time_column(): SELECT col::text, ds::text """ ) - model = Model.load(expressions, time_column_format="yyyy-mm-dd") - assert model.time_column.format == "yyyy-mm-dd" + model = Model.load(expressions, time_column_format="%Y") + assert model.time_column.format == "%Y" expressions = parse( """ MODEL ( name db.table, - time_column (ds, "mm-dd-yyyy") + time_column (ds, "%Y") ); SELECT col::text, ds::text """ ) - model = Model.load(expressions, time_column_format="yyyy-mm-dd") - assert model.time_column.format == "mm-dd-yyyy" + model = Model.load(expressions, time_column_format="%m") + assert model.time_column.format == "%Y" expressions = parse( """ MODEL ( name db.table, - dialect duckdb, - time_column ds, + dialect hive, + time_column (ds, "dd") ); SELECT col::text, ds::text """ ) - model = Model.load(expressions, dialect="hive", time_column_format="yy-M-ss") - assert model.time_column.format == "%y-%-m-%S" + model = Model.load(expressions, dialect="duckdb", time_column_format="%Y") + assert model.time_column.format == "%d" def test_convert_to_time_column(): @@ -474,7 +475,7 @@ def test_convert_to_time_column(): model = Model.load(expressions) assert model.convert_to_time_column("2022-01-01") == parse_one("'2022-01-01'") assert model.convert_to_time_column(to_datetime("2022-01-01")) == parse_one( - "'2022-01-01 00:00:00+00:00'" + "'2022-01-01'" ) expressions = parse( diff --git a/tests/schedulers/airflow/operators/test_targets.py b/tests/schedulers/airflow/operators/test_targets.py index 46c4c08061..55eed93068 100644 --- a/tests/schedulers/airflow/operators/test_targets.py +++ b/tests/schedulers/airflow/operators/test_targets.py @@ -51,7 +51,7 @@ def test_evaluation_target_execute( target.execute(context, lambda: mocker.Mock(), "spark") evaluator_evaluate_mock.assert_called_once_with( - snapshot, interval_ds, interval_ds, logical_ds, {}, mapping=table_mapping + snapshot, interval_ds, interval_ds, logical_ds, mapping=table_mapping ) add_interval_mock.assert_called_once_with(