Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example/models/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd

from example.helper import iter_dates
from sqlmesh import EngineAdapter, model
from sqlmesh import ExecutionContext, model
from sqlmesh.utils.date import to_ds

ITEMS = [
Expand Down Expand Up @@ -63,7 +63,7 @@
"""
)
def execute(
engine: EngineAdapter,
context: ExecutionContext,
start: datetime,
end: datetime,
latest: datetime,
Expand Down
15 changes: 5 additions & 10 deletions example/models/order_items.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import random
import typing as t
from datetime import datetime

import pandas as pd

from example.helper import iter_dates
from sqlmesh import EngineAdapter, Snapshot, model
from sqlmesh import ExecutionContext, model
from sqlmesh.utils.date import to_ds


Expand All @@ -29,23 +28,19 @@
"""
)
def execute(
engine: EngineAdapter,
context: ExecutionContext,
start: datetime,
end: datetime,
latest: datetime,
snapshots: t.Dict[str, Snapshot],
mapping: t.Optional[t.Dict[str, str]],
**kwargs,
) -> pd.DataFrame:
dfs = []

raw_orders = (
snapshots["sushi.orders"].table_name if snapshots else mapping["sushi.orders"]
)
raw_orders = context.table("sushi.orders")

for dt in iter_dates(start, end):
# this section not super clean, make it easier to fetch other snapshots
orders = engine.fetchdf(
orders = context.fetchdf(
f"""
SELECT *
FROM {raw_orders}
Expand All @@ -56,7 +51,7 @@ def execute(
if not isinstance(orders, pd.DataFrame):
orders = orders.toPandas()

items = engine.fetchdf(
items = context.fetchdf(
f"""
SELECT *
FROM {raw_orders}
Expand Down
4 changes: 2 additions & 2 deletions example/models/orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from example.helper import iter_dates
from sqlmesh import EngineAdapter, model
from sqlmesh import ExecutionContext, model
from sqlmesh.utils.date import to_ds

CUSTOMERS = list(range(0, 100))
Expand Down Expand Up @@ -33,7 +33,7 @@
"""
)
def execute(
engine: EngineAdapter,
context: ExecutionContext,
start: datetime,
end: datetime,
latest: datetime,
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from enum import Enum

from sqlmesh.core.context import Context
from sqlmesh.core.context import Context, ExecutionContext
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.macros import macro
from sqlmesh.core.model import Model, model
Expand Down
30 changes: 30 additions & 0 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ def render(
ctx.obj.console.show_sql(sql)


@cli.command("evaluate")
@click.argument("model")
@opt.start_time
@opt.end_time
@opt.latest_time
@click.option(
"--limit",
type=int,
help="The number of rows which the query should be limited to.",
)
@click.pass_context
def evaluate(
ctx,
model: str,
start: TimeLike,
end: TimeLike,
latest: t.Optional[TimeLike] = None,
limit: t.Optional[int] = None,
) -> None:
"""Evaluate a model and return a dataframe with a default limit of 1000."""
df = ctx.obj.evaluate(
model,
start=start,
end=end,
latest=latest,
limit=limit,
)
ctx.obj.console.log_success(df)


@cli.command("format")
@click.pass_context
def format(ctx) -> None:
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class Config(PydanticModel):
physical_schema: The default schema used to store materialized tables.
snapshot_ttl: Duration before unpromoted snapshots are removed.
time_column_format: The default format to use for all model time columns. Defaults to %Y-%m-%d.
This time format uses python format codes. https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes.
ddl_concurrent_task: The number of concurrent tasks used for DDL
operations (table / view creation, deletion, etc). Default: 1.
"""
Expand Down
107 changes: 91 additions & 16 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from sqlmesh.core.context_diff import ContextDiff
from sqlmesh.core.dag import DAG
from sqlmesh.core.dialect import extend_sqlglot, format_model_expressions
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.engine_adapter import DF, EngineAdapter
from sqlmesh.core.environment import Environment
from sqlmesh.core.macros import macro
from sqlmesh.core.model import Model
Expand All @@ -69,10 +69,51 @@
if t.TYPE_CHECKING:
import graphviz

MODEL_OR_SNAPSHOT = t.Union[str, Model, Snapshot]
Comment thread
tobymao marked this conversation as resolved.

extend_sqlglot()


class Context:
class ExecutionContext:
"""The minimal context needed in order to execute a query.
Args:
engine_adapter: The engine adapter to execute queries against.
mapping: A mapping of models to physical tables.
"""

def __init__(self, engine_adapter: EngineAdapter, mapping: t.Dict[str, str]):
self.engine_adapter = engine_adapter
self.spark = self.engine_adapter.spark
Comment thread
vchan marked this conversation as resolved.
Outdated
self._mapping = mapping

@property
def mapping(self) -> t.Dict[str, str]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nitpick] TBH I really dislike the name mapping. Given its name and the type signature it can be anything at all and it's impossible to tell without reading the docs (if they are even available). Can we be more specific? Like physical_tables_to_model or model_tables or model_to_table_mapping etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to model_tables

return self._mapping

def table(self, model_name: str) -> str:
"""Gets the physical table name for a given model.

Args:
model_name: The model name.

Returns:
The physical table name.
"""
return self.mapping[model_name]

def fetchdf(self, query: t.Union[exp.Expression, str]) -> DF:
"""Fetches a dataframe given a sql string or sqlglot expression.

Args:
query: SQL string or sqlglot expression.

Returns:
The default dataframe is Pandas, but for Spark a PySpark dataframe is returned.
"""
return self.engine_adapter.fetchdf(query)


class Context(ExecutionContext):
Copy link
Copy Markdown
Collaborator

@izeigerman izeigerman Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way we use the base class here is quite sketchy and can lead to unintended consequences. For example we never invoke the base class constructor and only rely on method overriding hoping it would just do the right thing. As the code evolves custom initialization can be added to the constructor of ExecutionContext which wouldn't be a part of the Context.

I'd rather have an ABC for this instead and 2 concrete implementations. Also we may want to create context package since this module is pretty huge already.

"""Encapsulates a SQLMesh environment supplying convenient functions to perform various tasks.

Args:
Expand Down Expand Up @@ -303,9 +344,22 @@ def snapshots(self) -> t.Dict[str, Snapshot]:
snapshots[model.name] = snapshot
return snapshots

@property
def mapping(self) -> t.Dict[str, str]:
"""Mapping of model name to physical table name.

If a snapshot has not been versioned yet, it's view name will be returned.
Comment thread
tobymao marked this conversation as resolved.
Outdated
"""
return {
name: snapshot.table_name
if snapshot.version
else snapshot.qualified_view_name.for_environment(c.PROD)
for name, snapshot in self.snapshots.items()
}

def render(
self,
model: t.Union[str, Model],
model_or_snapshot: MODEL_OR_SNAPSHOT,
*,
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
Expand All @@ -317,7 +371,7 @@ def render(
"""Renders a model's query, expanding macros with provided kwargs, and optionally expanding referenced models.

Args:
model: The model name or instance to render.
model_or_snapshot: The model, model name, or snapshot to render.
start: The start of the interval to render.
end: The end of the interval to render.
latest: The latest time used for non incremental datasets.
Expand All @@ -330,7 +384,14 @@ def render(
The rendered expression.
"""
latest = latest or yesterday_ds()
model = model if isinstance(model, Model) else self.models[model]

if isinstance(model_or_snapshot, str):
model = self.models[model_or_snapshot]
elif isinstance(model_or_snapshot, Snapshot):
model = model_or_snapshot.model
else:
model = model_or_snapshot

expand = self.dag.upstream(model.name) if expand is True else expand or []

return model.render_query(
Expand All @@ -345,33 +406,43 @@ def render(

def evaluate(
self,
snapshot: Snapshot | str,
model_or_snapshot: MODEL_OR_SNAPSHOT,
start: TimeLike,
end: TimeLike,
latest: TimeLike,
limit: t.Optional[int] = None,
**kwargs,
) -> None:
"""Evaluate a snapshot (running its query against a DB/Engine).
) -> DF:
"""Evaluate a model or snapshot (running its query against a DB/Engine).

This method is used to test or iterate on models without side effects.

Args:
snapshot: The snapshot to evaluate.
model_or_snapshot: The model, model name, or snapshot to render.
start: The start of the interval to evaluate.
end: The end of the interval to evaluate.
latest: The latest time used for non incremental datasets.
limit: A limit applied to the model, this must be > 0.
"""
if isinstance(snapshot, str):
snapshot = self.snapshots[snapshot]
if isinstance(model_or_snapshot, str):
snapshot = self.snapshots[model_or_snapshot]
elif isinstance(model_or_snapshot, Model):
snapshot = self.snapshots[model_or_snapshot.name]
else:
snapshot = model_or_snapshot

if not limit or limit <= 0:
limit = 1000

self.snapshot_evaluator.evaluate(
return self.snapshot_evaluator.evaluate(
snapshot,
start,
end,
latest,
snapshots=self.snapshots,
mapping=self.mapping,
limit=limit,
)

self.state_sync.add_interval(snapshot.snapshot_id, start, end)

def format(self) -> None:
"""Format all models in a given directory."""
for model in self.models.values():
Expand Down Expand Up @@ -680,7 +751,11 @@ def _load_models(self):
new = registry.keys() - registered
registered |= new
for name in new:
model = registry[name].model(module, path)
model = registry[name].model(
module=module,
path=path,
time_column_format=self.config.time_column_format,
)
self.models[model.name] = model
self._add_model_to_dag(model)

Expand Down
Loading