Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example/models/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd

from example.helper import iter_dates
from sqlmesh import EngineAdapter, model
from sqlmesh import ExecutionContext, model
from sqlmesh.utils.date import to_ds

ITEMS = [
Expand Down Expand Up @@ -63,7 +63,7 @@
"""
)
def execute(
engine: EngineAdapter,
context: ExecutionContext,
start: datetime,
end: datetime,
latest: datetime,
Expand Down
21 changes: 8 additions & 13 deletions example/models/order_items.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import random
import typing as t
from datetime import datetime

import pandas as pd

from example.helper import iter_dates
from sqlmesh import EngineAdapter, Snapshot, model
from sqlmesh import ExecutionContext, model
from sqlmesh.utils.date import to_ds


Expand All @@ -29,37 +28,33 @@
"""
)
def execute(
engine: EngineAdapter,
context: ExecutionContext,
start: datetime,
end: datetime,
latest: datetime,
snapshots: t.Dict[str, Snapshot],
mapping: t.Optional[t.Dict[str, str]],
**kwargs,
) -> pd.DataFrame:
dfs = []

raw_orders = (
snapshots["sushi.orders"].table_name if snapshots else mapping["sushi.orders"]
)
orders_table = context.table("sushi.orders")
items_table = context.table("sushi.items")

for dt in iter_dates(start, end):
# this section not super clean, make it easier to fetch other snapshots
orders = engine.fetchdf(
orders = context.fetchdf(
f"""
SELECT *
FROM {raw_orders}
FROM {orders_table}
WHERE ds = '{to_ds(dt)}'
"""
)

if not isinstance(orders, pd.DataFrame):
orders = orders.toPandas()

items = engine.fetchdf(
items = context.fetchdf(
f"""
SELECT *
FROM {raw_orders}
FROM {items_table}
WHERE ds = '{to_ds(dt)}'
"""
)
Expand Down
4 changes: 2 additions & 2 deletions example/models/orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from example.helper import iter_dates
from sqlmesh import EngineAdapter, model
from sqlmesh import ExecutionContext, model
from sqlmesh.utils.date import to_ds

CUSTOMERS = list(range(0, 100))
Expand Down Expand Up @@ -33,7 +33,7 @@
"""
)
def execute(
engine: EngineAdapter,
context: ExecutionContext,
start: datetime,
end: datetime,
latest: datetime,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"requests",
"rich",
"ruamel.yaml",
"sqlglot>=10.2.4",
"sqlglot>=10.2.5",
],
extras_require={
"dev": [
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from enum import Enum

from sqlmesh.core.context import Context
from sqlmesh.core.context import Context, ExecutionContext
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.macros import macro
from sqlmesh.core.model import Model, model
Expand Down
30 changes: 30 additions & 0 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ def render(
ctx.obj.console.show_sql(sql)


@cli.command("evaluate")
@click.argument("model")
@opt.start_time
@opt.end_time
@opt.latest_time
@click.option(
"--limit",
type=int,
help="The number of rows which the query should be limited to.",
)
@click.pass_context
def evaluate(
ctx,
model: str,
start: TimeLike,
end: TimeLike,
latest: t.Optional[TimeLike] = None,
limit: t.Optional[int] = None,
) -> None:
"""Evaluate a model and return a dataframe with a default limit of 1000."""
df = ctx.obj.evaluate(
model,
start=start,
end=end,
latest=latest,
limit=limit,
)
ctx.obj.console.log_success(df)


@cli.command("format")
@click.pass_context
def format(ctx) -> None:
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class Config(PydanticModel):
physical_schema: The default schema used to store materialized tables.
snapshot_ttl: Duration before unpromoted snapshots are removed.
time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d.
This time format uses python format codes. https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes.
ddl_concurrent_task: The number of concurrent tasks used for DDL
operations (table / view creation, deletion, etc). Default: 1.
"""
Expand Down
140 changes: 123 additions & 17 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""
from __future__ import annotations

import abc
import contextlib
import importlib
import types
Expand All @@ -50,7 +51,7 @@
from sqlmesh.core.context_diff import ContextDiff
from sqlmesh.core.dag import DAG
from sqlmesh.core.dialect import extend_sqlglot, format_model_expressions
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.engine_adapter import DF, EngineAdapter
from sqlmesh.core.environment import Environment
from sqlmesh.core.macros import macro
from sqlmesh.core.model import Model
Expand All @@ -69,10 +70,76 @@
if t.TYPE_CHECKING:
import graphviz

MODEL_OR_SNAPSHOT = t.Union[str, Model, Snapshot]
Comment thread
tobymao marked this conversation as resolved.

extend_sqlglot()


class Context:
class BaseContext(abc.ABC):
"""The base context which defines methods to execute a model."""

@property
@abc.abstractmethod
def model_tables(self) -> t.Dict[str, str]:
"""Returns a mapping of model names to tables."""

@property
@abc.abstractmethod
def engine_adapter(self) -> EngineAdapter:
"""Returns an engine adapter."""

@property
def spark(self) -> t.Optional["pyspark.sql.SparkSession"]: # type: ignore
"""Returns the spark session if it exists."""
return self.engine_adapter.spark

def table(self, model_name: str) -> str:
"""Gets the physical table name for a given model.

Args:
model_name: The model name.

Returns:
The physical table name.
"""
return self.model_tables[model_name]

def fetchdf(self, query: t.Union[exp.Expression, str]) -> DF:
"""Fetches a dataframe given a sql string or sqlglot expression.

Args:
query: SQL string or sqlglot expression.

Returns:
The default dataframe is Pandas, but for Spark a PySpark dataframe is returned.
"""
return self.engine_adapter.fetchdf(query)


class ExecutionContext(BaseContext):
"""The minimal context needed to execute a model.

Args:
engine_adapter: The engine adapter to execute queries against.
mapping: A mapping of models to physical tables.
"""

def __init__(self, engine_adapter: EngineAdapter, model_tables: t.Dict[str, str]):
self._engine_adapter = engine_adapter
self._model_tables = model_tables

@property
def engine_adapter(self) -> EngineAdapter:
"""Returns an engine adapter."""
return self._engine_adapter

@property
def model_tables(self) -> t.Dict[str, str]:
"""Returns a mapping of model names to tables."""
return self._model_tables


class Context(BaseContext):
"""Encapsulates a SQLMesh environment supplying convenient functions to perform various tasks.

Args:
Expand Down Expand Up @@ -137,7 +204,7 @@ def __init__(
ddl_concurrent_tasks or self.config.ddl_concurrent_tasks
)

self.engine_adapter = engine_adapter or EngineAdapter(
self._engine_adapter = engine_adapter or EngineAdapter(
self.config.engine_connection_factory,
self.config.engine_dialect,
multithreaded=self.ddl_concurrent_tasks > 1,
Expand Down Expand Up @@ -170,6 +237,11 @@ def __init__(
if load:
self.load()

@property
def engine_adapter(self) -> EngineAdapter:
"""Returns an engine adapter."""
return self._engine_adapter

def upsert_model(self, model: t.Union[str, Model] = "", **kwargs) -> Model:
"""Update or insert a model.

Expand Down Expand Up @@ -303,9 +375,22 @@ def snapshots(self) -> t.Dict[str, Snapshot]:
snapshots[model.name] = snapshot
return snapshots

@property
def model_tables(self) -> t.Dict[str, str]:
"""Mapping of model name to physical table name.

If a snapshot has not been versioned yet, its view name will be returned.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why return view? So that local evaluation works?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, in case you haven't pushed a snapshot yet (because you can run evaluate before plan)

"""
return {
name: snapshot.table_name
if snapshot.version
else snapshot.qualified_view_name.for_environment(c.PROD)
for name, snapshot in self.snapshots.items()
}

def render(
self,
model: t.Union[str, Model],
model_or_snapshot: MODEL_OR_SNAPSHOT,
*,
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
Expand All @@ -317,7 +402,7 @@ def render(
"""Renders a model's query, expanding macros with provided kwargs, and optionally expanding referenced models.

Args:
model: The model name or instance to render.
model_or_snapshot: The model, model name, or snapshot to render.
start: The start of the interval to render.
end: The end of the interval to render.
latest: The latest time used for non incremental datasets.
Expand All @@ -330,7 +415,14 @@ def render(
The rendered expression.
"""
latest = latest or yesterday_ds()
model = model if isinstance(model, Model) else self.models[model]

if isinstance(model_or_snapshot, str):
model = self.models[model_or_snapshot]
elif isinstance(model_or_snapshot, Snapshot):
model = model_or_snapshot.model
else:
model = model_or_snapshot

expand = self.dag.upstream(model.name) if expand is True else expand or []

return model.render_query(
Expand All @@ -345,33 +437,43 @@ def render(

def evaluate(
self,
snapshot: Snapshot | str,
model_or_snapshot: MODEL_OR_SNAPSHOT,
start: TimeLike,
end: TimeLike,
latest: TimeLike,
limit: t.Optional[int] = None,
**kwargs,
) -> None:
"""Evaluate a snapshot (running its query against a DB/Engine).
) -> DF:
"""Evaluate a model or snapshot (running its query against a DB/Engine).

This method is used to test or iterate on models without side effects.

Args:
snapshot: The snapshot to evaluate.
model_or_snapshot: The model, model name, or snapshot to render.
start: The start of the interval to evaluate.
end: The end of the interval to evaluate.
latest: The latest time used for non incremental datasets.
limit: A limit applied to the model, this must be > 0.
"""
if isinstance(snapshot, str):
snapshot = self.snapshots[snapshot]
if isinstance(model_or_snapshot, str):
snapshot = self.snapshots[model_or_snapshot]
elif isinstance(model_or_snapshot, Model):
snapshot = self.snapshots[model_or_snapshot.name]
else:
snapshot = model_or_snapshot

if not limit or limit <= 0:
limit = 1000

self.snapshot_evaluator.evaluate(
return self.snapshot_evaluator.evaluate(
snapshot,
start,
end,
latest,
snapshots=self.snapshots,
mapping=self.model_tables,
limit=limit,
)

self.state_sync.add_interval(snapshot.snapshot_id, start, end)

def format(self) -> None:
"""Format all models in a given directory."""
for model in self.models.values():
Expand Down Expand Up @@ -680,7 +782,11 @@ def _load_models(self):
new = registry.keys() - registered
registered |= new
for name in new:
model = registry[name].model(module, path)
model = registry[name].model(
module=module,
path=path,
time_column_format=self.config.time_column_format,
)
self.models[model.name] = model
self._add_model_to_dag(model)

Expand Down
Loading