Skip to content
Merged
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230313-135917.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Make model contracts agnostic to ordering
time: 2023-03-13T13:59:17.255368-04:00
custom:
Author: gshank
Issue: 6975 7064
2 changes: 1 addition & 1 deletion core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def get_environment(
native: bool = False,
) -> jinja2.Environment:
args: Dict[str, List[Union[str, Type[jinja2.ext.Extension]]]] = {
"extensions": ["jinja2.ext.do"]
"extensions": ["jinja2.ext.do", "jinja2.ext.loopcontrols"]
}

if capture_macros:
Expand Down
6 changes: 6 additions & 0 deletions core/dbt/context/exceptions_jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PropertyYMLError,
NotImplementedError,
RelationWrongTypeError,
ContractError,
ColumnTypeMissingError,
)

Expand Down Expand Up @@ -66,6 +67,10 @@ def raise_compiler_error(msg, node=None) -> NoReturn:
raise CompilationError(msg, node)


def raise_contract_error(yaml_columns, sql_columns) -> NoReturn:
raise ContractError(yaml_columns, sql_columns)


def raise_database_error(msg, node=None) -> NoReturn:
raise DbtDatabaseError(msg, node)

Expand Down Expand Up @@ -124,6 +129,7 @@ def column_type_missing(column_names) -> NoReturn:
raise_invalid_property_yml_version,
raise_not_implemented,
relation_wrong_type,
raise_contract_error,
column_type_missing,
]
}
Expand Down
17 changes: 17 additions & 0 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,23 @@ def get_message(self) -> str:
return msg


class ContractError(CompilationError):
def __init__(self, yaml_columns, sql_columns):
self.yaml_columns = yaml_columns
self.sql_columns = sql_columns
super().__init__(msg=self.get_message())

def get_message(self) -> str:
msg = (
"Contracts are enabled for this model. "
"Please ensure the name, data_type, and number of columns in your `yml` file "
"match the columns in your SQL file.\n"
f"Schema File Columns: {self.yaml_columns}\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to sort these both alphabetically (message only) so that it's easier for the user to spot the difference.

f"SQL File Columns: {self.sql_columns}"
)
return msg


# not modifying these since rpc should be deprecated soon
class UnknownAsyncIDException(Exception):
CODE = 10012
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,32 @@
{#--Obtain the column schema provided by the schema file by generating an 'empty schema' query from the model's columns. #}
{%- set schema_file_provided_columns = get_column_schema_from_query(get_empty_schema_sql(model['columns'])) -%}

{%- set sql_file_provided_columns_formatted = format_columns(sql_file_provided_columns) -%}
{%- set schema_file_provided_columns_formatted = format_columns(schema_file_provided_columns) -%}
{#-- For compiler error msg #}
{%- set sql_columns = (format_columns(sql_file_provided_columns)|trim) -%}
{%- set yaml_columns = (format_columns(schema_file_provided_columns)|trim) -%}

{%- if sql_file_provided_columns_formatted != schema_file_provided_columns_formatted -%}
{%- do exceptions.raise_compiler_error('Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file.\nSchema File Columns: ' ~ (schema_file_provided_columns_formatted|trim) ~ '\n\nSQL File Columns: ' ~ (sql_file_provided_columns_formatted|trim) ~ ' ' ) %}
{%- if sql_file_provided_columns|length != schema_file_provided_columns|length -%}
{%- do exceptions.raise_contract_error(yaml_columns, sql_columns) -%}
{%- endif -%}

{%- for sql_col in sql_file_provided_columns -%}
{%- set yaml_col = [] -%}
{%- for schema_col in schema_file_provided_columns -%}
{%- if schema_col.name == sql_col.name -%}
{%- do yaml_col.append(schema_col) -%}
{%- break -%}
{%- endif -%}
{%- endfor -%}
{%- if not yaml_col -%}
{#-- Column with name not found in yaml --#}
{%- do exceptions.raise_contract_error(yaml_columns, sql_columns) -%}
{%- endif -%}
{%- if sql_col.dtype != yaml_col[0].dtype -%}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously column types were compared based on their formatted representation, which could be data platform specific as implemented by the <adapter>__format_column macro.

For example, BigQuery's format_column implementation compares data_type values by column.data_type as opposed to column.dtype in order to make comparisons on nested data types.

Strictly comparing SQL and yml data_type values by dtype would allow the following contract to be accepted in BigQuery:

SELECT
STRUCT("test" AS name, [1.0,2.0] AS laps) as some_struct
models:
  - name: test_schema
    config:
      contract: true
    columns:
      - name: some_struct
        data_type: STRUCT<name FLOAT64, laps STRING> #wrong! but accepted because dtype == STRUCT for both SQL and schema.yml

One workaround would be to do this comparison using format_column instead, i.e: adapter.dispatch('format_column', 'dbt')(sql_col) != adapter.dispatch('format_column', 'dbt')(yaml_col[0]). This would also ensure the comparison and error messaging are using consistent logic. An alternative would be to implement a default__compare_data_types macro to enable adapter-specific implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at using format_column but the default implementation of that just uses the dtype. Is that different in the other adapters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't actually see any implementations of format_column in the adapters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I was on the wrong bigquery branch. Bigquery is the only adapter that implements it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we include the formatted column in the Column structures returned by "get_column_schema_from_query"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or concatenate the other parts of the column structure for comparison?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a lot of logic to put in a jinja template. Could we put this in a python function and then wrap that python function in this macro? I'm cutting corners, but this is what's in my head:

columns_spec_ddl.sql

{% macro assert_columns_equivalent(sql) %}

  {% set sql_schema = get_column_schema_from_query(sql) %}

  {% set model_contract = model_contract(model['columns']) %}

  {% do assert_schema_meets_contract(sql_schema, model_contract) %}

dbt/adapters/base/impl.py (for lack of a better spot)

# these are defined elsewhere, but look something like this
ModelContract = List[ColumnContract]
Schema = List[Column]


def model_contract(model) -> ModelContract:
    # I assume we have a way of creating a model contract from a `schema.yml` file
    return ModelContract(model)


def assert_schema_meets_contract(schema: Schema, model_contract: ModelContract)
    if len(schema) != len(model_contract):
        raise ContractError(msg)
    for schema_column, contract_column in zip(sorted(schema), sorted(model_contract)):
        try:
            assert schema_column.name == contract_column.name
            assert schema_column.dtype == contract_column.dtype
        except AssertionError:
            raise ContractError(msg)

I think the python version would be much easier to unit test.

{#-- Column data types don't match --#}
{%- do exceptions.raise_contract_error(yaml_columns, sql_columns) -%}
{%- endif -%}
{%- endfor -%}

{% endmacro %}

{% macro format_columns(columns) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,26 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking through what this might look like across all potential adapters, and if we need to add contract-related items in the future. With that context, I have the following questions:

  1. Do you think this could benefit from becoming a "dispatch" method? The goal of this would be to create a hard divide between contracted models and non-contracted models.
  2. Do we need to include the column select statement in default__get_select_subquery if we already validated get_assert_columns_equivalent? We validate that the number of columns are the same, and the names are the same. So I think the subquery already limits to just the columns that we want.
  3. It looks like get_assert_columns_equivalent might have been renamed assert_columns_equivalent.

With my assumptions (not necessarily true of course):

{% macro default__create_table_as(temporary, relation, sql) -%}
  {% if config.get('contract', False) %}
    {{ default__create_table_as_with_contract(temporary, relation, sql }}
  {% else %}
    {{ default__create_table_as_without_contract(temporary, relation, sql }}
  {% endif %}
{% endmacro %}

{% macro default__create_table_as_with_contract(temporary, relation, sql) %}
  {{ get_assert_columns_equivalent(sql) }}
  
  create {% if temporary: -%}temporary{%- endif %} table
    {{ relation.include(database=(not temporary), schema=(not temporary)) }}
    {{ get_columns_spec_ddl() }}
  as ({{ sql }})

{% endmacro %}

{% macro default__create_table_as_without_contract(temporary, relation, sql) %}
  create {% if temporary: -%}temporary{%- endif %} table
    {{ relation.include(database=(not temporary), schema=(not temporary)) }}
  as ({{ sql }})
{% endmacro %}

This would maintain backwards compatibility because we're keeping the macro create_table_as, which is what would have been overridden. And if we need to update only one of these in the future, it isolates the updates, instead of impacting all "create table" workflows. I'm open to feedback and I'm just trying to communicate a thought here. Please let me know what you think of these recommendations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the get_select_subquery in order to put the columns in the right order. We removed validating the order because we added the subquery. As far as splitting out the macro into two, probably @jtcohen6 and @MichelleArk should weigh in on that.

create {% if temporary: -%}temporary{%- endif %} table
{{ relation.include(database=(not temporary), schema=(not temporary)) }}
{% if config.get('contract', False) %}
{{ get_assert_columns_equivalent(sql) }}
{{ get_columns_spec_ddl() }}
{% endif %}
{% if config.get('contract', False) %}
{{ get_assert_columns_equivalent(sql) }}
{{ get_columns_spec_ddl() }}
{%- set sql = get_select_subquery(sql) %}
{% endif %}
as (
{{ sql }}
);
{%- endmacro %}

{% macro get_select_subquery(sql) %}
{{ return(adapter.dispatch('get_select_subquery', 'dbt')(sql)) }}
{% endmacro %}

{% macro default__get_select_subquery(sql) %}
select
{% for column in model['columns'] %}
{{ column }}{{ ", " if not loop.last }}
{% endfor %}
from (
{{ sql }}
) as model_subq
{%- endmacro %}
7 changes: 4 additions & 3 deletions plugins/postgres/dbt/include/postgres/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
{{ get_assert_columns_equivalent(sql) }}
{{ get_columns_spec_ddl() }} ;
insert into {{ relation }} {{ get_column_names() }}
{% else %}
as
{%- set sql = get_select_subquery(sql) %}
{% else %}
as
{% endif %}
(
(
{{ sql }}
);
{%- endmacro %}
Expand Down
39 changes: 18 additions & 21 deletions tests/adapter/dbt/tests/adapter/constraints/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def data_types(self, schema_int_type, int_type, string_type):
]

def test__constraints_wrong_column_order(self, project, string_type, int_type):
# This no longer causes an error, since we enforce yaml column order
results, log_output = run_dbt_and_capture(
["run", "-s", "my_model_wrong_order"], expect_pass=False
["run", "-s", "my_model_wrong_order"], expect_pass=True
)
manifest = get_manifest(project.project_root)
model_id = "model.test.my_model_wrong_order"
Expand All @@ -75,18 +76,6 @@ def test__constraints_wrong_column_order(self, project, string_type, int_type):

assert contract_actual_config is True

expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file."
expected_schema_file_columns = (
f"Schema File Columns: id {int_type}, color {string_type}, date_day DATE"
)
expected_sql_file_columns = (
f"SQL File Columns: color {string_type}, id {int_type}, date_day DATE"
)

assert expected_compile_error in log_output
assert expected_schema_file_columns in log_output
assert expected_sql_file_columns in log_output

def test__constraints_wrong_column_names(self, project, string_type, int_type):
results, log_output = run_dbt_and_capture(
["run", "-s", "my_model_wrong_name"], expect_pass=False
Expand All @@ -98,7 +87,7 @@ def test__constraints_wrong_column_names(self, project, string_type, int_type):

assert contract_actual_config is True

expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file."
expected_compile_error = "Please ensure the name, data_type, and number of columns in your `yml` file match the columns in your SQL file."
expected_schema_file_columns = (
f"Schema File Columns: id {int_type}, color {string_type}, date_day DATE"
)
Expand Down Expand Up @@ -147,7 +136,7 @@ def test__constraints_wrong_column_data_types(

assert contract_actual_config is True

expected_compile_error = "Please ensure the name, data_type, order, and number of columns in your `yml` file match the columns in your SQL file."
expected_compile_error = "Please ensure the name, data_type, and number of columns in your `yml` file match the columns in your SQL file."
expected_sql_file_columns = (
f"SQL File Columns: wrong_data_type_column_name {error_data_type}"
)
Expand Down Expand Up @@ -196,11 +185,19 @@ def test__constraints_correct_column_data_types(self, project, data_types):
id ,
color ,
date_day
) (
)
(
select
1 as id,
'blue' as color,
cast('2019-01-01' as date) as date_day
id,
color,
date_day
from
(
select
1 as id,
'blue' as color,
cast('2019-01-01' as date) as date_day
) as model_subq
);
"""

Expand Down Expand Up @@ -248,10 +245,10 @@ def test__constraints_ddl(self, project, expected_sql):
expected_sql_check == generated_sql_check
), f"""
-- GENERATED SQL
{generated_sql}
{generated_sql_check}

-- EXPECTED SQL
{expected_sql}
{expected_sql_check}
"""

def test__constraints_enforcement_rollback(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from dbt.exceptions import ParsingError
from dbt.tests.util import run_dbt, get_manifest, run_dbt_and_capture
from dbt.tests.util import run_dbt, get_manifest, get_artifact, run_dbt_and_capture

my_model_sql = """
{{
Expand All @@ -10,8 +10,8 @@
}}

select
1 as id,
'blue' as color,
1 as id,
cast('2019-01-01' as date) as date_day
"""

Expand All @@ -29,7 +29,7 @@
cast('2019-01-01' as date) as date_day
"""

my_model_constraints_disabled_sql = """
my_model_contract_disabled_sql = """
{{
config(
materialized = "table",
Expand Down Expand Up @@ -171,7 +171,7 @@ def model(dbt, _):
"""


class TestModelLevelConstraintsEnabledConfigs:
class TestModelLevelContractEnabledConfigs:
@pytest.fixture(scope="class")
def models(self):
return {
Expand All @@ -180,11 +180,12 @@ def models(self):
}

def test__model_contract_true(self, project):
run_dbt(["parse"])
run_dbt(["run"])
manifest = get_manifest(project.project_root)
model_id = "model.test.my_model"
my_model_columns = manifest.nodes[model_id].columns
my_model_config = manifest.nodes[model_id].config
model = manifest.nodes[model_id]
my_model_columns = model.columns
my_model_config = model.config
contract_actual_config = my_model_config.contract

assert contract_actual_config is True
Expand All @@ -193,8 +194,17 @@ def test__model_contract_true(self, project):

assert expected_columns == str(my_model_columns)

# compiled fields aren't in the manifest above because it only has parsed fields
manifest_json = get_artifact(project.project_root, "target", "manifest.json")
compiled_code = manifest_json["nodes"][model_id]["compiled_code"]
cleaned_code = " ".join(compiled_code.split())
assert (
"select 'blue' as color, 1 as id, cast('2019-01-01' as date) as date_day"
== cleaned_code
)


class TestProjectConstraintsEnabledConfigs:
class TestProjectContractEnabledConfigs:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
Expand All @@ -221,7 +231,7 @@ def test_defined_column_type(self, project):
assert contract_actual_config is True


class TestProjectConstraintsEnabledConfigsError:
class TestProjectContractEnabledConfigsError:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
Expand Down Expand Up @@ -253,7 +263,7 @@ def test_undefined_column_type(self, project):
assert expected_compile_error in log_output


class TestModelConstraintsEnabledConfigs:
class TestModelContractEnabledConfigs:
@pytest.fixture(scope="class")
def models(self):
return {"my_model.sql": my_model_contract_sql, "constraints_schema.yml": model_schema_yml}
Expand All @@ -267,7 +277,7 @@ def test__model_contract(self, project):
assert contract_actual_config is True


class TestModelConstraintsEnabledConfigsMissingDataTypes:
class TestModelContractEnabledConfigsMissingDataTypes:
@pytest.fixture(scope="class")
def models(self):
return {
Expand All @@ -289,11 +299,11 @@ def test_undefined_column_type(self, project):
assert expected_compile_error in log_output


class TestModelLevelConstraintsDisabledConfigs:
class TestModelLevelContractDisabledConfigs:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_constraints_disabled_sql,
"my_model.sql": my_model_contract_disabled_sql,
"constraints_schema.yml": model_schema_yml,
}

Expand All @@ -308,7 +318,7 @@ def test__model_contract_false(self, project):
assert contract_actual_config is False


class TestModelLevelConstraintsErrorMessages:
class TestModelLevelContractErrorMessages:
@pytest.fixture(scope="class")
def models(self):
return {
Expand All @@ -330,7 +340,7 @@ def test__config_errors(self, project):
assert expected_empty_data_type_error not in str(exc_str)


class TestSchemaConstraintsEnabledConfigs:
class TestSchemaContractEnabledConfigs:
@pytest.fixture(scope="class")
def models(self):
return {
Expand All @@ -347,7 +357,7 @@ def test__schema_error(self, project):
assert schema_error_expected in str(exc_str)


class TestPythonModelLevelConstraintsErrorMessages:
class TestPythonModelLevelContractErrorMessages:
@pytest.fixture(scope="class")
def models(self):
return {
Expand Down