diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 7a12d3c07d..0000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,115 +0,0 @@ -version: 2.1 - -setup: true - -on_main_or_tag_filter: &on_main_or_tag_filter - filters: - branches: - only: main - tags: - only: /^v\d+\.\d+\.\d+/ - -on_tag_filter: &on_tag_filter - filters: - branches: - ignore: /.*/ - tags: - only: /^v\d+\.\d+\.\d+/ - -orbs: - path-filtering: circleci/path-filtering@1.2.0 - -jobs: - publish: - docker: - - image: cimg/python:3.10 - resource_class: small - steps: - - checkout - - attach_workspace: - at: web/client - - run: - name: Publish Python package - command: make publish - - run: - name: Update pypirc - command: ./.circleci/update-pypirc.sh - - run: - name: Publish Python Tests package - command: unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests - gh-release: - docker: - - image: cimg/node:20.19.0 - resource_class: small - steps: - - run: - name: Create release on GitHub - command: | - GITHUB_TOKEN="$GITHUB_TOKEN" \ - TARGET_TAG="$CIRCLE_TAG" \ - REPO_OWNER="$CIRCLE_PROJECT_USERNAME" \ - REPO_NAME="$CIRCLE_PROJECT_REPONAME" \ - CONTINUE_ON_ERROR="false" \ - npx https://github.com/TobikoData/circleci-gh-conventional-release - - ui-build: - docker: - - image: cimg/node:20.19.0 - resource_class: medium - steps: - - checkout - - run: - name: Install Dependencies - command: | - pnpm install - - run: - name: Build UI - command: pnpm --prefix web/client run build - - persist_to_workspace: - root: web/client - paths: - - dist - trigger_private_renovate: - docker: - - image: cimg/base:2021.11 - resource_class: small - steps: - - run: - name: Trigger private renovate - command: | - curl --request POST \ - --url $TOBIKO_PRIVATE_CIRCLECI_URL \ - --header "Circle-Token: $TOBIKO_PRIVATE_CIRCLECI_KEY" \ - --header "content-type: application/json" \ - --data '{ - "branch":"main", - "parameters":{ - "run_main_pr":false, - "run_sqlmesh_commit":false, - "run_renovate":true - } - }' - -workflows: - setup-workflow: - jobs: - - path-filtering/filter: - mapping: | - web/client/.* client true - (sqlmesh|tests|examples|web/server)/.* python true - pytest.ini|setup.cfg|setup.py|pyproject.toml python true - \.circleci/.*|Makefile|\.pre-commit-config\.yaml common true - vscode/extensions/.* vscode true - tag: "3.9" - - gh-release: - <<: *on_tag_filter - - ui-build: - <<: *on_main_or_tag_filter - - publish: - <<: *on_main_or_tag_filter - requires: - - ui-build - - trigger_private_renovate: - <<: *on_tag_filter - requires: - - publish diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml deleted file mode 100644 index bf27e03f47..0000000000 --- a/.circleci/continue_config.yml +++ /dev/null @@ -1,331 +0,0 @@ -version: 2.1 - -parameters: - client: - type: boolean - default: false - common: - type: boolean - default: false - python: - type: boolean - default: false - -orbs: - windows: circleci/windows@5.0 - -commands: - halt_unless_core: - steps: - - unless: - condition: - or: - - << pipeline.parameters.common >> - - << pipeline.parameters.python >> - - equal: [main, << pipeline.git.branch >>] - steps: - - run: circleci-agent step halt - halt_unless_client: - steps: - - unless: - condition: - or: - - << pipeline.parameters.common >> - - << pipeline.parameters.client >> - - equal: [main, << pipeline.git.branch >>] - steps: - - run: circleci-agent step halt - -jobs: - vscode_test: - docker: - - image: cimg/node:20.19.1-browsers - resource_class: small - steps: - - checkout - - run: - name: Install Dependencies - command: | - pnpm install - - run: - name: Run VSCode extension CI - command: | - cd vscode/extension - pnpm run ci - doc_tests: - docker: - - image: cimg/python:3.10 - resource_class: small - steps: - - halt_unless_core - - checkout - - run: - name: Install dependencies - command: make install-dev install-doc - - run: - name: Run doc tests - command: make doc-test - - style_and_cicd_tests: - parameters: - python_version: - type: string - docker: - - image: cimg/python:<< parameters.python_version >> - resource_class: large - environment: - PYTEST_XDIST_AUTO_NUM_WORKERS: 8 - steps: - - halt_unless_core - - checkout - - run: - name: Install OpenJDK - command: sudo apt-get update && sudo apt-get install default-jdk - - run: - name: Install ODBC - command: sudo apt-get install unixodbc-dev - - run: - name: Install SQLMesh dev dependencies - command: make install-dev - - run: - name: Fix Git URL override - command: git config --global --unset url."ssh://git@github.com".insteadOf - - run: - name: Run linters and code style checks - command: make py-style - - unless: - condition: - equal: ["3.9", << parameters.python_version >>] - steps: - - run: - name: Exercise the benchmarks - command: make benchmark-ci - - run: - name: Run cicd tests - command: make cicd-test - - store_test_results: - path: test-results - - cicd_tests_windows: - executor: - name: windows/default - size: large - steps: - - halt_unless_core - - run: - name: Enable symlinks in git config - command: git config --global core.symlinks true - - checkout - - run: - name: Install System Dependencies - command: | - choco install make which -y - refreshenv - - run: - name: Install SQLMesh dev dependencies - command: | - python -m venv venv - . ./venv/Scripts/activate - python.exe -m pip install --upgrade pip - make install-dev - - run: - name: Run fast unit tests - command: | - . ./venv/Scripts/activate - which python - python --version - make fast-test - - store_test_results: - path: test-results - - migration_test: - docker: - - image: cimg/python:3.10 - resource_class: small - environment: - SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1" - steps: - - halt_unless_core - - checkout - - run: - name: Run the migration test - sushi - command: ./.circleci/test_migration.sh sushi "--gateway duckdb_persistent" - - run: - name: Run the migration test - sushi_dbt - command: ./.circleci/test_migration.sh sushi_dbt "--config migration_test_config" - - ui_style: - docker: - - image: cimg/node:20.19.0 - resource_class: small - steps: - - checkout - - restore_cache: - name: Restore pnpm Package Cache - keys: - - pnpm-packages-{{ checksum "pnpm-lock.yaml" }} - - run: - name: Install Dependencies - command: | - pnpm install - - save_cache: - name: Save pnpm Package Cache - key: pnpm-packages-{{ checksum "pnpm-lock.yaml" }} - paths: - - .pnpm-store - - run: - name: Run linters and code style checks - command: pnpm run lint - - ui_test: - docker: - - image: mcr.microsoft.com/playwright:v1.54.1-jammy - resource_class: medium - steps: - - halt_unless_client - - checkout - - restore_cache: - name: Restore pnpm Package Cache - keys: - - pnpm-packages-{{ checksum "pnpm-lock.yaml" }} - - run: - name: Install pnpm package manager - command: | - npm install --global corepack@latest - corepack enable - corepack prepare pnpm@latest-10 --activate - pnpm config set store-dir .pnpm-store - - run: - name: Install Dependencies - command: | - pnpm install - - save_cache: - name: Save pnpm Package Cache - key: pnpm-packages-{{ checksum "pnpm-lock.yaml" }} - paths: - - .pnpm-store - - run: - name: Run tests - command: npm --prefix web/client run test - - engine_tests_docker: - parameters: - engine: - type: string - machine: - image: ubuntu-2404:2024.05.1 - docker_layer_caching: true - resource_class: large - environment: - SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1" - steps: - - halt_unless_core - - checkout - - run: - name: Install OS-level dependencies - command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" - - run: - name: Run tests - command: make << parameters.engine >>-test - no_output_timeout: 20m - - store_test_results: - path: test-results - - engine_tests_cloud: - parameters: - engine: - type: string - docker: - - image: cimg/python:3.12 - resource_class: medium - environment: - PYTEST_XDIST_AUTO_NUM_WORKERS: 4 - SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1" - steps: - - halt_unless_core - - checkout - - run: - name: Install OS-level dependencies - command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" - - run: - name: Generate database name - command: | - UUID=`cat /proc/sys/kernel/random/uuid` - TEST_DB_NAME="circleci_${UUID:0:8}" - echo "export TEST_DB_NAME='$TEST_DB_NAME'" >> "$BASH_ENV" - echo "export SNOWFLAKE_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" - echo "export DATABRICKS_CATALOG='$TEST_DB_NAME'" >> "$BASH_ENV" - echo "export REDSHIFT_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" - echo "export GCP_POSTGRES_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" - echo "export FABRIC_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" - - # Make snowflake private key available - echo $SNOWFLAKE_PRIVATE_KEY_RAW | base64 -d > /tmp/snowflake-keyfile.p8 - echo "export SNOWFLAKE_PRIVATE_KEY_FILE='/tmp/snowflake-keyfile.p8'" >> "$BASH_ENV" - - run: - name: Create test database - command: ./.circleci/manage-test-db.sh << parameters.engine >> "$TEST_DB_NAME" up - - run: - name: Run tests - command: | - make << parameters.engine >>-test - no_output_timeout: 20m - - run: - name: Tear down test database - command: ./.circleci/manage-test-db.sh << parameters.engine >> "$TEST_DB_NAME" down - when: always - - store_test_results: - path: test-results - -workflows: - main_pr: - jobs: - - doc_tests - - style_and_cicd_tests: - matrix: - parameters: - python_version: - - "3.9" - - "3.10" - - "3.11" - - "3.12" - - "3.13" - - cicd_tests_windows - - engine_tests_docker: - name: engine_<< matrix.engine >> - matrix: - parameters: - engine: - - duckdb - - postgres - - mysql - - mssql - - trino - - spark - - clickhouse - - risingwave - - engine_tests_cloud: - name: cloud_engine_<< matrix.engine >> - context: - - sqlmesh_cloud_database_integration - requires: - - engine_tests_docker - matrix: - parameters: - engine: - - snowflake - - databricks - - redshift - - bigquery - - clickhouse-cloud - - athena - - fabric - - gcp-postgres - filters: - branches: - only: - - main - - ui_style - - ui_test - - vscode_test - - migration_test diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000..7585f0ce10 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,16 @@ +## Description + + + +## Test Plan + + + +## Checklist + +- [ ] I have run `make style` and fixed any issues +- [ ] I have added tests for my changes (if applicable) +- [ ] All existing tests pass (`make fast-test`) +- [ ] My commits are signed off (`git commit -s`) per the [DCO](DCO) + + diff --git a/.circleci/install-prerequisites.sh b/.github/scripts/install-prerequisites.sh similarity index 89% rename from .circleci/install-prerequisites.sh rename to .github/scripts/install-prerequisites.sh index 446221dba6..6ab602fc37 100755 --- a/.circleci/install-prerequisites.sh +++ b/.github/scripts/install-prerequisites.sh @@ -1,6 +1,6 @@ #!/bin/bash -# This script is intended to be run by an Ubuntu build agent on CircleCI +# This script is intended to be run by an Ubuntu CI build agent # The goal is to install OS-level dependencies that are required before trying to install Python dependencies set -e @@ -25,7 +25,7 @@ elif [ "$ENGINE" == "fabric" ]; then sudo dpkg -i packages-microsoft-prod.deb rm packages-microsoft-prod.deb - ENGINE_DEPENDENCIES="msodbcsql18" + ENGINE_DEPENDENCIES="msodbcsql18" fi ALL_DEPENDENCIES="$COMMON_DEPENDENCIES $ENGINE_DEPENDENCIES" @@ -39,4 +39,4 @@ if [ "$ENGINE" == "spark" ]; then java -version fi -echo "All done" \ No newline at end of file +echo "All done" diff --git a/.circleci/manage-test-db.sh b/.github/scripts/manage-test-db.sh similarity index 88% rename from .circleci/manage-test-db.sh rename to .github/scripts/manage-test-db.sh index b6e9c265c9..29d11afcc0 100755 --- a/.circleci/manage-test-db.sh +++ b/.github/scripts/manage-test-db.sh @@ -68,10 +68,10 @@ redshift_down() { EXIT_CODE=1 ATTEMPTS=0 while [ $EXIT_CODE -ne 0 ] && [ $ATTEMPTS -lt 5 ]; do - # note: sometimes this pg_terminate_backend() call can randomly fail with: ERROR: Insufficient privileges + # note: sometimes this pg_terminate_backend() call can randomly fail with: ERROR: Insufficient privileges # if it does, let's proceed with the drop anyway rather than aborting and never attempting the drop redshift_exec "select pg_terminate_backend(procpid) from pg_stat_activity where datname = '$1'" || true - + # perform drop redshift_exec "drop database $1;" && EXIT_CODE=$? || EXIT_CODE=$? if [ $EXIT_CODE -ne 0 ]; then @@ -103,14 +103,16 @@ clickhouse-cloud_init() { # GCP Postgres gcp-postgres_init() { - # Download and start Cloud SQL Proxy - curl -fsSL -o cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.18.0/cloud-sql-proxy.linux.amd64 - chmod +x cloud-sql-proxy + # Download Cloud SQL Proxy if not already present + if [ ! -f cloud-sql-proxy ]; then + curl -fsSL -o cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.18.0/cloud-sql-proxy.linux.amd64 + chmod +x cloud-sql-proxy + fi echo "$GCP_POSTGRES_KEYFILE_JSON" > /tmp/keyfile.json - ./cloud-sql-proxy --credentials-file /tmp/keyfile.json $GCP_POSTGRES_INSTANCE_CONNECTION_STRING & - - # Wait for proxy to start - sleep 5 + if ! pgrep -x cloud-sql-proxy > /dev/null; then + ./cloud-sql-proxy --credentials-file /tmp/keyfile.json $GCP_POSTGRES_INSTANCE_CONNECTION_STRING & + sleep 5 + fi } gcp-postgres_exec() { @@ -126,13 +128,13 @@ gcp-postgres_down() { } # Fabric -fabric_init() { +fabric_init() { python --version #note: as at 2025-08-20, ms-fabric-cli is pinned to Python >= 3.10, <3.13 pip install ms-fabric-cli - + # to prevent the '[EncryptionFailed] An error occurred with the encrypted cache.' error # ref: https://microsoft.github.io/fabric-cli/#switch-to-interactive-mode-optional - fab config set encryption_fallback_enabled true + fab config set encryption_fallback_enabled true echo "Logging in to Fabric" fab auth login -u $FABRIC_CLIENT_ID -p $FABRIC_CLIENT_SECRET --tenant $FABRIC_TENANT_ID diff --git a/.circleci/test_migration.sh b/.github/scripts/test_migration.sh similarity index 91% rename from .circleci/test_migration.sh rename to .github/scripts/test_migration.sh index bb1776550a..ec45772c73 100755 --- a/.circleci/test_migration.sh +++ b/.github/scripts/test_migration.sh @@ -30,12 +30,14 @@ cp -r "$EXAMPLE_DIR" "$TEST_DIR" git checkout $LAST_TAG # Install dependencies from the previous release. +uv venv .venv --clear +source .venv/bin/activate make install-dev # this is only needed temporarily until the released tag for $LAST_TAG includes this config if [ "$EXAMPLE_NAME" == "sushi_dbt" ]; then echo 'migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb")' >> $TEST_DIR/config.py -fi +fi # Run initial plan pushd $TEST_DIR @@ -44,10 +46,12 @@ sqlmesh $SQLMESH_OPTS plan --no-prompts --auto-apply rm -rf .cache popd -# Switch back to the starting state of the repository +# Switch back to the starting state of the repository git checkout - # Install updated dependencies. +uv venv .venv --clear +source .venv/bin/activate make install-dev # Migrate and make sure the diff is empty diff --git a/.circleci/update-pypirc.sh b/.github/scripts/update-pypirc.sh similarity index 100% rename from .circleci/update-pypirc.sh rename to .github/scripts/update-pypirc.sh diff --git a/.circleci/wait-for-db.sh b/.github/scripts/wait-for-db.sh similarity index 98% rename from .circleci/wait-for-db.sh rename to .github/scripts/wait-for-db.sh index a313320279..07502e3898 100755 --- a/.circleci/wait-for-db.sh +++ b/.github/scripts/wait-for-db.sh @@ -80,4 +80,4 @@ while [ $EXIT_CODE -ne 0 ]; do fi done -echo "$ENGINE is ready!" \ No newline at end of file +echo "$ENGINE is ready!" diff --git a/.github/workflows/dco.yml b/.github/workflows/dco.yml new file mode 100644 index 0000000000..a1c4e07300 --- /dev/null +++ b/.github/workflows/dco.yml @@ -0,0 +1,17 @@ +name: Sanity check +on: [pull_request] + +jobs: + commits_check_job: + runs-on: ubuntu-latest + name: Commits Check + steps: + - name: Get PR Commits + id: 'get-pr-commits' + uses: tim-actions/get-pr-commits@master + with: + token: ${{ secrets.GITHUB_TOKEN }} + - name: DCO Check + uses: tim-actions/dco@master + with: + commits: ${{ steps.get-pr-commits.outputs.commits }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 69e93635dc..4395c56313 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -6,11 +6,392 @@ on: branches: - main concurrency: - group: 'pr-${{ github.event.pull_request.number }}' + group: pr-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true permissions: contents: read jobs: + changes: + runs-on: ubuntu-latest + outputs: + python: ${{ steps.filter.outputs.python }} + client: ${{ steps.filter.outputs.client }} + ci: ${{ steps.filter.outputs.ci }} + steps: + - uses: actions/checkout@v5 + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: | + python: + - 'sqlmesh/**' + - 'tests/**' + - 'examples/**' + - 'web/server/**' + - 'pytest.ini' + - 'setup.cfg' + - 'setup.py' + - 'pyproject.toml' + client: + - 'web/client/**' + ci: + - '.github/**' + - 'Makefile' + - '.pre-commit-config.yaml' + + doc-tests: + needs: changes + if: + needs.changes.outputs.python == 'true' || needs.changes.outputs.ci == + 'true' || github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + env: + UV: '1' + steps: + - uses: actions/checkout@v5 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install dependencies + run: | + uv venv .venv + source .venv/bin/activate + make install-dev install-doc + - name: Run doc tests + run: | + source .venv/bin/activate + make doc-test + + style-and-cicd-tests: + needs: changes + if: + needs.changes.outputs.python == 'true' || needs.changes.outputs.ci == + 'true' || github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + env: + PYTEST_XDIST_AUTO_NUM_WORKERS: 2 + UV: '1' + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install OpenJDK and ODBC + run: + sudo apt-get update && sudo apt-get install -y default-jdk + unixodbc-dev + - name: Install SQLMesh dev dependencies + run: | + uv venv .venv + source .venv/bin/activate + make install-dev + - name: Fix Git URL override + run: + git config --global --unset url."ssh://git@github.com".insteadOf || + true + - name: Run linters and code style checks + run: | + source .venv/bin/activate + make py-style + - name: Exercise the benchmarks + if: matrix.python-version != '3.9' + run: | + source .venv/bin/activate + make benchmark-ci + - name: Run cicd tests + run: | + source .venv/bin/activate + make cicd-test + - name: Upload test results + uses: actions/upload-artifact@v5 + if: ${{ !cancelled() }} + with: + name: test-results-style-cicd-${{ matrix.python-version }} + path: test-results/ + retention-days: 7 + + cicd-tests-windows: + needs: changes + if: + needs.changes.outputs.python == 'true' || needs.changes.outputs.ci == + 'true' || github.ref == 'refs/heads/main' + runs-on: windows-latest + steps: + - name: Enable symlinks in git config + run: git config --global core.symlinks true + - uses: actions/checkout@v5 + - name: Install make + run: choco install make which -y + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' + - name: Install SQLMesh dev dependencies + run: | + python -m venv venv + . ./venv/Scripts/activate + python.exe -m pip install --upgrade pip + make install-dev + - name: Run fast unit tests + run: | + . ./venv/Scripts/activate + which python + python --version + make fast-test + - name: Upload test results + uses: actions/upload-artifact@v5 + if: ${{ !cancelled() }} + with: + name: test-results-windows + path: test-results/ + retention-days: 7 + + migration-test: + needs: changes + if: + needs.changes.outputs.python == 'true' || needs.changes.outputs.ci == + 'true' || github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + env: + SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: '1' + UV: '1' + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Run migration test - sushi + run: + ./.github/scripts/test_migration.sh sushi "--gateway + duckdb_persistent" + - name: Run migration test - sushi_dbt + run: + ./.github/scripts/test_migration.sh sushi_dbt "--config + migration_test_config" + + ui-style: + needs: [changes] + if: + needs.changes.outputs.client == 'true' || needs.changes.outputs.ci == + 'true' || github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-node@v6 + with: + node-version: '20' + - uses: pnpm/action-setup@v4 + with: + version: latest + - name: Get pnpm store directory + id: pnpm-cache + run: echo "store=$(pnpm store path)" >> $GITHUB_OUTPUT + - uses: actions/cache@v4 + with: + path: ${{ steps.pnpm-cache.outputs.store }} + key: pnpm-store-${{ hashFiles('pnpm-lock.yaml') }} + restore-keys: pnpm-store- + - name: Install dependencies + run: pnpm install + - name: Run linters and code style checks + run: pnpm run lint + + ui-test: + needs: changes + if: + needs.changes.outputs.client == 'true' || needs.changes.outputs.ci == + 'true' || github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + container: + image: mcr.microsoft.com/playwright:v1.54.1-jammy + steps: + - uses: actions/checkout@v5 + - name: Install pnpm via corepack + run: | + npm install --global corepack@latest + corepack enable + corepack prepare pnpm@latest-10 --activate + pnpm config set store-dir .pnpm-store + - name: Install dependencies + run: pnpm install + - name: Build UI + run: npm --prefix web/client run build + - name: Run unit tests + run: npm --prefix web/client run test:unit + - name: Run e2e tests + run: npm --prefix web/client run test:e2e + env: + PLAYWRIGHT_SKIP_BUILD: '1' + HOME: /root + + engine-tests-docker: + needs: changes + if: + needs.changes.outputs.python == 'true' || needs.changes.outputs.ci == + 'true' || github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + timeout-minutes: 25 + strategy: + fail-fast: false + matrix: + engine: + [duckdb, postgres, mysql, mssql, trino, spark, clickhouse, risingwave] + env: + PYTEST_XDIST_AUTO_NUM_WORKERS: 2 + SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: '1' + UV: '1' + steps: + - uses: actions/checkout@v5 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install SQLMesh dev dependencies + run: | + uv venv .venv + source .venv/bin/activate + make install-dev + - name: Install OS-level dependencies + run: ./.github/scripts/install-prerequisites.sh "${{ matrix.engine }}" + - name: Run tests + run: | + source .venv/bin/activate + make ${{ matrix.engine }}-test + - name: Upload test results + uses: actions/upload-artifact@v5 + if: ${{ !cancelled() }} + with: + name: test-results-docker-${{ matrix.engine }} + path: test-results/ + retention-days: 7 + + engine-tests-cloud: + needs: engine-tests-docker + if: github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + timeout-minutes: 25 + strategy: + fail-fast: false + matrix: + engine: + [ + snowflake, + databricks, + redshift, + bigquery, + clickhouse-cloud, + athena, + fabric, + gcp-postgres, + ] + env: + PYTEST_XDIST_AUTO_NUM_WORKERS: 4 + SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: '1' + UV: '1' + SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} + SNOWFLAKE_USER: ${{ secrets.SNOWFLAKE_USER }} + SNOWFLAKE_WAREHOUSE: ${{ secrets.SNOWFLAKE_WAREHOUSE }} + SNOWFLAKE_AUTHENTICATOR: SNOWFLAKE_JWT + DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_SERVER_HOSTNAME }} + DATABRICKS_HOST: ${{ secrets.DATABRICKS_SERVER_HOSTNAME }} + DATABRICKS_HTTP_PATH: ${{ secrets.DATABRICKS_HTTP_PATH }} + DATABRICKS_CLIENT_ID: ${{ secrets.DATABRICKS_CLIENT_ID }} + DATABRICKS_CLIENT_SECRET: ${{ secrets.DATABRICKS_CLIENT_SECRET }} + DATABRICKS_CONNECT_VERSION: ${{ secrets.DATABRICKS_CONNECT_VERSION }} + REDSHIFT_HOST: ${{ secrets.REDSHIFT_HOST }} + REDSHIFT_PORT: ${{ secrets.REDSHIFT_PORT }} + REDSHIFT_USER: ${{ secrets.REDSHIFT_USER }} + REDSHIFT_PASSWORD: ${{ secrets.REDSHIFT_PASSWORD }} + BIGQUERY_KEYFILE: ${{ secrets.BIGQUERY_KEYFILE }} + BIGQUERY_KEYFILE_CONTENTS: ${{ secrets.BIGQUERY_KEYFILE_CONTENTS }} + CLICKHOUSE_CLOUD_HOST: ${{ secrets.CLICKHOUSE_CLOUD_HOST }} + CLICKHOUSE_CLOUD_USERNAME: ${{ secrets.CLICKHOUSE_CLOUD_USERNAME }} + CLICKHOUSE_CLOUD_PASSWORD: ${{ secrets.CLICKHOUSE_CLOUD_PASSWORD }} + GCP_POSTGRES_KEYFILE_JSON: ${{ secrets.GCP_POSTGRES_KEYFILE_JSON }} + GCP_POSTGRES_INSTANCE_CONNECTION_STRING: + ${{ secrets.GCP_POSTGRES_INSTANCE_CONNECTION_STRING }} + GCP_POSTGRES_USER: ${{ secrets.GCP_POSTGRES_USER }} + GCP_POSTGRES_PASSWORD: ${{ secrets.GCP_POSTGRES_PASSWORD }} + ATHENA_S3_WAREHOUSE_LOCATION: ${{ secrets.ATHENA_S3_WAREHOUSE_LOCATION }} + ATHENA_WORK_GROUP: ${{ secrets.ATHENA_WORK_GROUP }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + AWS_REGION: ${{ secrets.AWS_REGION }} + FABRIC_HOST: ${{ secrets.FABRIC_HOST }} + FABRIC_CLIENT_ID: ${{ secrets.FABRIC_CLIENT_ID }} + FABRIC_CLIENT_SECRET: ${{ secrets.FABRIC_CLIENT_SECRET }} + FABRIC_TENANT_ID: ${{ secrets.FABRIC_TENANT_ID }} + FABRIC_WORKSPACE_ID: ${{ secrets.FABRIC_WORKSPACE_ID }} + steps: + - uses: actions/checkout@v5 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install OS-level dependencies + run: ./.github/scripts/install-prerequisites.sh "${{ matrix.engine }}" + - name: Install SQLMesh dev dependencies + run: | + uv venv .venv + source .venv/bin/activate + make install-dev + - name: Generate database name and setup credentials + run: | + UUID=$(cat /proc/sys/kernel/random/uuid) + TEST_DB_NAME="ci_${UUID:0:8}" + echo "TEST_DB_NAME=$TEST_DB_NAME" >> $GITHUB_ENV + echo "SNOWFLAKE_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV + echo "DATABRICKS_CATALOG=$TEST_DB_NAME" >> $GITHUB_ENV + echo "REDSHIFT_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV + echo "GCP_POSTGRES_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV + echo "FABRIC_DATABASE=$TEST_DB_NAME" >> $GITHUB_ENV + + echo "$SNOWFLAKE_PRIVATE_KEY_RAW" | base64 -d > /tmp/snowflake-keyfile.p8 + echo "SNOWFLAKE_PRIVATE_KEY_FILE=/tmp/snowflake-keyfile.p8" >> $GITHUB_ENV + env: + SNOWFLAKE_PRIVATE_KEY_RAW: ${{ secrets.SNOWFLAKE_PRIVATE_KEY_RAW }} + - name: Create test database + run: + ./.github/scripts/manage-test-db.sh "${{ matrix.engine }}" + "$TEST_DB_NAME" up + - name: Run tests + run: | + source .venv/bin/activate + make ${{ matrix.engine }}-test + - name: Tear down test database + if: always() + run: + ./.github/scripts/manage-test-db.sh "${{ matrix.engine }}" + "$TEST_DB_NAME" down + - name: Upload test results + uses: actions/upload-artifact@v5 + if: ${{ !cancelled() }} + with: + name: test-results-cloud-${{ matrix.engine }} + path: test-results/ + retention-days: 7 + test-vscode: env: PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: 1 @@ -100,30 +481,30 @@ jobs: if [[ "${{ matrix.dbt-version }}" == "1.3" ]] || \ [[ "${{ matrix.dbt-version }}" == "1.4" ]] || \ [[ "${{ matrix.dbt-version }}" == "1.5" ]]; then - + echo "DBT version is ${{ matrix.dbt-version }} (< 1.6.0), removing semantic_models and metrics sections..." - + schema_file="tests/fixtures/dbt/sushi_test/models/schema.yml" if [[ -f "$schema_file" ]]; then echo "Modifying $schema_file..." - + # Create a temporary file temp_file=$(mktemp) - + # Use awk to remove semantic_models and metrics sections awk ' /^semantic_models:/ { in_semantic=1; next } /^metrics:/ { in_metrics=1; next } - /^[^ ]/ && (in_semantic || in_metrics) { - in_semantic=0; - in_metrics=0 + /^[^ ]/ && (in_semantic || in_metrics) { + in_semantic=0; + in_metrics=0 } !in_semantic && !in_metrics { print } ' "$schema_file" > "$temp_file" - + # Move the temp file back mv "$temp_file" "$schema_file" - + echo "Successfully removed semantic_models and metrics sections" else echo "Schema file not found at $schema_file, skipping..." diff --git a/.github/workflows/private-repo-test.yaml b/.github/workflows/private-repo-test.yaml deleted file mode 100644 index 9b2365f48a..0000000000 --- a/.github/workflows/private-repo-test.yaml +++ /dev/null @@ -1,97 +0,0 @@ -name: Private Repo Testing - -on: - pull_request_target: - branches: - - main - -concurrency: - group: 'private-test-${{ github.event.pull_request.number }}' - cancel-in-progress: true - -permissions: - contents: read - -jobs: - trigger-private-test: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v5 - with: - fetch-depth: 0 - ref: ${{ github.event.pull_request.head.sha || github.ref }} - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: '3.12' - - name: Install uv - uses: astral-sh/setup-uv@v7 - - name: Set up Node.js for UI build - uses: actions/setup-node@v6 - with: - node-version: '20' - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - version: latest - - name: Install UI dependencies - run: pnpm install - - name: Build UI - run: pnpm --prefix web/client run build - - name: Install Python dependencies - run: | - python -m venv .venv - source .venv/bin/activate - pip install build twine setuptools_scm - - name: Generate development version - id: version - run: | - source .venv/bin/activate - # Generate a PEP 440 compliant unique version including run attempt - BASE_VERSION=$(python .github/scripts/get_scm_version.py) - COMMIT_SHA=$(git rev-parse --short HEAD) - # Use PEP 440 compliant format: base.devN+pr.sha.attempt - UNIQUE_VERSION="${BASE_VERSION}+pr${{ github.event.pull_request.number }}.${COMMIT_SHA}.run${{ github.run_attempt }}" - echo "version=$UNIQUE_VERSION" >> $GITHUB_OUTPUT - echo "Generated unique version with run attempt: $UNIQUE_VERSION" - - name: Build package - env: - SETUPTOOLS_SCM_PRETEND_VERSION: ${{ steps.version.outputs.version }} - run: | - source .venv/bin/activate - python -m build - - name: Configure PyPI for private repository - env: - TOBIKO_PRIVATE_PYPI_URL: ${{ secrets.TOBIKO_PRIVATE_PYPI_URL }} - TOBIKO_PRIVATE_PYPI_KEY: ${{ secrets.TOBIKO_PRIVATE_PYPI_KEY }} - run: ./.circleci/update-pypirc.sh - - name: Publish to private PyPI - run: | - source .venv/bin/activate - python -m twine upload -r tobiko-private dist/* - - name: Publish Python Tests package - env: - SETUPTOOLS_SCM_PRETEND_VERSION: ${{ steps.version.outputs.version }} - run: | - source .venv/bin/activate - unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests - - name: Get GitHub App token - id: get_token - uses: actions/create-github-app-token@v2 - with: - private-key: ${{ secrets.TOBIKO_RENOVATE_BOT_PRIVATE_KEY }} - app-id: ${{ secrets.TOBIKO_RENOVATE_BOT_APP_ID }} - owner: ${{ secrets.PRIVATE_REPO_OWNER }} - - name: Trigger private repository workflow - uses: convictional/trigger-workflow-and-wait@v1.6.5 - with: - owner: ${{ secrets.PRIVATE_REPO_OWNER }} - repo: ${{ secrets.PRIVATE_REPO_NAME }} - github_token: ${{ steps.get_token.outputs.token }} - workflow_file_name: ${{ secrets.PRIVATE_WORKFLOW_FILE }} - client_payload: | - { - "package_version": "${{ steps.version.outputs.version }}", - "pr_number": "${{ github.event.pull_request.number }}" - } diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000000..75512ffd72 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,71 @@ +name: Release +on: + push: + tags: + - 'v*.*.*' +permissions: + contents: write +jobs: + ui-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-node@v6 + with: + node-version: '20' + - uses: pnpm/action-setup@v4 + with: + version: latest + - name: Install dependencies + run: pnpm install + - name: Build UI + run: pnpm --prefix web/client run build + - name: Upload UI build artifact + uses: actions/upload-artifact@v5 + with: + name: ui-dist + path: web/client/dist/ + retention-days: 1 + + publish: + needs: ui-build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Download UI build artifact + uses: actions/download-artifact@v4 + with: + name: ui-dist + path: web/client/dist/ + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install build dependencies + run: pip install build twine setuptools_scm + - name: Publish Python package + run: make publish + env: + TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} + TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} + - name: Update pypirc for private repository + run: ./.github/scripts/update-pypirc.sh + env: + TOBIKO_PRIVATE_PYPI_URL: ${{ secrets.TOBIKO_PRIVATE_PYPI_URL }} + TOBIKO_PRIVATE_PYPI_KEY: ${{ secrets.TOBIKO_PRIVATE_PYPI_KEY }} + - name: Publish Python Tests package + run: unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests + + gh-release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + - name: Create release on GitHub + uses: softprops/action-gh-release@v2 + with: + generate_release_notes: true + tag_name: ${{ github.ref_name }} diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..287a87dab5 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +SQLMesh follows the [LF Projects Code of Conduct](https://lfprojects.org/policies/code-of-conduct/). All participants in the project are expected to abide by it. + +If you believe someone is violating the code of conduct, please report it by following the instructions in the [LF Projects Code of Conduct](https://lfprojects.org/policies/code-of-conduct/). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..0e1d8e1c6e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,90 @@ +# Contributing to SQLMesh + +## Welcome + +SQLMesh is a project of the Linux Foundation. We welcome contributions from anyone β whether you're fixing a bug, improving documentation, or proposing a new feature. + +## Technical Steering Committee (TSC) + +The TSC is responsible for technical oversight of the SQLMesh project, including coordinating technical direction, approving contribution policies, and maintaining community norms. + +Initial TSC voting members are the project's Maintainers: + +| Name | GitHub Handle | Affiliation | Role | +|---------------------|---------------|----------------|------------| +| Alexander Butler | z3z1ma | Harness | TSC Member | +| Alexander Filipchik | afilipchik | Cloud Kitchens | TSC Member | +| Reid Hooper | rhooper9711 | Benzinga | TSC Member | +| Yuki Kakegawa | StuffbyYuki | Jump.ai | TSC Member | +| Toby Mao | tobymao | Fivetran | TSC Chair | +| Alex Wilde | alexminerv | Minerva | TSC Member | + + +## Roles + +**Contributors**: Anyone who contributes code, documentation, or other technical artifacts to the project. + +**Maintainers**: Contributors who have earned the ability to modify source code, documentation, or other technical artifacts. A Contributor may become a Maintainer by majority approval of the TSC. A Maintainer may be removed by majority approval of the TSC. + +## How to Contribute + +1. Fork the repository on GitHub +2. Create a branch for your changes +3. Make your changes and commit them with a sign-off (see DCO section below) +4. Submit a pull request against the `main` branch + +File issues at [github.com/sqlmesh/sqlmesh/issues](https://github.com/sqlmesh/sqlmesh/issues). + +## Developer Certificate of Origin (DCO) + +All contributions must include a `Signed-off-by` line in the commit message per the [Developer Certificate of Origin](DCO). This certifies that you wrote the contribution or have the right to submit it under the project's open source license. + +Use `git commit -s` to add the sign-off automatically: + +```bash +git commit -s -m "Your commit message" +``` + +To fix a commit that is missing the sign-off: + +```bash +git commit --amend -s +``` + +To add a sign-off to multiple commits: + +```bash +git rebase HEAD~N --signoff +``` + +## Development Setup + +See [docs/development.md](docs/development.md) for full setup instructions. Key commands: + +```bash +python -m venv .venv +source .venv/bin/activate +make install-dev +make style # Run before submitting +make fast-test # Quick test suite +``` + +## Coding Standards + +- Run `make style` before submitting a pull request +- Follow existing code patterns and conventions in the codebase +- New files should include an SPDX license header: + ```python + # SPDX-License-Identifier: Apache-2.0 + ``` + +## Pull Request Process + +- Describe your changes clearly in the pull request description +- Ensure all CI checks pass +- Include a DCO sign-off on all commits (`git commit -s`) +- Be responsive to review feedback from maintainers + +## Licensing + +Code contributions are licensed under the [Apache License 2.0](LICENSE). Documentation contributions are licensed under [Creative Commons Attribution 4.0 International (CC-BY-4.0)](https://creativecommons.org/licenses/by/4.0/). See the LICENSE file and the [technical charter](sqlmesh-technical-charter.pdf) for details. diff --git a/DCO b/DCO new file mode 100644 index 0000000000..49b8cb0549 --- /dev/null +++ b/DCO @@ -0,0 +1,34 @@ +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. diff --git a/GOVERNANCE.md b/GOVERNANCE.md new file mode 100644 index 0000000000..44b6bc9947 --- /dev/null +++ b/GOVERNANCE.md @@ -0,0 +1,62 @@ +# SQLMesh Project Governance + +## Overview + +SQLMesh is a Series of LF Projects, LLC. The project is governed by its [Technical Charter](sqlmesh-technical-charter.pdf) and overseen by the Technical Steering Committee (TSC). SQLMesh is a project of the [Linux Foundation](https://www.linuxfoundation.org/). + +## Technical Steering Committee + +The TSC is responsible for all technical oversight of the project, including: + +- Coordinating the technical direction of the project +- Approving project or system proposals +- Organizing sub-projects and removing sub-projects +- Creating sub-committees or working groups to focus on cross-project technical issues +- Appointing representatives to work with other open source or open standards communities +- Establishing community norms, workflows, issuing releases, and security vulnerability reports +- Approving and implementing policies for contribution requirements +- Coordinating any marketing, events, or communications regarding the project + +## TSC Composition + +TSC voting members are initially the project's Maintainers as listed in [CONTRIBUTING.md](CONTRIBUTING.md). The TSC may elect a Chair from among its voting members. The Chair presides over TSC meetings and serves as the primary point of contact with the Linux Foundation. + +## Decision Making + +The project operates as a consensus-based community. When a formal vote is required: + +- Each voting TSC member receives one vote +- A quorum of 50% of voting members is required to conduct a vote +- Decisions are made by a majority of those present when quorum is met +- Electronic votes (e.g., via GitHub issues or mailing list) require a majority of all voting members to pass +- Votes that do not meet quorum or remain unresolved may be referred to the Series Manager for resolution + +## Charter Amendments + +The technical charter may be amended by a two-thirds vote of the entire TSC, subject to approval by LF Projects, LLC. + +## Reference + +The full technical charter is available at [sqlmesh-technical-charter.pdf](sqlmesh-technical-charter.pdf). + +# TSC Meeting Minutes + +## 2026-03-10 β Initial TSC Meeting + +**Members present:** Toby Mao (tobymao) + +### Vote 1: Elect Toby Mao as TSC Chair +- **Motion by:** Toby Mao +- **Votes:** Toby Mao: Yes +- **Result:** Approved (1-0-0, yes-no-abstain) + +### Vote 2: Elect TSC founding members +- **Question:** Shall the following members be added to the TSC? + - Alexander Butler (z3z1ma) + - Alexander Filipchik (afilipchik) + - Reid Hooper (rhooper9711) + - Yuki Kakegawa (StuffbyYuki) + - Alex Wilde (alexminerv) +- **Motion by:** Toby Mao +- **Votes:** Toby Mao: Yes +- **Result:** Approved (1-0-0, yes-no-abstain) diff --git a/LICENSE b/LICENSE index eabfad022a..7e95724816 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2024 Tobiko Data Inc. + Copyright Contributors to the SQLMesh project Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile index e7a78de472..843beb0624 100644 --- a/Makefile +++ b/Makefile @@ -130,7 +130,7 @@ slow-test: pytest -n auto -m "(fast or slow) and not cicdonly" && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated" cicd-test: - pytest -n auto -m "fast or slow" --junitxml=test-results/junit-cicd.xml && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated" + pytest -n auto -m "(fast or slow) and not pyspark" --junitxml=test-results/junit-cicd.xml && pytest -m "pyspark" && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated" core-fast-test: pytest -n auto -m "fast and not web and not github and not dbt and not jupyter" @@ -166,7 +166,7 @@ web-test: pytest -n auto -m "web" guard-%: - @ if [ "${${*}}" = "" ]; then \ + @ if ! printenv ${*} > /dev/null 2>&1; then \ echo "Environment variable $* not set"; \ exit 1; \ fi @@ -176,7 +176,7 @@ engine-%-install: engine-docker-%-up: docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml up -d - ./.circleci/wait-for-db.sh ${*} + ./.github/scripts/wait-for-db.sh ${*} engine-%-up: engine-%-install engine-docker-%-up @echo "Engine '${*}' is up and running" diff --git a/README.md b/README.md index 3215f7cceb..41f78cc138 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@
SQLMesh is a project of the Linux Foundation.
SQLMesh is a next-generation data transformation framework designed to ship data quickly, efficiently, and without error. Data teams can run and deploy data transformations written in SQL or Python with visibility and control at any size. @@ -12,7 +13,7 @@ It is more than just a [dbt alternative](https://tobikodata.com/reduce_costs_wit ## Core Features -
+
> Get instant SQL impact and context of your changes, both in the CLI and in the [SQLMesh VSCode Extension](https://sqlmesh.readthedocs.io/en/latest/guides/vscode/?h=vs+cod)
@@ -121,19 +122,19 @@ outputs:
* Never build a table [more than once](https://tobikodata.com/simplicity-or-efficiency-how-dbt-makes-you-choose.html)
* Track what dataβs been modified and run only the necessary transformations for [incremental models](https://tobikodata.com/correctly-loading-incremental-data-at-scale.html)
* Run [unit tests](https://tobikodata.com/we-need-even-greater-expectations.html) for free and configure automated audits
-* Run [table diffs](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/?h=crash#run-data-diff-against-prod) between prod and dev based on tables/views impacted by a change
+* Run [table diffs](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/?h=crash#run-data-diff-against-prod) between prod and dev based on tables/views impacted by a change
+
-
+
-
+
+
> Get instant SQL impact analysis of your changes, whether in the CLI or in [SQLMesh Plan Mode](https://sqlmesh.readthedocs.io/en/stable/guides/ui/?h=modes#working-with-an-ide)
@@ -121,7 +121,7 @@ It is more than just a [dbt alternative](https://tobikodata.com/reduce_costs_wit
??? tip "Level Up Your SQL"
Write SQL in any dialect and SQLMesh will transpile it to your target SQL dialect on the fly before sending it to the warehouse.
-
+
* Debug transformation errors *before* you run them in your warehouse in [10+ different SQL dialects](https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines)
* Definitions using [simply SQL](https://sqlmesh.readthedocs.io/en/stable/concepts/models/sql_models/#sql-based-definition) (no need for redundant and confusing `Jinja` + `YAML`)
@@ -153,7 +153,7 @@ Follow this [example](https://sqlmesh.readthedocs.io/en/stable/examples/incremen
Together, we want to build data transformation without the waste. Connect with us in the following ways:
* Join the [Tobiko Slack Community](https://tobikodata.com/slack) to ask questions, or just to say hi!
-* File an issue on our [GitHub](https://github.com/TobikoData/sqlmesh/issues/new)
+* File an issue on our [GitHub](https://github.com/SQLMesh/sqlmesh/issues/new)
* Send us an email at [hello@tobikodata.com](mailto:hello@tobikodata.com) with your questions or feedback
* Read our [blog](https://tobikodata.com/blog)
diff --git a/docs/integrations/dbt.md b/docs/integrations/dbt.md
index 7cbef5b8fa..5854236aa2 100644
--- a/docs/integrations/dbt.md
+++ b/docs/integrations/dbt.md
@@ -358,4 +358,4 @@ The dbt jinja methods that are not currently supported are:
## Missing something you need?
-Submit an [issue](https://github.com/TobikoData/sqlmesh/issues), and we'll look into it!
+Submit an [issue](https://github.com/SQLMesh/sqlmesh/issues), and we'll look into it!
diff --git a/docs/integrations/dlt.md b/docs/integrations/dlt.md
index a53dc184ea..7125510de9 100644
--- a/docs/integrations/dlt.md
+++ b/docs/integrations/dlt.md
@@ -70,7 +70,7 @@ SQLMesh will retrieve the data warehouse connection credentials from your dlt pr
### Example
-Generating a SQLMesh project dlt is quite simple. In this example, we'll use the example `sushi_pipeline.py` from the [sushi-dlt project](https://github.com/TobikoData/sqlmesh/tree/main/examples/sushi_dlt).
+Generating a SQLMesh project dlt is quite simple. In this example, we'll use the example `sushi_pipeline.py` from the [sushi-dlt project](https://github.com/SQLMesh/sqlmesh/tree/main/examples/sushi_dlt).
First, run the pipeline within the project directory:
diff --git a/docs/integrations/github.md b/docs/integrations/github.md
index 923714888e..07903fce56 100644
--- a/docs/integrations/github.md
+++ b/docs/integrations/github.md
@@ -364,7 +364,7 @@ These are the possible outputs (based on how the bot is configured) that are cre
* `prod_plan_preview`
* `prod_environment_synced`
-[There are many possible conclusions](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/integrations/github/cicd/controller.py#L96-L102) so the best use case for this is likely to check for `success` conclusion in order to potentially run follow up steps.
+[There are many possible conclusions](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/integrations/github/cicd/controller.py#L96-L102) so the best use case for this is likely to check for `success` conclusion in order to potentially run follow up steps.
Note that in error cases conclusions may not be set and therefore you will get an empty string.
Example of running a step after pr environment has been synced:
diff --git a/docs/quickstart/cli.md b/docs/quickstart/cli.md
index 7b77b2af1e..a592847470 100644
--- a/docs/quickstart/cli.md
+++ b/docs/quickstart/cli.md
@@ -160,7 +160,7 @@ https://sqlmesh.readthedocs.io/en/stable/quickstart/cli/
Need help?
- Docs: https://sqlmesh.readthedocs.io
- Slack: https://www.tobikodata.com/slack
-- GitHub: https://github.com/TobikoData/sqlmesh/issues
+- GitHub: https://github.com/SQLMesh/sqlmesh/issues
```
??? info "Learn more about the project's configuration: `config.yaml`"
diff --git a/docs/reference/python.md b/docs/reference/python.md
index 14e0da84c8..1c4c9191ff 100644
--- a/docs/reference/python.md
+++ b/docs/reference/python.md
@@ -4,6 +4,6 @@ SQLMesh is built in Python, and its complete Python API reference is located [he
The Python API reference is comprehensive and includes the internal components of SQLMesh. Those components are likely only of interest if you want to modify SQLMesh itself.
-If you want to use SQLMesh via its Python API, the best approach is to study how the SQLMesh [CLI](./cli.md) calls it behind the scenes. The CLI implementation code shows exactly which Python methods are called for each CLI command and can be viewed [on Github](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/cli/main.py). For example, the Python code executed by the `plan` command is located [here](https://github.com/TobikoData/sqlmesh/blob/15c8788100fa1cfb8b0cc1879ccd1ad21dc3e679/sqlmesh/cli/main.py#L302).
+If you want to use SQLMesh via its Python API, the best approach is to study how the SQLMesh [CLI](./cli.md) calls it behind the scenes. The CLI implementation code shows exactly which Python methods are called for each CLI command and can be viewed [on Github](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/cli/main.py). For example, the Python code executed by the `plan` command is located [here](https://github.com/SQLMesh/sqlmesh/blob/15c8788100fa1cfb8b0cc1879ccd1ad21dc3e679/sqlmesh/cli/main.py#L302).
Almost all the relevant Python methods are in the [SQLMesh `Context` class](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/context.html#Context).
diff --git a/mkdocs.yml b/mkdocs.yml
index 47ddca54e9..86761de9d7 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -1,6 +1,6 @@
site_name: SQLMesh
-repo_url: https://github.com/TobikoData/sqlmesh
-repo_name: TobikoData/sqlmesh
+repo_url: https://github.com/SQLMesh/sqlmesh
+repo_name: SQLMesh/sqlmesh
nav:
- "Overview": index.md
- Get started:
@@ -202,7 +202,7 @@ extra:
- icon: fontawesome/solid/paper-plane
link: mailto:hello@tobikodata.com
- icon: fontawesome/brands/github
- link: https://github.com/TobikoData/sqlmesh/issues/new
+ link: https://github.com/SQLMesh/sqlmesh/issues/new
analytics:
provider: google
property: G-JXQ1R227VS
diff --git a/pdoc/cli.py b/pdoc/cli.py
index 5833c59207..9301ae0444 100755
--- a/pdoc/cli.py
+++ b/pdoc/cli.py
@@ -29,7 +29,7 @@ def mocked_import(*args, **kwargs):
opts.logo_link = "https://tobikodata.com"
opts.footer_text = "Copyright Tobiko Data Inc. 2022"
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
- opts.edit_url = ["sqlmesh=https://github.com/TobikoData/sqlmesh/tree/main/sqlmesh/"]
+ opts.edit_url = ["sqlmesh=https://github.com/SQLMesh/sqlmesh/tree/main/sqlmesh/"]
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
cli()
diff --git a/posts/virtual_data_environments.md b/posts/virtual_data_environments.md
index dc3b2cb46e..5cde9dba51 100644
--- a/posts/virtual_data_environments.md
+++ b/posts/virtual_data_environments.md
@@ -8,7 +8,7 @@ In this post, I'm going to explain why existing approaches to managing developme
I'll introduce [Virtual Data Environments](#virtual-data-environments-1) - a novel approach that provides low-cost, efficient, scalable, and safe data environments that are easy to use and manage. They significantly boost the productivity of anyone who has to create or maintain data pipelines.
-Finally, Iβm going to explain how **Virtual Data Environments** are implemented in [SQLMesh](https://github.com/TobikoData/sqlmesh) and share details on each core component involved:
+Finally, Iβm going to explain how **Virtual Data Environments** are implemented in [SQLMesh](https://github.com/SQLMesh/sqlmesh) and share details on each core component involved:
- Data [fingerprinting](#fingerprinting)
- [Automatic change categorization](#automatic-change-categorization)
- Decoupling of [physical](#physical-layer) and [virtual](#virtual-layer) layers
@@ -156,6 +156,6 @@ With **Virtual Data Environments**, SQLMesh is able to provide fully **isolated*
- Rolling back a change happens almost instantaneously since no data movement is involved and only views that are part of the **virtual layer** get updated.
- Deploying changes to production is a **virtual layer** operation, which ensures that results observed during development are exactly the same in production and that data and code are always in sync.
-To streamline deploying changes to production, our team is about to release the SQLMesh [CI/CD bot](https://github.com/TobikoData/sqlmesh/blob/main/docs/integrations/github.md), which will help automate this process.
+To streamline deploying changes to production, our team is about to release the SQLMesh [CI/CD bot](https://github.com/SQLMesh/sqlmesh/blob/main/docs/integrations/github.md), which will help automate this process.
Don't miss out - join our [Slack channel](https://tobikodata.com/slack) and stay tuned!
diff --git a/pyproject.toml b/pyproject.toml
index 029d043704..bcc69c667e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ name = "sqlmesh"
dynamic = ["version"]
description = "Next-generation data transformation framework"
readme = "README.md"
-authors = [{ name = "TobikoData Inc.", email = "engineering@tobikodata.com" }]
+authors = [{ name = "SQLMesh Contributors" }]
license = { file = "LICENSE" }
requires-python = ">= 3.9"
dependencies = [
@@ -24,7 +24,7 @@ dependencies = [
"requests",
"rich[jupyter]",
"ruamel.yaml",
- "sqlglot[rs]~=28.10.1",
+ "sqlglot~=30.4.2",
"tenacity",
"time-machine",
"json-stream"
@@ -154,8 +154,8 @@ sqlmesh_lsp = "sqlmesh.lsp.main:main"
[project.urls]
Homepage = "https://sqlmesh.com/"
Documentation = "https://sqlmesh.readthedocs.io/en/stable/"
-Repository = "https://github.com/TobikoData/sqlmesh"
-Issues = "https://github.com/TobikoData/sqlmesh/issues"
+Repository = "https://github.com/SQLMesh/sqlmesh"
+Issues = "https://github.com/SQLMesh/sqlmesh/issues"
[build-system]
requires = ["setuptools >= 61.0", "setuptools_scm"]
diff --git a/sqlmesh-technical-charter.pdf b/sqlmesh-technical-charter.pdf
new file mode 100644
index 0000000000..107f015050
Binary files /dev/null and b/sqlmesh-technical-charter.pdf differ
diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py
index 45f95d2abb..ec5acbea59 100644
--- a/sqlmesh/cli/main.py
+++ b/sqlmesh/cli/main.py
@@ -246,7 +246,7 @@ def init(
Need help?
β’ Docs: https://sqlmesh.readthedocs.io
β’ Slack: https://www.tobikodata.com/slack
-β’ GitHub: https://github.com/TobikoData/sqlmesh/issues
+β’ GitHub: https://github.com/SQLMesh/sqlmesh/issues
""")
diff --git a/sqlmesh/core/_typing.py b/sqlmesh/core/_typing.py
index 8e28312c1a..2bc69e901b 100644
--- a/sqlmesh/core/_typing.py
+++ b/sqlmesh/core/_typing.py
@@ -8,8 +8,8 @@
if t.TYPE_CHECKING:
TableName = t.Union[str, exp.Table]
SchemaName = t.Union[str, exp.Table]
- SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
- CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
+ SessionProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]]
+ CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]]
if sys.version_info >= (3, 11):
diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py
index 9f470872fe..4c90151ee4 100644
--- a/sqlmesh/core/audit/definition.py
+++ b/sqlmesh/core/audit/definition.py
@@ -67,7 +67,7 @@ class AuditMixin(AuditCommonMetaMixin):
"""
query_: ParsableSql
- defaults: t.Dict[str, exp.Expression]
+ defaults: t.Dict[str, exp.Expr]
expressions_: t.Optional[t.List[ParsableSql]]
jinja_macros: JinjaMacroRegistry
formatting: t.Optional[bool]
@@ -77,10 +77,10 @@ def query(self) -> t.Union[exp.Query, d.JinjaQuery]:
return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect))
@property
- def expressions(self) -> t.List[exp.Expression]:
+ def expressions(self) -> t.List[exp.Expr]:
if not self.expressions_:
return []
- result = []
+ result: t.List[exp.Expr] = []
for e in self.expressions_:
parsed = e.parse(self.dialect)
if not isinstance(parsed, exp.Semicolon):
@@ -95,7 +95,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]:
@field_validator("name", "dialect", mode="before", check_fields=False)
def audit_string_validator(cls: t.Type, v: t.Any) -> t.Optional[str]:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return v.name.lower()
return str(v).lower() if v is not None else None
@@ -111,9 +111,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A
if isinstance(v, dict):
dialect = get_dialect(values)
return {
- key: value
- if isinstance(value, exp.Expression)
- else d.parse_one(str(value), dialect=dialect)
+ key: value if isinstance(value, exp.Expr) else d.parse_one(str(value), dialect=dialect)
for key, value in v.items()
}
raise_config_error("Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError)
@@ -133,7 +131,7 @@ class ModelAudit(PydanticModel, AuditMixin, DbtInfoMixin, frozen=True):
blocking: bool = True
standalone: t.Literal[False] = False
query_: ParsableSql = Field(alias="query")
- defaults: t.Dict[str, exp.Expression] = {}
+ defaults: t.Dict[str, exp.Expr] = {}
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
formatting: t.Optional[bool] = Field(default=None, exclude=True)
@@ -169,7 +167,7 @@ class StandaloneAudit(_Node, AuditMixin):
blocking: bool = False
standalone: t.Literal[True] = True
query_: ParsableSql = Field(alias="query")
- defaults: t.Dict[str, exp.Expression] = {}
+ defaults: t.Dict[str, exp.Expr] = {}
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
default_catalog: t.Optional[str] = None
@@ -323,13 +321,13 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns the original list of sql expressions comprising the model definition.
Args:
include_python: Whether or not to include Python code in the rendered definition.
"""
- expressions: t.List[exp.Expression] = []
+ expressions: t.List[exp.Expr] = []
comment = None
for field_name in sorted(self.meta_fields):
field_value = getattr(self, field_name)
@@ -381,7 +379,7 @@ def meta_fields(self) -> t.Iterable[str]:
return set(AuditCommonMetaMixin.__annotations__) | set(_Node.all_field_infos())
@property
- def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]:
+ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expr]]]:
return [(self, {})]
@@ -389,7 +387,7 @@ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]
def load_audit(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
*,
path: Path = Path(),
module_path: Path = Path(),
@@ -499,7 +497,7 @@ def load_audit(
def load_multiple_audits(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
*,
path: Path = Path(),
module_path: Path = Path(),
@@ -510,7 +508,7 @@ def load_multiple_audits(
variables: t.Optional[t.Dict[str, t.Any]] = None,
project: t.Optional[str] = None,
) -> t.Generator[Audit, None, None]:
- audit_block: t.List[exp.Expression] = []
+ audit_block: t.List[exp.Expr] = []
for expression in expressions:
if isinstance(expression, d.Audit):
if audit_block:
@@ -543,7 +541,7 @@ def _raise_config_error(msg: str, path: pathlib.Path) -> None:
# mypy doesn't realize raise_config_error raises an exception
@t.no_type_check
-def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]:
+def _maybe_parse_arg_pair(e: exp.Expr) -> t.Tuple[str, exp.Expr]:
if isinstance(e, exp.EQ):
return e.left.name, e.right
diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py
index 9e3a210e5e..d930537711 100644
--- a/sqlmesh/core/config/connection.py
+++ b/sqlmesh/core/config/connection.py
@@ -34,6 +34,7 @@
ValidationInfo,
field_validator,
model_validator,
+ validation_data,
validation_error_message,
get_concrete_types_from_typehint,
)
@@ -1062,6 +1063,7 @@ class BigQueryConnectionConfig(ConnectionConfig):
job_retry_deadline_seconds: t.Optional[int] = None
priority: t.Optional[BigQueryPriority] = None
maximum_bytes_billed: t.Optional[int] = None
+ reservation: t.Optional[str] = None
concurrent_tasks: int = 1
register_comments: bool = True
@@ -1080,7 +1082,7 @@ def validate_execution_project(
v: t.Optional[str],
info: ValidationInfo,
) -> t.Optional[str]:
- if v and not info.data.get("project"):
+ if v and not validation_data(info).get("project"):
raise ConfigError(
"If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
)
@@ -1092,7 +1094,7 @@ def validate_quota_project(
v: t.Optional[str],
info: ValidationInfo,
) -> t.Optional[str]:
- if v and not info.data.get("project"):
+ if v and not validation_data(info).get("project"):
raise ConfigError(
"If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
)
@@ -1171,6 +1173,7 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]:
"job_retry_deadline_seconds",
"priority",
"maximum_bytes_billed",
+ "reservation",
}
}
@@ -2013,7 +2016,17 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
OAuth2Authentication,
)
+ auth: t.Optional[
+ t.Union[
+ BasicAuthentication,
+ KerberosAuthentication,
+ OAuth2Authentication,
+ JWTAuthentication,
+ CertificateAuthentication,
+ ]
+ ] = None
if self.method.is_basic or self.method.is_ldap:
+ assert self.password is not None # for mypy since validator already checks this
auth = BasicAuthentication(self.user, self.password)
elif self.method.is_kerberos:
if self.keytab:
@@ -2032,11 +2045,12 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
elif self.method.is_oauth:
auth = OAuth2Authentication()
elif self.method.is_jwt:
+ assert self.jwt_token is not None
auth = JWTAuthentication(self.jwt_token)
elif self.method.is_certificate:
+ assert self.client_certificate is not None
+ assert self.client_private_key is not None
auth = CertificateAuthentication(self.client_certificate, self.client_private_key)
- else:
- auth = None
return {
"auth": auth,
@@ -2328,23 +2342,20 @@ def init(cursor: t.Any) -> None:
return init
+_CONNECTION_CONFIG_EXCLUDE: t.Set[t.Type[ConnectionConfig]] = {
+ ConnectionConfig, # type: ignore[type-abstract]
+ BaseDuckDBConnectionConfig, # type: ignore[type-abstract]
+}
+
CONNECTION_CONFIG_TO_TYPE = {
# Map all subclasses of ConnectionConfig to the value of their `type_` field.
tpe.all_field_infos()["type_"].default: tpe
- for tpe in subclasses(
- __name__,
- ConnectionConfig,
- exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
- )
+ for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE)
}
DIALECT_TO_TYPE = {
tpe.all_field_infos()["type_"].default: tpe.DIALECT
- for tpe in subclasses(
- __name__,
- ConnectionConfig,
- exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
- )
+ for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE)
}
INIT_DISPLAY_INFO_TO_TYPE = {
@@ -2352,11 +2363,7 @@ def init(cursor: t.Any) -> None:
tpe.DISPLAY_ORDER,
tpe.DISPLAY_NAME,
)
- for tpe in subclasses(
- __name__,
- ConnectionConfig,
- exclude={ConnectionConfig, BaseDuckDBConnectionConfig},
- )
+ for tpe in subclasses(__name__, ConnectionConfig, exclude=_CONNECTION_CONFIG_EXCLUDE)
}
diff --git a/sqlmesh/core/config/linter.py b/sqlmesh/core/config/linter.py
index c2a40e09aa..11d700c540 100644
--- a/sqlmesh/core/config/linter.py
+++ b/sqlmesh/core/config/linter.py
@@ -34,7 +34,7 @@ def _validate_rules(cls, v: t.Any) -> t.Set[str]:
v = v.unnest().name
elif isinstance(v, (exp.Tuple, exp.Array)):
v = [e.name for e in v.expressions]
- elif isinstance(v, exp.Expression):
+ elif isinstance(v, exp.Expr):
v = v.name
return {name.lower() for name in ensure_collection(v)}
diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py
index aeefdf2557..ac41d75fe3 100644
--- a/sqlmesh/core/config/model.py
+++ b/sqlmesh/core/config/model.py
@@ -71,9 +71,9 @@ class ModelDefaultsConfig(BaseConfig):
enabled: t.Optional[t.Union[str, bool]] = None
formatting: t.Optional[t.Union[str, bool]] = None
batch_concurrency: t.Optional[int] = None
- pre_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
- post_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
- on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
+ pre_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
+ post_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
+ on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expr]]] = None
_model_kind_validator = model_kind_validator
_on_destructive_change_validator = on_destructive_change_validator
diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py
index 970defee62..9d9d1d3c79 100644
--- a/sqlmesh/core/config/scheduler.py
+++ b/sqlmesh/core/config/scheduler.py
@@ -144,9 +144,10 @@ def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str
return default_catalogs_per_gateway
-SCHEDULER_CONFIG_TO_TYPE = {
+SCHEDULER_CONFIG_TO_TYPE: t.Dict[str, t.Type[SchedulerConfig]] = {
tpe.all_field_infos()["type_"].default: tpe
for tpe in subclasses(__name__, BaseConfig, exclude={BaseConfig})
+ if issubclass(tpe, SchedulerConfig)
}
diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py
index e6b404c597..f1a7657704 100644
--- a/sqlmesh/core/context.py
+++ b/sqlmesh/core/context.py
@@ -234,7 +234,7 @@ def resolve_table(self, model_name: str) -> str:
)
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a dataframe given a sql string or sqlglot expression.
@@ -248,7 +248,7 @@ def fetchdf(
return self.engine_adapter.fetchdf(query, quote_identifiers=quote_identifiers)
def fetch_pyspark_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> PySparkDataFrame:
"""Fetches a PySpark dataframe given a sql string or sqlglot expression.
@@ -692,8 +692,11 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
if snapshot.node.project in self._projects:
uncached.add(snapshot.name)
else:
- store = self._standalone_audits if snapshot.is_audit else self._models
- store[snapshot.name] = snapshot.node # type: ignore
+ local_store = self._standalone_audits if snapshot.is_audit else self._models
+ if snapshot.name in local_store:
+ uncached.add(snapshot.name)
+ else:
+ local_store[snapshot.name] = snapshot.node # type: ignore
for model in self._models.values():
self.dag.add(model.fqn, model.depends_on)
@@ -1102,7 +1105,7 @@ def render(
execution_time: t.Optional[TimeLike] = None,
expand: t.Union[bool, t.Iterable[str]] = False,
**kwargs: t.Any,
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""Renders a model's query, expanding macros with provided kwargs, and optionally expanding referenced models.
Args:
@@ -1602,9 +1605,11 @@ def plan_builder(
backfill_models = None
models_override: t.Optional[UniqueKeyDict[str, Model]] = None
+ selected_fqns: t.Set[str] = set()
+ selected_deletion_fqns: t.Set[str] = set()
if select_models:
try:
- models_override = model_selector.select_models(
+ models_override, selected_fqns = model_selector.select_models(
select_models,
environment,
fallback_env_name=create_from or c.PROD,
@@ -1619,12 +1624,17 @@ def plan_builder(
# Only backfill selected models unless explicitly specified.
backfill_models = model_selector.expand_model_selections(select_models)
+ if not backfill_models:
+ # The selection matched nothing locally. Check whether it matched models
+ # in the deployed environment that were deleted locally.
+ selected_deletion_fqns = selected_fqns - set(self._models)
+
expanded_restate_models = None
if restate_models is not None:
expanded_restate_models = model_selector.expand_model_selections(restate_models)
if (restate_models is not None and not expanded_restate_models) or (
- backfill_models is not None and not backfill_models
+ backfill_models is not None and not backfill_models and not selected_deletion_fqns
):
raise PlanError(
"Selector did not return any models. Please check your model selection and try again."
@@ -1633,7 +1643,7 @@ def plan_builder(
if always_include_local_changes is None:
# default behaviour - if restatements are detected; we operate entirely out of state and ignore local changes
force_no_diff = restate_models is not None or (
- backfill_models is not None and not backfill_models
+ backfill_models is not None and not backfill_models and not selected_deletion_fqns
)
else:
force_no_diff = not always_include_local_changes
@@ -1857,10 +1867,10 @@ def table_diff(
self,
source: str,
target: str,
- on: t.Optional[t.List[str] | exp.Condition] = None,
+ on: t.Optional[t.List[str] | exp.Expr] = None,
skip_columns: t.Optional[t.List[str]] = None,
select_models: t.Optional[t.Collection[str]] = None,
- where: t.Optional[str | exp.Condition] = None,
+ where: t.Optional[str | exp.Expr] = None,
limit: int = 20,
show: bool = True,
show_sample: bool = True,
@@ -1919,7 +1929,7 @@ def table_diff(
raise SQLMeshError(e)
models_to_diff: t.List[
- t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Condition]]
+ t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Expr]]
] = []
models_without_grain: t.List[Model] = []
source_snapshots_to_name = {
@@ -2038,9 +2048,9 @@ def _model_diff(
target_alias: str,
limit: int,
decimals: int,
- on: t.Optional[t.List[str] | exp.Condition] = None,
+ on: t.Optional[t.List[str] | exp.Expr] = None,
skip_columns: t.Optional[t.List[str]] = None,
- where: t.Optional[str | exp.Condition] = None,
+ where: t.Optional[str | exp.Expr] = None,
show: bool = True,
temp_schema: t.Optional[str] = None,
skip_grain_check: bool = False,
@@ -2080,10 +2090,10 @@ def _table_diff(
limit: int,
decimals: int,
adapter: EngineAdapter,
- on: t.Optional[t.List[str] | exp.Condition] = None,
+ on: t.Optional[t.List[str] | exp.Expr] = None,
model: t.Optional[Model] = None,
skip_columns: t.Optional[t.List[str]] = None,
- where: t.Optional[str | exp.Condition] = None,
+ where: t.Optional[str | exp.Expr] = None,
schema_diff_ignore_case: bool = False,
) -> TableDiff:
if not on:
@@ -2341,7 +2351,7 @@ def audit(
return not errors
@python_api_analytics
- def rewrite(self, sql: str, dialect: str = "") -> exp.Expression:
+ def rewrite(self, sql: str, dialect: str = "") -> exp.Expr:
"""Rewrite a sql expression with semantic references into an executable query.
https://sqlmesh.readthedocs.io/en/latest/concepts/metrics/overview/
@@ -3043,10 +3053,17 @@ def _get_plan_default_start_end(
modified_model_names: t.Set[str],
execution_time: t.Optional[TimeLike] = None,
) -> t.Tuple[t.Optional[int], t.Optional[int]]:
- if not max_interval_end_per_model:
+ # exclude seeds so their stale interval ends does not become the default plan end date
+ # when they're the only ones that contain intervals in this plan
+ non_seed_interval_ends = {
+ model_fqn: end
+ for model_fqn, end in max_interval_end_per_model.items()
+ if model_fqn not in snapshots or not snapshots[model_fqn].is_seed
+ }
+ if not non_seed_interval_ends:
return None, None
- default_end = to_timestamp(max(max_interval_end_per_model.values()))
+ default_end = to_timestamp(max(non_seed_interval_ends.values()))
default_start: t.Optional[int] = None
# Infer the default start by finding the smallest interval start that corresponds to the default end.
for model_name in backfill_models or modified_model_names or max_interval_end_per_model:
diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py
index 07d13b1c2f..047e58609a 100644
--- a/sqlmesh/core/context_diff.py
+++ b/sqlmesh/core/context_diff.py
@@ -36,7 +36,7 @@
from sqlmesh.utils.metaprogramming import Executable # noqa
from sqlmesh.core.environment import EnvironmentStatements
-IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
+IGNORED_PACKAGES = {"sqlmesh", "sqlglot", "sqlglotc"}
class ContextDiff(PydanticModel):
diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py
index c0a48326f2..565c629789 100644
--- a/sqlmesh/core/dialect.py
+++ b/sqlmesh/core/dialect.py
@@ -14,6 +14,8 @@
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects import DuckDB, Snowflake, TSQL
import sqlglot.dialects.athena as athena
+import sqlglot.generators.athena as athena_generators
+from sqlglot.parsers.athena import AthenaTrinoParser
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
@@ -52,7 +54,7 @@ class Metric(exp.Expression):
arg_types = {"expressions": True}
-class Jinja(exp.Func):
+class Jinja(exp.Expression, exp.Func):
arg_types = {"this": True}
@@ -76,7 +78,7 @@ class MacroVar(exp.Var):
pass
-class MacroFunc(exp.Func):
+class MacroFunc(exp.Expression, exp.Func):
@property
def name(self) -> str:
return self.this.name
@@ -102,7 +104,7 @@ class DColonCast(exp.Cast):
pass
-class MetricAgg(exp.AggFunc):
+class MetricAgg(exp.Expression, exp.AggFunc):
"""Used for computing metrics."""
arg_types = {"this": True}
@@ -118,7 +120,7 @@ class StagedFilePath(exp.Expression):
arg_types = exp.Table.arg_types.copy()
-def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
+def _parse_statement(self: Parser) -> t.Optional[exp.Expr]:
if self._curr is None:
return None
@@ -152,7 +154,7 @@ def _parse_statement(self: Parser) -> t.Optional[exp.Expression]:
raise
-def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expression]:
+def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expr]:
node = self.__parse_lambda(alias=alias) # type: ignore
if isinstance(node, exp.Lambda):
node.set("this", self._parse_alias(node.this))
@@ -163,7 +165,7 @@ def _parse_id_var(
self: Parser,
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
if self._prev and self._prev.text == SQLMESH_MACRO_PREFIX and self._match(TokenType.L_BRACE):
identifier = self.__parse_id_var(any_token=any_token, tokens=tokens) # type: ignore
if not self._match(TokenType.R_BRACE):
@@ -207,12 +209,12 @@ def _parse_id_var(
else:
self.raise_error("Expecting }")
- identifier = self.expression(exp.Identifier, this=this, quoted=identifier.quoted)
+ identifier = self.expression(exp.Identifier(this=this, quoted=identifier.quoted))
return identifier
-def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expression]:
+def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expr]:
if self._prev.text != SQLMESH_MACRO_PREFIX:
return self._parse_parameter()
@@ -220,7 +222,7 @@ def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expres
index = self._index
field = self._parse_primary() or self._parse_function(functions={}) or self._parse_id_var()
- def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ def _build_macro(field: t.Optional[exp.Expr]) -> t.Optional[exp.Expr]:
if isinstance(field, exp.Func):
macro_name = field.name.upper()
if macro_name != keyword_macro and macro_name in KEYWORD_MACROS:
@@ -230,37 +232,39 @@ def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression
if isinstance(field, exp.Anonymous):
if macro_name == "DEF":
return self.expression(
- MacroDef,
- this=field.expressions[0],
- expression=field.expressions[1],
+ MacroDef(
+ this=field.expressions[0],
+ expression=field.expressions[1],
+ ),
comments=comments,
)
if macro_name == "SQL":
into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None
return self.expression(
- MacroSQL, this=field.expressions[0], into=into, comments=comments
+ MacroSQL(this=field.expressions[0], into=into), comments=comments
)
else:
field = self.expression(
- exp.Anonymous,
- this=field.sql_name(),
- expressions=list(field.args.values()),
+ exp.Anonymous(
+ this=field.sql_name(),
+ expressions=list(field.args.values()),
+ ),
comments=comments,
)
- return self.expression(MacroFunc, this=field, comments=comments)
+ return self.expression(MacroFunc(this=field), comments=comments)
if field is None:
return None
if field.is_string or (isinstance(field, exp.Identifier) and field.quoted):
return self.expression(
- MacroStrReplace, this=exp.Literal.string(field.this), comments=comments
+ MacroStrReplace(this=exp.Literal.string(field.this)), comments=comments
)
if "@" in field.this:
- return field
- return self.expression(MacroVar, this=field.this, comments=comments)
+ return field # type: ignore[return-value]
+ return self.expression(MacroVar(this=field.this), comments=comments)
if isinstance(field, (exp.Window, exp.IgnoreNulls, exp.RespectNulls)):
field.set("this", _build_macro(field.this))
@@ -273,7 +277,7 @@ def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression
KEYWORD_MACROS = {"WITH", "JOIN", "WHERE", "GROUP_BY", "HAVING", "ORDER_BY", "LIMIT"}
-def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]:
+def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expr]:
if not self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) or (
self._next and self._next.text.upper() != name.upper()
):
@@ -283,7 +287,7 @@ def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]
return _parse_macro(self, keyword_macro=name)
-def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expression]]:
+def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expr]]:
name = self._next and self._next.text.upper()
if name == "JOIN":
@@ -301,7 +305,7 @@ def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expression]]:
return ("", None)
-def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "WITH")
if not macro:
return self.__parse_with(skip_with_token=skip_with_token) # type: ignore
@@ -312,7 +316,7 @@ def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.E
def _parse_join(
self: Parser, skip_join_token: bool = False, parse_bracket: bool = False
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
index = self._index
method, side, kind = self._parse_join_parts()
macro = _parse_matching_macro(self, "JOIN")
@@ -351,7 +355,7 @@ def _parse_select(
parse_set_operation: bool = True,
consume_pipe: bool = True,
from_: t.Optional[exp.From] = None,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
select = self.__parse_select( # type: ignore
nested=nested,
table=table,
@@ -372,7 +376,7 @@ def _parse_select(
return select
-def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "WHERE")
if not macro:
return self.__parse_where(skip_where_token=skip_where_token) # type: ignore
@@ -381,7 +385,7 @@ def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp
return macro
-def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "GROUP_BY")
if not macro:
return self.__parse_group(skip_group_by_token=skip_group_by_token) # type: ignore
@@ -390,7 +394,7 @@ def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[
return macro
-def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
+def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "HAVING")
if not macro:
return self.__parse_having(skip_having_token=skip_having_token) # type: ignore
@@ -400,8 +404,8 @@ def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[e
def _parse_order(
- self: Parser, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
-) -> t.Optional[exp.Expression]:
+ self: Parser, this: t.Optional[exp.Expr] = None, skip_order_token: bool = False
+) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "ORDER_BY")
if not macro:
return self.__parse_order(this, skip_order_token=skip_order_token) # type: ignore
@@ -412,10 +416,10 @@ def _parse_order(
def _parse_limit(
self: Parser,
- this: t.Optional[exp.Expression] = None,
+ this: t.Optional[exp.Expr] = None,
top: bool = False,
skip_limit_token: bool = False,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
macro = _parse_matching_macro(self, "TOP" if top else "LIMIT")
if not macro:
return self.__parse_limit(this, top=top, skip_limit_token=skip_limit_token) # type: ignore
@@ -424,7 +428,7 @@ def _parse_limit(
return macro
-def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression]:
+def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expr]:
wrapped = self._match(TokenType.L_PAREN, advance=False)
# The base _parse_value method always constructs a Tuple instance. This is problematic when
@@ -438,11 +442,11 @@ def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression
return expr
-def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expression]:
+def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expr]:
return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser()
-def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
+def _parse_props(self: Parser) -> t.Optional[exp.Expr]:
key = self._parse_id_var(any_token=True)
if not key:
return None
@@ -460,7 +464,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
elif name == "merge_filter":
value = self._parse_conjunction()
elif self._match(TokenType.L_PAREN):
- value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality))
+ value = self.expression(exp.Tuple(expressions=self._parse_csv(self._parse_equality)))
self._match_r_paren()
else:
value = self._parse_bracket(self._parse_field(any_token=True))
@@ -469,7 +473,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
# Make sure if we get a windows path that it is converted to posix
value = exp.Literal.string(value.this.replace("\\", "/")) # type: ignore
- return self.expression(exp.Property, this=name, value=value)
+ return self.expression(exp.Property(this=name, value=value))
def _parse_types(
@@ -477,7 +481,7 @@ def _parse_types(
check_func: bool = False,
schema: bool = False,
allow_identifiers: bool = True,
-) -> t.Optional[exp.Expression]:
+) -> t.Optional[exp.Expr]:
start = self._curr
parsed_type = self.__parse_types( # type: ignore
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
@@ -496,13 +500,20 @@ def _parse_types(
#
# See: https://docs.snowflake.com/en/user-guide/querying-stage
def _parse_table_parts(
- self: Parser, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
+ self: Parser,
+ schema: bool = False,
+ is_db_reference: bool = False,
+ wildcard: bool = False,
+ fast: bool = False,
) -> exp.Table | StagedFilePath:
index = self._index
table = self.__parse_table_parts( # type: ignore
- schema=schema, is_db_reference=is_db_reference, wildcard=wildcard
+ schema=schema, is_db_reference=is_db_reference, wildcard=wildcard, fast=fast
)
+ if table is None:
+ return table # type: ignore[return-value]
+
table_arg = table.this
name = table_arg.name if isinstance(table_arg, exp.Var) else ""
@@ -526,7 +537,9 @@ def _parse_table_parts(
)
):
self._retreat(index)
- return Parser._parse_table_parts(self, schema=schema, is_db_reference=is_db_reference)
+ return Parser._parse_table_parts(
+ self, schema=schema, is_db_reference=is_db_reference, fast=fast
+ ) # type: ignore[return-value]
table_arg.replace(MacroVar(this=name[1:]))
return StagedFilePath(**table.args)
@@ -534,7 +547,7 @@ def _parse_table_parts(
return table
-def _parse_if(self: Parser) -> t.Optional[exp.Expression]:
+def _parse_if(self: Parser) -> t.Optional[exp.Expr]:
# If we fail to parse an IF function with expressions as arguments, we then try
# to parse a statement / command to support the macro @IF(condition, statement)
index = self._index
@@ -554,6 +567,10 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expression]:
if last_token.token_type == TokenType.R_PAREN:
self._tokens[-2].comments.extend(last_token.comments)
self._tokens.pop()
+ if hasattr(self, "_tokens_size"):
+ # keep _tokens_size in sync sqlglot 30.0.3 caches len(_tokens)
+ # _advance() tries to read tokens[index + 1] past the new end
+ self._tokens_size -= 1
else:
self.raise_error("Expecting )")
@@ -566,11 +583,11 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expression]:
return exp.Anonymous(this="IF", expressions=[cond, stmt])
-def _create_parser(expression_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable:
- def parse(self: Parser) -> t.Optional[exp.Expression]:
+def _create_parser(expression_type: t.Type[exp.Expr], table_keys: t.List[str]) -> t.Callable:
+ def parse(self: Parser) -> t.Optional[exp.Expr]:
from sqlmesh.core.model.kind import ModelKindName
- expressions: t.List[exp.Expression] = []
+ expressions: t.List[exp.Expr] = []
while True:
prev_property = seq_get(expressions, -1)
@@ -589,7 +606,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
key = key_expression.name.lower()
start = self._curr
- value: t.Optional[exp.Expression | str]
+ value: t.Optional[exp.Expr | str]
if key in table_keys:
value = self._parse_table_parts()
@@ -629,7 +646,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
else:
props = None
- value = self.expression(ModelKind, this=kind.value, expressions=props)
+ value = self.expression(ModelKind(this=kind.value, expressions=props))
elif key == "expression":
value = self._parse_conjunction()
elif key == "partitioned_by":
@@ -641,12 +658,12 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
else:
value = self._parse_bracket(self._parse_field(any_token=True))
- if isinstance(value, exp.Expression):
+ if isinstance(value, exp.Expr):
value.meta["sql"] = self._find_sql(start, self._prev)
- expressions.append(self.expression(exp.Property, this=key, value=value))
+ expressions.append(self.expression(exp.Property(this=key, value=value)))
- return self.expression(expression_type, expressions=expressions)
+ return self.expression(expression_type(expressions=expressions))
return parse
@@ -658,7 +675,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
}
-def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
+def _props_sql(self: Generator, expressions: t.List[exp.Expr]) -> str:
props = []
size = len(expressions)
@@ -676,7 +693,7 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
return "\n".join(props)
-def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
+def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expr]) -> str:
statements = "\n".join(
self.sql(expression)
if isinstance(expression, JinjaStatement)
@@ -697,7 +714,7 @@ def _model_kind_sql(self: Generator, expression: ModelKind) -> str:
return expression.name.upper()
-def _macro_keyword_func_sql(self: Generator, expression: exp.Expression) -> str:
+def _macro_keyword_func_sql(self: Generator, expression: exp.Expr) -> str:
name = expression.name
keyword = name.replace("_", " ")
*args, clause = expression.expressions
@@ -731,7 +748,7 @@ def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:
def format_model_expressions(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
dialect: t.Optional[str] = None,
rewrite_casts: bool = True,
**kwargs: t.Any,
@@ -752,7 +769,7 @@ def format_model_expressions(
if rewrite_casts:
- def cast_to_colon(node: exp.Expression) -> exp.Expression:
+ def cast_to_colon(node: exp.Expr) -> exp.Expr:
if isinstance(node, exp.Cast) and not any(
# Only convert CAST into :: if it doesn't have additional args set, otherwise this
# conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST)
@@ -784,8 +801,8 @@ def cast_to_colon(node: exp.Expression) -> exp.Expression:
def text_diff(
- a: t.List[exp.Expression],
- b: t.List[exp.Expression],
+ a: t.List[exp.Expr],
+ b: t.List[exp.Expr],
a_dialect: t.Optional[str] = None,
b_dialect: t.Optional[str] = None,
) -> str:
@@ -860,7 +877,7 @@ def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool:
return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos)
-def virtual_statement(statements: t.List[exp.Expression]) -> VirtualUpdateStatement:
+def virtual_statement(statements: t.List[exp.Expr]) -> VirtualUpdateStatement:
return VirtualUpdateStatement(expressions=statements)
@@ -874,7 +891,7 @@ class ChunkType(Enum):
def parse_one(
sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None
-) -> exp.Expression:
+) -> exp.Expr:
expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into)
if not expressions:
raise SQLMeshError(f"No expressions found in '{sql}'")
@@ -888,7 +905,7 @@ def parse(
default_dialect: t.Optional[str] = None,
match_dialect: bool = True,
into: t.Optional[exp.IntoType] = None,
-) -> t.List[exp.Expression]:
+) -> t.List[exp.Expr]:
"""Parse a sql string.
Supports parsing model definition.
@@ -952,10 +969,10 @@ def parse(
pos += 1
parser = dialect.parser()
- expressions: t.List[exp.Expression] = []
+ expressions: t.List[exp.Expr] = []
- def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expression]:
- parsed_expressions: t.List[t.Optional[exp.Expression]] = (
+ def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expr]:
+ parsed_expressions: t.List[t.Optional[exp.Expr]] = (
parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
)
expressions = []
@@ -966,7 +983,7 @@ def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.E
expressions.append(expression)
return expressions
- def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expression:
+ def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expr:
start, *_, end = chunk
segment = sql[start.end + 2 : end.start - 1]
factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
@@ -977,9 +994,9 @@ def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expres
def parse_virtual_statement(
chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int
- ) -> t.Tuple[t.List[exp.Expression], int]:
+ ) -> t.Tuple[t.List[exp.Expr], int]:
# For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk
- virtual_update_statements = []
+ virtual_update_statements: t.List[exp.Expr] = []
start = chunks[pos][0][0].start
while (
@@ -1031,9 +1048,9 @@ def extend_sqlglot() -> None:
# so this ensures that the extra ones it defines are also extended
if dialect == athena.Athena:
tokenizers.add(athena._TrinoTokenizer)
- parsers.add(athena._TrinoParser)
- generators.add(athena._TrinoGenerator)
- generators.add(athena._HiveGenerator)
+ parsers.add(AthenaTrinoParser)
+ generators.add(athena_generators.AthenaTrinoGenerator)
+ generators.add(athena_generators._HiveGenerator)
if hasattr(dialect, "Tokenizer"):
tokenizers.add(dialect.Tokenizer)
@@ -1251,7 +1268,7 @@ def normalize_model_name(
def find_tables(
- expression: exp.Expression, default_catalog: t.Optional[str], dialect: DialectType = None
+ expression: exp.Expr, default_catalog: t.Optional[str], dialect: DialectType = None
) -> t.Set[str]:
"""Find all tables referenced in a query.
@@ -1274,10 +1291,10 @@ def find_tables(
return expression.meta[TABLES_META]
-def add_table(node: exp.Expression, table: str) -> exp.Expression:
+def add_table(node: exp.Expr, table: str) -> exp.Expr:
"""Add a table to all columns in an expression."""
- def _transform(node: exp.Expression) -> exp.Expression:
+ def _transform(node: exp.Expr) -> exp.Expr:
if isinstance(node, exp.Column) and not node.table:
return exp.column(node.this, table=table)
if isinstance(node, exp.Identifier):
@@ -1387,7 +1404,7 @@ def normalize_and_quote(
quote_identifiers(query, dialect=dialect)
-def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | float | bool:
+def interpret_expression(e: exp.Expr) -> exp.Expr | str | int | float | bool:
if e.is_int:
return int(e.this)
if e.is_number:
@@ -1399,13 +1416,13 @@ def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | floa
def interpret_key_value_pairs(
e: exp.Tuple,
-) -> t.Dict[str, exp.Expression | str | int | float | bool]:
+) -> t.Dict[str, exp.Expr | str | int | float | bool]:
return {i.this.name: interpret_expression(i.expression) for i in e.expressions}
def extract_func_call(
- v: exp.Expression, allow_tuples: bool = False
-) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
+ v: exp.Expr, allow_tuples: bool = False
+) -> t.Tuple[str, t.Dict[str, exp.Expr]]:
kwargs = {}
if isinstance(v, exp.Anonymous):
@@ -1442,7 +1459,7 @@ def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.A
return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
if isinstance(func_calls, exp.Paren):
return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
- if isinstance(func_calls, exp.Expression):
+ if isinstance(func_calls, exp.Expr):
return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
if isinstance(func_calls, list):
function_calls = []
@@ -1474,9 +1491,7 @@ def is_meta_expression(v: t.Any) -> bool:
return isinstance(v, (Audit, Metric, Model))
-def replace_merge_table_aliases(
- expression: exp.Expression, dialect: t.Optional[str] = None
-) -> exp.Expression:
+def replace_merge_table_aliases(expression: exp.Expr, dialect: t.Optional[str] = None) -> exp.Expr:
"""
Resolves references from the "source" and "target" tables (or their DBT equivalents)
with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py
index bd84ba5276..338381549b 100644
--- a/sqlmesh/core/engine_adapter/athena.py
+++ b/sqlmesh/core/engine_adapter/athena.py
@@ -158,7 +158,7 @@ def _create_schema(
schema_name: SchemaName,
ignore_if_exists: bool,
warn_on_error: bool,
- properties: t.List[exp.Expression],
+ properties: t.List[exp.Expr],
kind: str,
) -> None:
if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)):
@@ -177,14 +177,14 @@ def _create_schema(
def _build_create_table_exp(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
**kwargs: t.Any,
) -> exp.Create:
exists = False if replace else exists
@@ -235,18 +235,18 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
table: t.Optional[exp.Table] = None,
- expression: t.Optional[exp.Expression] = None,
+ expression: t.Optional[exp.Expr] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
table_properties = table_properties or {}
is_hive = self._table_type(table_format) == "hive"
@@ -266,7 +266,7 @@ def _build_table_properties_exp(
properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description)))
if partitioned_by:
- schema_expressions: t.List[exp.Expression] = []
+ schema_expressions: t.List[exp.Expr] = []
if is_hive and target_columns_to_types:
# For Hive-style tables, you cannot include the partitioned by columns in the main set of columns
# In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well
@@ -381,7 +381,7 @@ def _is_hive_partitioned_table(self, table: exp.Table) -> bool:
raise e
def _table_location_or_raise(
- self, table_properties: t.Optional[t.Dict[str, exp.Expression]], table: exp.Table
+ self, table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table
) -> exp.LocationProperty:
location = self._table_location(table_properties, table)
if not location:
@@ -392,7 +392,7 @@ def _table_location_or_raise(
def _table_location(
self,
- table_properties: t.Optional[t.Dict[str, exp.Expression]],
+ table_properties: t.Optional[t.Dict[str, exp.Expr]],
table: exp.Table,
) -> t.Optional[exp.LocationProperty]:
base_uri: str
@@ -402,7 +402,7 @@ def _table_location(
s3_base_location_property = table_properties.pop(
"s3_base_location"
) # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause
- if isinstance(s3_base_location_property, exp.Expression):
+ if isinstance(s3_base_location_property, exp.Expr):
base_uri = s3_base_location_property.name
else:
base_uri = s3_base_location_property
@@ -419,7 +419,7 @@ def _table_location(
return exp.LocationProperty(this=exp.Literal.string(full_uri))
def _find_matching_columns(
- self, partitioned_by: t.List[exp.Expression], columns_to_types: t.Dict[str, exp.DataType]
+ self, partitioned_by: t.List[exp.Expr], columns_to_types: t.Dict[str, exp.DataType]
) -> t.List[t.Tuple[str, exp.DataType]]:
matches = []
for col in partitioned_by:
@@ -557,7 +557,7 @@ def _chunks() -> t.Iterable[t.List[t.List[str]]]:
PartitionsToDelete=[{"Values": v} for v in batch],
)
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
table = exp.to_table(table_name)
table_type = self._query_table_type(table)
diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py
index e2dbb51459..5465ea1197 100644
--- a/sqlmesh/core/engine_adapter/base.py
+++ b/sqlmesh/core/engine_adapter/base.py
@@ -236,7 +236,7 @@ def _casted_columns(
cls,
target_columns_to_types: t.Dict[str, exp.DataType],
source_columns: t.Optional[t.List[str]] = None,
- ) -> t.List[exp.Alias]:
+ ) -> t.List[exp.Expr]:
source_columns_lookup = set(source_columns or target_columns_to_types)
return [
exp.alias_(
@@ -591,7 +591,7 @@ def create_index(
def _pop_creatable_type_from_properties(
self,
- properties: t.Dict[str, exp.Expression],
+ properties: t.Dict[str, exp.Expr],
) -> t.Optional[exp.Property]:
"""Pop out the creatable_type from the properties dictionary (if exists (return it/remove it) else return none).
It also checks that none of the expressions are MATERIALIZE as that conflicts with the `materialize` parameter.
@@ -652,9 +652,9 @@ def create_managed_table(
table_name: TableName,
query: Query,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
source_columns: t.Optional[t.List[str]] = None,
@@ -964,7 +964,7 @@ def _create_table_from_source_queries(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1002,7 +1002,7 @@ def _create_table(
def _build_create_table_exp(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1203,7 +1203,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
@@ -1382,7 +1382,7 @@ def create_schema(
schema_name: SchemaName,
ignore_if_exists: bool = True,
warn_on_error: bool = True,
- properties: t.Optional[t.List[exp.Expression]] = None,
+ properties: t.Optional[t.List[exp.Expr]] = None,
) -> None:
properties = properties or []
return self._create_schema(
@@ -1398,7 +1398,7 @@ def _create_schema(
schema_name: SchemaName,
ignore_if_exists: bool,
warn_on_error: bool,
- properties: t.List[exp.Expression],
+ properties: t.List[exp.Expr],
kind: str,
) -> None:
"""Create a schema from a name or qualified table name."""
@@ -1423,7 +1423,7 @@ def drop_schema(
schema_name: SchemaName,
ignore_if_not_exists: bool = True,
cascade: bool = False,
- **drop_args: t.Dict[str, exp.Expression],
+ **drop_args: t.Dict[str, exp.Expr],
) -> None:
return self._drop_object(
name=schema_name,
@@ -1494,7 +1494,7 @@ def table_exists(self, table_name: TableName) -> bool:
except Exception:
return False
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
self.execute(exp.delete(table_name, where))
def insert_append(
@@ -1552,7 +1552,7 @@ def insert_overwrite_by_partition(
self,
table_name: TableName,
query_or_df: QueryOrDF,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -1583,10 +1583,8 @@ def insert_overwrite_by_time_partition(
query_or_df: QueryOrDF,
start: TimeLike,
end: TimeLike,
- time_formatter: t.Callable[
- [TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expression
- ],
- time_column: TimeColumn | exp.Expression | str,
+ time_formatter: t.Callable[[TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expr],
+ time_column: TimeColumn | exp.Expr | str,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
@@ -1726,7 +1724,7 @@ def _merge(
self,
target_table: TableName,
query: Query,
- on: exp.Expression,
+ on: exp.Expr,
whens: exp.Whens,
) -> None:
this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True)
@@ -1741,7 +1739,7 @@ def scd_type_2_by_time(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
@@ -1777,11 +1775,11 @@ def scd_type_2_by_column(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
- check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]],
+ check_columns: t.Union[exp.Star, t.Sequence[exp.Expr]],
invalidate_hard_deletes: bool = True,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1813,13 +1811,13 @@ def _scd_type_2(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
invalidate_hard_deletes: bool = True,
updated_at_col: t.Optional[exp.Column] = None,
- check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
+ check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None,
updated_at_as_valid_from: bool = False,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1908,7 +1906,7 @@ def remove_managed_columns(
raise SQLMeshError(
"Cannot use `updated_at_as_valid_from` without `updated_at_name` for SCD Type 2"
)
- update_valid_from_start: t.Union[str, exp.Expression] = updated_at_col
+ update_valid_from_start: t.Union[str, exp.Expr] = updated_at_col
# If using check_columns and the user doesn't always want execution_time for valid from
# then we only use epoch 0 if we are truncating the table and loading rows for the first time.
# All future new rows should have execution time.
@@ -2207,9 +2205,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -2382,7 +2380,7 @@ def get_data_objects(
def fetchone(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.Optional[t.Tuple]:
@@ -2396,7 +2394,7 @@ def fetchone(
def fetchall(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.List[t.Tuple]:
@@ -2409,7 +2407,7 @@ def fetchall(
return self.cursor.fetchall()
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
with self.transaction():
@@ -2432,7 +2430,7 @@ def _native_df_to_pandas_df(
raise NotImplementedError(f"Unable to convert {type(query_or_df)} to Pandas")
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a Pandas DataFrame from the cursor"""
import pandas as pd
@@ -2445,7 +2443,7 @@ def fetchdf(
return df
def fetch_pyspark_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> PySparkDataFrame:
"""Fetches a PySpark DataFrame from the cursor"""
raise NotImplementedError(f"Engine does not support PySpark DataFrames: {type(self)}")
@@ -2575,7 +2573,7 @@ def _is_session_active(self) -> bool:
def execute(
self,
- expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]],
+ expressions: t.Union[str, exp.Expr, t.Sequence[exp.Expr]],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = True,
track_rows_processed: bool = False,
@@ -2587,7 +2585,7 @@ def execute(
)
with self.transaction():
for e in ensure_list(expressions):
- if isinstance(e, exp.Expression):
+ if isinstance(e, exp.Expr):
self._check_identifier_length(e)
sql = self._to_sql(e, quote=quote_identifiers, **to_sql_kwargs)
else:
@@ -2597,7 +2595,7 @@ def execute(
self._log_sql(
sql,
- expression=e if isinstance(e, exp.Expression) else None,
+ expression=e if isinstance(e, exp.Expr) else None,
quote_identifiers=quote_identifiers,
)
self._execute(sql, track_rows_processed, **kwargs)
@@ -2610,7 +2608,7 @@ def _attach_correlation_id(self, sql: str) -> str:
def _log_sql(
self,
sql: str,
- expression: t.Optional[exp.Expression] = None,
+ expression: t.Optional[exp.Expr] = None,
quote_identifiers: bool = True,
) -> None:
if not logger.isEnabledFor(self._execute_log_level):
@@ -2702,7 +2700,7 @@ def temp_table(
self.drop_table(table)
def _table_or_view_properties_to_expressions(
- self, table_or_view_properties: t.Optional[t.Dict[str, exp.Expression]] = None
+ self, table_or_view_properties: t.Optional[t.Dict[str, exp.Expr]] = None
) -> t.List[exp.Property]:
"""Converts model properties (either physical or virtual) to a list of property expressions."""
if not table_or_view_properties:
@@ -2714,7 +2712,7 @@ def _table_or_view_properties_to_expressions(
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
*,
partition_interval_unit: t.Optional[IntervalUnit] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -2725,7 +2723,7 @@ def _build_partitioned_by_exp(
def _build_clustered_by_exp(
self,
- clustered_by: t.List[exp.Expression],
+ clustered_by: t.List[exp.Expr],
**kwargs: t.Any,
) -> t.Optional[exp.Cluster]:
return None
@@ -2735,17 +2733,17 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for ddl."""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -2764,12 +2762,12 @@ def _build_table_properties_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -2791,7 +2789,7 @@ def _truncate_table_comment(self, comment: str) -> str:
def _truncate_column_comment(self, comment: str) -> str:
return self._truncate_comment(comment, self.MAX_COLUMN_COMMENT_LENGTH)
- def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
+ def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
"""
Converts an expression to a SQL string. Has a set of default kwargs to apply, and then default
kwargs defined for the given dialect, and then kwargs provided by the user when defining the engine
@@ -2852,7 +2850,7 @@ def _order_projections_and_filter(
self,
query: Query,
target_columns_to_types: t.Dict[str, exp.DataType],
- where: t.Optional[exp.Expression] = None,
+ where: t.Optional[exp.Expr] = None,
coerce_types: bool = False,
) -> Query:
if not isinstance(query, exp.Query) or (
@@ -2863,7 +2861,7 @@ def _order_projections_and_filter(
query = t.cast(exp.Query, query.copy())
with_ = query.args.pop("with_", None)
- select_exprs: t.List[exp.Expression] = [
+ select_exprs: t.List[exp.Expr] = [
exp.column(c, quoted=True) for c in target_columns_to_types
]
if coerce_types and columns_to_types_all_known(target_columns_to_types):
@@ -2914,7 +2912,7 @@ def _replace_by_key(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- key: t.Sequence[exp.Expression],
+ key: t.Sequence[exp.Expr],
is_unique_key: bool,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -2922,7 +2920,11 @@ def _replace_by_key(
target_columns_to_types = self.columns(target_table)
temp_table = self._get_temp_table(target_table)
- key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key) if len(key) > 1 else key[0]
+ key_exp = (
+ exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key, dialect=self.dialect)
+ if len(key) > 1
+ else key[0]
+ )
column_names = list(target_columns_to_types or [])
with self.transaction():
@@ -3055,7 +3057,7 @@ def _select_columns(
)
)
- def _check_identifier_length(self, expression: exp.Expression) -> None:
+ def _check_identifier_length(self, expression: exp.Expr) -> None:
if self.MAX_IDENTIFIER_LENGTH is None or not isinstance(expression, exp.DDL):
return
@@ -3147,7 +3149,7 @@ def _apply_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns SQLGlot Grant expressions to apply grants to a table.
Args:
@@ -3170,7 +3172,7 @@ def _revoke_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns SQLGlot expressions to revoke grants from a table.
Args:
diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py
index 11f56da133..e2347b1263 100644
--- a/sqlmesh/core/engine_adapter/base_postgres.py
+++ b/sqlmesh/core/engine_adapter/base_postgres.py
@@ -110,7 +110,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py
index 59a56b6ace..d136445114 100644
--- a/sqlmesh/core/engine_adapter/bigquery.py
+++ b/sqlmesh/core/engine_adapter/bigquery.py
@@ -67,7 +67,7 @@ class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchema
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_CLONING = True
SUPPORTS_GRANTS = True
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("session_user")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("session_user")
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True
USE_CATALOG_IN_GRANTS = True
GRANT_INFORMATION_SCHEMA_TABLE_NAME = "OBJECT_PRIVILEGES"
@@ -140,8 +140,10 @@ def _job_params(self) -> t.Dict[str, t.Any]:
"priority", BigQueryPriority.INTERACTIVE.bigquery_constant
),
}
- if self._extra_config.get("maximum_bytes_billed"):
+ if self._extra_config.get("maximum_bytes_billed") is not None:
params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed")
+ if self._extra_config.get("reservation") is not None:
+ params["reservation"] = self._extra_config.get("reservation")
if self.correlation_id:
# BigQuery label keys must be lowercase
key = self.correlation_id.job_type.value.lower()
@@ -288,7 +290,7 @@ def create_schema(
schema_name: SchemaName,
ignore_if_exists: bool = True,
warn_on_error: bool = True,
- properties: t.List[exp.Expression] = [],
+ properties: t.List[exp.Expr] = [],
) -> None:
"""Create a schema from a name or qualified table name."""
from google.api_core.exceptions import Conflict
@@ -433,7 +435,7 @@ def alter_table(
def fetchone(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.Optional[t.Tuple]:
@@ -453,7 +455,7 @@ def fetchone(
def fetchall(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.List[t.Tuple]:
@@ -689,7 +691,7 @@ def insert_overwrite_by_partition(
self,
table_name: TableName,
query_or_df: QueryOrDF,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -803,7 +805,7 @@ def _table_name(self, table_name: TableName) -> str:
return ".".join(part.name for part in exp.to_table(table_name).parts)
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
self.execute(query, quote_identifiers=quote_identifiers)
query_job = self._query_job
@@ -863,7 +865,7 @@ def _build_description_property_exp(
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
*,
partition_interval_unit: t.Optional[IntervalUnit] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -909,16 +911,16 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if partitioned_by and (
partitioned_by_prop := self._build_partitioned_by_exp(
@@ -1025,12 +1027,12 @@ def _build_col_comment_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -1106,7 +1108,9 @@ def _execute(
else []
)
+ # Create job config
job_config = QueryJobConfig(**self._job_params, connection_properties=connection_properties)
+
self._query_job = self._db_call(
self.client.query,
query=sql,
@@ -1257,10 +1261,10 @@ def _update_clustering_key(self, operation: TableAlterClusterByOperation) -> Non
)
)
- def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr:
return exp.func("FORMAT", exp.Literal.string(f"%.{precision}f"), col)
- def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression:
+ def _normalize_nested_value(self, col: exp.Expr) -> exp.Expr:
return exp.func("TO_JSON_STRING", col, dialect=self.dialect)
@t.overload
@@ -1338,7 +1342,7 @@ def _get_current_schema(self) -> str:
def _get_bq_dataset_location(self, project: str, dataset: str) -> str:
return self._db_call(self.client.get_dataset, dataset_ref=f"{project}.{dataset}").location
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
if not table.db:
raise ValueError(
f"Table {table.sql(dialect=self.dialect)} does not have a schema (dataset)"
@@ -1392,8 +1396,8 @@ def _dcl_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
- expressions: t.List[exp.Expression] = []
+ ) -> t.List[exp.Expr]:
+ expressions: t.List[exp.Expr] = []
if not grants_config:
return expressions
diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py
index 45c22a6e55..698b2f4128 100644
--- a/sqlmesh/core/engine_adapter/clickhouse.py
+++ b/sqlmesh/core/engine_adapter/clickhouse.py
@@ -64,7 +64,7 @@ def cluster(self) -> t.Optional[str]:
# doesn't use the row index at all
def fetchone(
self,
- query: t.Union[exp.Expression, str],
+ query: t.Union[exp.Expr, str],
ignore_unsupported_errors: bool = False,
quote_identifiers: bool = False,
) -> t.Tuple:
@@ -77,13 +77,11 @@ def fetchone(
return self.cursor.fetchall()[0]
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a Pandas DataFrame from the cursor"""
return self.cursor.client.query_df(
- self._to_sql(query, quote=quote_identifiers)
- if isinstance(query, exp.Expression)
- else query,
+ self._to_sql(query, quote=quote_identifiers) if isinstance(query, exp.Expr) else query,
use_extended_dtypes=True,
)
@@ -168,7 +166,7 @@ def create_schema(
schema_name: SchemaName,
ignore_if_exists: bool = True,
warn_on_error: bool = True,
- properties: t.List[exp.Expression] = [],
+ properties: t.List[exp.Expr] = [],
) -> None:
"""Create a Clickhouse database from a name or qualified table name.
@@ -229,7 +227,7 @@ def _insert_overwrite_by_condition(
# REPLACE BY KEY: extract kwargs if present
dynamic_key = kwargs.get("dynamic_key")
if dynamic_key:
- dynamic_key_exp = t.cast(exp.Expression, kwargs.get("dynamic_key_exp"))
+ dynamic_key_exp = t.cast(exp.Expr, kwargs.get("dynamic_key_exp"))
dynamic_key_unique = t.cast(bool, kwargs.get("dynamic_key_unique"))
try:
@@ -414,7 +412,7 @@ def _replace_by_key(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- key: t.Sequence[exp.Expression],
+ key: t.Sequence[exp.Expr],
is_unique_key: bool,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -425,7 +423,11 @@ def _replace_by_key(
source_columns=source_columns,
)
- key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key) if len(key) > 1 else key[0]
+ key_exp = (
+ exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key, dialect=self.dialect)
+ if len(key) > 1
+ else key[0]
+ )
self._insert_overwrite_by_condition(
target_table,
@@ -440,7 +442,7 @@ def insert_overwrite_by_partition(
self,
table_name: TableName,
query_or_df: QueryOrDF,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
@@ -487,7 +489,7 @@ def _get_partition_ids(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -595,7 +597,7 @@ def _rename_table(
self.execute(f"RENAME TABLE {old_table_sql} TO {new_table_sql}{self._on_cluster_sql()}")
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
delete_expr = exp.delete(table_name, where)
if self.engine_run_mode.is_cluster:
delete_expr.set("cluster", exp.OnCluster(this=exp.to_identifier(self.cluster)))
@@ -649,7 +651,7 @@ def _drop_object(
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
**kwargs: t.Any,
) -> t.Optional[t.Union[exp.PartitionedByProperty, exp.Property]]:
return exp.PartitionedByProperty(
@@ -714,14 +716,14 @@ def use_server_nulls_for_unmatched_after_join(
return query
def _build_settings_property(
- self, key: str, value: exp.Expression | str | int | float
+ self, key: str, value: exp.Expr | str | int | float
) -> exp.SettingsProperty:
return exp.SettingsProperty(
expressions=[
exp.EQ(
this=exp.var(key.lower()),
expression=value
- if isinstance(value, exp.Expression)
+ if isinstance(value, exp.Expr)
else exp.Literal(this=value, is_string=isinstance(value, str)),
)
]
@@ -732,17 +734,17 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
empty_ctas: bool = False,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
table_engine = self.DEFAULT_TABLE_ENGINE
if storage_format:
@@ -809,9 +811,7 @@ def _build_table_properties_exp(
ttl = table_properties_copy.pop("TTL", None)
if ttl:
properties.append(
- exp.MergeTreeTTL(
- expressions=[ttl if isinstance(ttl, exp.Expression) else exp.var(ttl)]
- )
+ exp.MergeTreeTTL(expressions=[ttl if isinstance(ttl, exp.Expr) else exp.var(ttl)])
)
if (
@@ -845,12 +845,12 @@ def _build_table_properties_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
view_properties_copy = view_properties.copy() if view_properties else {}
diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py
index 870b946e7d..e3d029a17d 100644
--- a/sqlmesh/core/engine_adapter/databricks.py
+++ b/sqlmesh/core/engine_adapter/databricks.py
@@ -163,7 +163,7 @@ def _grant_object_kind(table_type: DataObjectType) -> str:
return "MATERIALIZED VIEW"
return "TABLE"
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
# We only care about explicitly granted privileges and not inherited ones
# if this is removed you would see grants inherited from the catalog get returned
expression = super()._get_grant_expression(table)
@@ -210,7 +210,7 @@ def query_factory() -> Query:
return [SourceQuery(query_factory=query_factory)]
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
if self.is_spark_session_connection:
@@ -223,7 +223,7 @@ def _fetch_native_df(
return self.cursor.fetchall_arrow().to_pandas()
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""
Returns a Pandas DataFrame from a query or expression.
@@ -364,10 +364,10 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py
index 3b057219e0..ebfcaa7901 100644
--- a/sqlmesh/core/engine_adapter/duckdb.py
+++ b/sqlmesh/core/engine_adapter/duckdb.py
@@ -145,7 +145,7 @@ def _get_data_objects(
for row in df.itertuples()
]
- def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr:
"""
duckdb truncates instead of rounding when casting to decimal.
@@ -163,7 +163,7 @@ def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.E
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py
index c8ef32b9da..bf4bb970a2 100644
--- a/sqlmesh/core/engine_adapter/mixins.py
+++ b/sqlmesh/core/engine_adapter/mixins.py
@@ -38,9 +38,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -58,18 +58,14 @@ def merge(
class PandasNativeFetchDFSupportMixin(EngineAdapter):
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""Fetches a Pandas DataFrame from a SQL query."""
from warnings import catch_warnings, filterwarnings
from pandas.io.sql import read_sql_query
- sql = (
- self._to_sql(query, quote=quote_identifiers)
- if isinstance(query, exp.Expression)
- else query
- )
+ sql = self._to_sql(query, quote=quote_identifiers) if isinstance(query, exp.Expr) else query
logger.debug(f"Executing SQL:\n{sql}")
with catch_warnings(), self.transaction():
filterwarnings(
@@ -87,7 +83,7 @@ class HiveMetastoreTablePropertiesMixin(EngineAdapter):
def _build_partitioned_by_exp(
self,
- partitioned_by: t.List[exp.Expression],
+ partitioned_by: t.List[exp.Expr],
*,
catalog_name: t.Optional[str] = None,
**kwargs: t.Any,
@@ -120,16 +116,16 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_format and self.dialect == "spark":
properties.append(exp.FileFormatProperty(this=exp.Var(this=table_format)))
@@ -166,12 +162,12 @@ def _build_table_properties_exp(
def _build_view_properties_exp(
self,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
"""Creates a SQLGlot table properties expression for view"""
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
if table_description:
properties.append(
@@ -194,7 +190,7 @@ def _truncate_comment(self, comment: str, length: t.Optional[int]) -> str:
class GetCurrentCatalogFromFunctionMixin(EngineAdapter):
- CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
+ CURRENT_CATALOG_EXPRESSION: exp.Expr = exp.func("current_catalog")
def get_current_catalog(self) -> t.Optional[str]:
"""Returns the catalog name of the current connection."""
@@ -240,7 +236,7 @@ def _default_precision_to_max(
def _build_create_table_exp(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -322,11 +318,11 @@ def is_destructive(self) -> bool:
return False
@property
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
return [exp.Cluster(expressions=self.cluster_key_expressions)]
@property
- def cluster_key_expressions(self) -> t.List[exp.Expression]:
+ def cluster_key_expressions(self) -> t.List[exp.Expr]:
# Note: Assumes `clustering_key` as a string like:
# - "(col_a)"
# - "(col_a, col_b)"
@@ -346,14 +342,14 @@ def is_destructive(self) -> bool:
return False
@property
- def _alter_actions(self) -> t.List[exp.Expression]:
+ def _alter_actions(self) -> t.List[exp.Expr]:
return [exp.Command(this="DROP", expression="CLUSTERING KEY")]
class ClusteredByMixin(EngineAdapter):
def _build_clustered_by_exp(
self,
- clustered_by: t.List[exp.Expression],
+ clustered_by: t.List[exp.Expr],
**kwargs: t.Any,
) -> t.Optional[exp.Cluster]:
return exp.Cluster(expressions=[c.copy() for c in clustered_by])
@@ -410,9 +406,9 @@ def logical_merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
) -> None:
"""
@@ -452,12 +448,12 @@ def concat_columns(
decimal_precision: int = 3,
timestamp_precision: int = MAX_TIMESTAMP_PRECISION,
delimiter: str = ",",
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""
Produce an expression that generates a string version of a record, that is:
- Every column converted to a string representation, joined together into a single string using the specified :delimiter
"""
- expressions_to_concat: t.List[exp.Expression] = []
+ expressions_to_concat: t.List[exp.Expr] = []
for idx, (column, type) in enumerate(columns_to_types.items()):
expressions_to_concat.append(
exp.func(
@@ -475,11 +471,11 @@ def concat_columns(
def normalize_value(
self,
- expr: exp.Expression,
+ expr: exp.Expr,
type: exp.DataType,
decimal_precision: int = 3,
timestamp_precision: int = MAX_TIMESTAMP_PRECISION,
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""
Return an expression that converts the values inside the column `col` to a normalized string
@@ -490,6 +486,7 @@ def normalize_value(
- `boolean` columns -> '1' or '0'
- NULLS -> "" (empty string)
"""
+ value: exp.Expr
if type.is_type(exp.DataType.Type.BOOLEAN):
value = self._normalize_boolean_value(expr)
elif type.is_type(*exp.DataType.INTEGER_TYPES):
@@ -512,12 +509,12 @@ def normalize_value(
return exp.cast(value, to=exp.DataType.build("VARCHAR"))
- def _normalize_nested_value(self, expr: exp.Expression) -> exp.Expression:
+ def _normalize_nested_value(self, expr: exp.Expr) -> exp.Expr:
return expr
def _normalize_timestamp_value(
- self, expr: exp.Expression, type: exp.DataType, precision: int
- ) -> exp.Expression:
+ self, expr: exp.Expr, type: exp.DataType, precision: int
+ ) -> exp.Expr:
if precision > self.MAX_TIMESTAMP_PRECISION:
raise ValueError(
f"Requested timestamp precision '{precision}' exceeds maximum supported precision: {self.MAX_TIMESTAMP_PRECISION}"
@@ -547,18 +544,18 @@ def _normalize_timestamp_value(
return expr
- def _normalize_integer_value(self, expr: exp.Expression) -> exp.Expression:
+ def _normalize_integer_value(self, expr: exp.Expr) -> exp.Expr:
return exp.cast(expr, "BIGINT")
- def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr:
return exp.cast(expr, f"DECIMAL(38,{precision})")
- def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression:
+ def _normalize_boolean_value(self, expr: exp.Expr) -> exp.Expr:
return exp.cast(expr, "INT")
class GrantsFromInfoSchemaMixin(EngineAdapter):
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("current_user")
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False
USE_CATALOG_IN_GRANTS = False
GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges"
@@ -578,8 +575,8 @@ def _dcl_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
- expressions: t.List[exp.Expression] = []
+ ) -> t.List[exp.Expr]:
+ expressions: t.List[exp.Expr] = []
if not grants_config:
return expressions
@@ -617,7 +614,7 @@ def _apply_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
return self._dcl_grants_config_expr(exp.Grant, table, grants_config, table_type)
def _revoke_grants_config_expr(
@@ -625,10 +622,10 @@ def _revoke_grants_config_expr(
table: exp.Table,
grants_config: GrantsConfig,
table_type: DataObjectType = DataObjectType.TABLE,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
return self._dcl_grants_config_expr(exp.Revoke, table, grants_config, table_type)
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
schema_identifier = table.args.get("db") or normalize_identifiers(
exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect
)
diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py
index 359d1f0818..e381c0a198 100644
--- a/sqlmesh/core/engine_adapter/mssql.py
+++ b/sqlmesh/core/engine_adapter/mssql.py
@@ -176,7 +176,7 @@ def drop_schema(
schema_name: SchemaName,
ignore_if_not_exists: bool = True,
cascade: bool = False,
- **drop_args: t.Dict[str, exp.Expression],
+ **drop_args: t.Dict[str, exp.Expr],
) -> None:
"""
MsSql doesn't support CASCADE clause and drops schemas unconditionally.
@@ -205,9 +205,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -401,7 +401,7 @@ def _get_data_objects(
for row in dataframe.itertuples()
]
- def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
+ def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
sql = super()._to_sql(expression, quote=quote, **kwargs)
return f"{sql};"
@@ -448,7 +448,7 @@ def _insert_overwrite_by_condition(
**kwargs,
)
- def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
+ def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None:
if where == exp.true():
# "A TRUNCATE TABLE operation can be rolled back within a transaction."
# ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks
diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py
index 31773d6c63..66759dc440 100644
--- a/sqlmesh/core/engine_adapter/mysql.py
+++ b/sqlmesh/core/engine_adapter/mysql.py
@@ -73,7 +73,7 @@ def drop_schema(
schema_name: SchemaName,
ignore_if_not_exists: bool = True,
cascade: bool = False,
- **drop_args: t.Dict[str, exp.Expression],
+ **drop_args: t.Dict[str, exp.Expr],
) -> None:
# MySQL doesn't support CASCADE clause and drops schemas unconditionally.
super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False)
diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py
index 3dd108cf91..6794169322 100644
--- a/sqlmesh/core/engine_adapter/postgres.py
+++ b/sqlmesh/core/engine_adapter/postgres.py
@@ -40,7 +40,7 @@ class PostgresEngineAdapter(
MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63
SUPPORTS_QUERY_EXECUTION_TRACKING = True
GRANT_INFORMATION_SCHEMA_TABLE_NAME = "role_table_grants"
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.column("current_role")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.column("current_role")
SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
@@ -73,7 +73,7 @@ class PostgresEngineAdapter(
}
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
"""
`read_sql_query` when using psycopg will result on a hanging transaction that must be committed
@@ -113,9 +113,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py
index 03dc89053e..c2a27954cd 100644
--- a/sqlmesh/core/engine_adapter/redshift.py
+++ b/sqlmesh/core/engine_adapter/redshift.py
@@ -143,7 +143,7 @@ def cursor(self) -> t.Any:
return cursor
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
"""Fetches a Pandas DataFrame from the cursor"""
import pandas as pd
@@ -217,7 +217,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
@@ -227,7 +227,7 @@ def create_view(
swap tables out from under views. Therefore, we create the view as non-binding.
"""
no_schema_binding = True
- if isinstance(query_or_df, exp.Expression):
+ if isinstance(query_or_df, exp.Expr):
# We can't include NO SCHEMA BINDING if the query has a recursive CTE
has_recursive_cte = any(
w.args.get("recursive", False) for w in query_or_df.find_all(exp.With)
@@ -367,9 +367,9 @@ def merge(
target_table: TableName,
source_table: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
when_matched: t.Optional[exp.Whens] = None,
- merge_filter: t.Optional[exp.Expression] = None,
+ merge_filter: t.Optional[exp.Expr] = None,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
@@ -400,12 +400,12 @@ def _merge(
self,
target_table: TableName,
query: Query,
- on: exp.Expression,
+ on: exp.Expr,
whens: exp.Whens,
) -> None:
# Redshift does not support table aliases in the target table of a MERGE statement.
# So we must use the actual table name instead of an alias, as we do with the source table.
- def resolve_target_table(expression: exp.Expression) -> exp.Expression:
+ def resolve_target_table(expression: exp.Expr) -> exp.Expr:
if (
isinstance(expression, exp.Column)
and expression.table.upper() == MERGE_TARGET_ALIAS
@@ -436,7 +436,7 @@ def resolve_target_table(expression: exp.Expression) -> exp.Expression:
track_rows_processed=True,
)
- def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression:
+ def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr:
# Redshift is finicky. It truncates when the data is already in a table, but rounds when the data is generated as part of a SELECT.
#
# The following works:
diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py
index a8eabe070d..09c530b8f3 100644
--- a/sqlmesh/core/engine_adapter/snowflake.py
+++ b/sqlmesh/core/engine_adapter/snowflake.py
@@ -83,7 +83,7 @@ class SnowflakeEngineAdapter(
SNOWPARK = "snowpark"
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SUPPORTS_GRANTS = True
- CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("CURRENT_ROLE")
+ CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("CURRENT_ROLE")
USE_CATALOG_IN_GRANTS = True
@contextlib.contextmanager
@@ -95,7 +95,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
if isinstance(warehouse, str):
warehouse = exp.to_identifier(warehouse)
- if not isinstance(warehouse, exp.Expression):
+ if not isinstance(warehouse, exp.Expr):
raise SQLMeshError(f"Invalid warehouse: '{warehouse}'")
warehouse_exp = quote_identifiers(
@@ -189,7 +189,7 @@ def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -225,9 +225,9 @@ def create_managed_table(
table_name: TableName,
query: Query,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
source_columns: t.Optional[t.List[str]] = None,
@@ -278,7 +278,7 @@ def create_view(
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
- view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ view_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
source_columns: t.Optional[t.List[str]] = None,
**create_kwargs: t.Any,
) -> None:
@@ -311,16 +311,16 @@ def _build_table_properties_exp(
catalog_name: t.Optional[str] = None,
table_format: t.Optional[str] = None,
storage_format: t.Optional[str] = None,
- partitioned_by: t.Optional[t.List[exp.Expression]] = None,
+ partitioned_by: t.Optional[t.List[exp.Expr]] = None,
partition_interval_unit: t.Optional[IntervalUnit] = None,
- clustered_by: t.Optional[t.List[exp.Expression]] = None,
- table_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
+ clustered_by: t.Optional[t.List[exp.Expr]] = None,
+ table_properties: t.Optional[t.Dict[str, exp.Expr]] = None,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
table_kind: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
- properties: t.List[exp.Expression] = []
+ properties: t.List[exp.Expr] = []
# TODO: there is some overlap with the base class and other engine adapters
# we need a better way of filtering table properties relevent to the current engine
@@ -471,7 +471,7 @@ def cleanup() -> None:
return [SourceQuery(query_factory=query_factory, cleanup_func=cleanup)]
def _fetch_native_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> DF:
import pandas as pd
from snowflake.connector.errors import NotSupportedError
@@ -561,7 +561,7 @@ def _get_data_objects(
for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples()
]
- def _get_grant_expression(self, table: exp.Table) -> exp.Expression:
+ def _get_grant_expression(self, table: exp.Table) -> exp.Expr:
# Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides
# the default catalog in their connection config. This doesn't though update catalogs in strings like when querying
# the information schema. So we need to manually replace those here.
@@ -586,7 +586,7 @@ def set_current_catalog(self, catalog: str) -> None:
def set_current_schema(self, schema: str) -> None:
self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema)))
- def _normalize_catalog(self, expression: exp.Expression) -> exp.Expression:
+ def _normalize_catalog(self, expression: exp.Expr) -> exp.Expr:
# note: important to use self._default_catalog instead of the self.default_catalog property
# otherwise we get RecursionError: maximum recursion depth exceeded
# because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc
@@ -604,7 +604,7 @@ def unquote_and_lower(identifier: str) -> str:
self._default_catalog, dialect=self.dialect
)
- def catalog_rewriter(node: exp.Expression) -> exp.Expression:
+ def catalog_rewriter(node: exp.Expr) -> exp.Expr:
if isinstance(node, exp.Table):
if node.catalog:
# only replace the catalog on the model with the target catalog if the two are functionally equivalent
@@ -621,7 +621,7 @@ def catalog_rewriter(node: exp.Expression) -> exp.Expression:
expression = expression.transform(catalog_rewriter)
return expression
- def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
+ def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str:
return super()._to_sql(
expression=self._normalize_catalog(expression), quote=quote, **kwargs
)
diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py
index 5216b0a329..9199aa3bcd 100644
--- a/sqlmesh/core/engine_adapter/spark.py
+++ b/sqlmesh/core/engine_adapter/spark.py
@@ -340,12 +340,12 @@ def _get_temp_table(
return table
def fetchdf(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> pd.DataFrame:
return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas()
def fetch_pyspark_df(
- self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
+ self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False
) -> PySparkDataFrame:
return self._ensure_pyspark_df(
self._fetch_native_df(query, quote_identifiers=quote_identifiers)
@@ -437,7 +437,7 @@ def _native_df_to_pandas_df(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py
index 89470728f2..00acddb26c 100644
--- a/sqlmesh/core/engine_adapter/trino.py
+++ b/sqlmesh/core/engine_adapter/trino.py
@@ -129,7 +129,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
yield
return
- if not isinstance(authorization, exp.Expression):
+ if not isinstance(authorization, exp.Expr):
authorization = exp.Literal.string(authorization)
if not authorization.is_string:
@@ -326,13 +326,13 @@ def _scd_type_2(
self,
target_table: TableName,
source_table: QueryOrDF,
- unique_key: t.Sequence[exp.Expression],
+ unique_key: t.Sequence[exp.Expr],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
invalidate_hard_deletes: bool = True,
updated_at_col: t.Optional[exp.Column] = None,
- check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
+ check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None,
updated_at_as_valid_from: bool = False,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -409,7 +409,7 @@ def _create_schema(
schema_name: SchemaName,
ignore_if_exists: bool,
warn_on_error: bool,
- properties: t.List[exp.Expression],
+ properties: t.List[exp.Expr],
kind: str,
) -> None:
if mapped_location := self._schema_location(schema_name):
@@ -426,7 +426,7 @@ def _create_schema(
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
- expression: t.Optional[exp.Expression],
+ expression: t.Optional[exp.Expr],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
diff --git a/sqlmesh/core/environment.py b/sqlmesh/core/environment.py
index 4a1f417468..4594dc120d 100644
--- a/sqlmesh/core/environment.py
+++ b/sqlmesh/core/environment.py
@@ -56,7 +56,8 @@ def _sanitize_name(cls, v: str) -> str:
@classmethod
def _validate_boolean_field(cls, v: t.Any, info: ValidationInfo) -> bool:
if v is None:
- return info.field_name == "normalize_name"
+ # Pydantic 2.13+ sets field_name to None during model_validate_json()
+ return (info.field_name or "") == "normalize_name"
return bool(v)
@t.overload
diff --git a/sqlmesh/core/lineage.py b/sqlmesh/core/lineage.py
index 777a2a7d9a..8363979034 100644
--- a/sqlmesh/core/lineage.py
+++ b/sqlmesh/core/lineage.py
@@ -16,7 +16,7 @@
from sqlmesh.core.model import Model
-CACHE: t.Dict[str, t.Tuple[int, exp.Expression, Scope]] = {}
+CACHE: t.Dict[str, t.Tuple[int, exp.Expr, Scope]] = {}
def lineage(
@@ -25,8 +25,8 @@ def lineage(
trim_selects: bool = True,
**kwargs: t.Any,
) -> Node:
- query = None
- scope = None
+ query: t.Optional[exp.Expr] = None
+ scope: t.Optional[Scope] = None
if model.name in CACHE:
obj_id, query, scope = CACHE[model.name]
diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py
index 4547ac0528..8dc4172f9f 100644
--- a/sqlmesh/core/linter/rules/builtin.py
+++ b/sqlmesh/core/linter/rules/builtin.py
@@ -318,4 +318,5 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return None
-BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude={Rule}))
+_RULE_EXCLUDE: t.Set[t.Type[Rule]] = {Rule} # type: ignore[type-abstract]
+BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude=_RULE_EXCLUDE))
diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py
index 4b7b1bac02..cb951b4f9e 100644
--- a/sqlmesh/core/loader.py
+++ b/sqlmesh/core/loader.py
@@ -840,7 +840,8 @@ def _load_linting_rules(self) -> RuleSet:
if os.path.getsize(path):
self._track_file(path)
module = import_python_file(path, self.config_path)
- module_rules = subclasses(module.__name__, Rule, exclude={Rule})
+ _rule_exclude: t.Set[t.Type[Rule]] = {Rule} # type: ignore[type-abstract]
+ module_rules = subclasses(module.__name__, Rule, exclude=_rule_exclude)
for user_rule in module_rules:
user_rules[user_rule.name] = user_rule
diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py
index af7c344081..9370bffdeb 100644
--- a/sqlmesh/core/macros.py
+++ b/sqlmesh/core/macros.py
@@ -110,7 +110,7 @@ def _macro_sql(sql: str, into: t.Optional[str] = None) -> str:
return f"self.parse_one({', '.join(args)})"
-def _macro_func_sql(self: Generator, e: exp.Expression) -> str:
+def _macro_func_sql(self: Generator, e: exp.Expr) -> str:
func = e.this
if isinstance(func, exp.Anonymous):
@@ -178,7 +178,7 @@ def __init__(
schema: t.Optional[MappingSchema] = None,
runtime_stage: RuntimeStage = RuntimeStage.LOADING,
resolve_table: t.Optional[t.Callable[[str | exp.Table], str]] = None,
- resolve_tables: t.Optional[t.Callable[[exp.Expression], exp.Expression]] = None,
+ resolve_tables: t.Optional[t.Callable[[exp.Expr], exp.Expr]] = None,
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
default_catalog: t.Optional[str] = None,
path: t.Optional[Path] = None,
@@ -237,7 +237,7 @@ def __init__(
def send(
self, name: str, *args: t.Any, **kwargs: t.Any
- ) -> t.Union[None, exp.Expression, t.List[exp.Expression]]:
+ ) -> t.Union[None, exp.Expr, t.List[exp.Expr]]:
func = self.macros.get(normalize_macro_name(name))
if not callable(func):
@@ -253,14 +253,12 @@ def send(
+ format_evaluated_code_exception(e, self.python_env)
)
- def transform(
- self, expression: exp.Expression
- ) -> exp.Expression | t.List[exp.Expression] | None:
+ def transform(self, expression: exp.Expr) -> exp.Expr | t.List[exp.Expr] | None:
changed = False
def evaluate_macros(
- node: exp.Expression,
- ) -> exp.Expression | t.List[exp.Expression] | None:
+ node: exp.Expr,
+ ) -> exp.Expr | t.List[exp.Expr] | None:
nonlocal changed
if isinstance(node, MacroVar):
@@ -281,14 +279,10 @@ def evaluate_macros(
value = self.locals.get(var_name, variables.get(var_name))
if isinstance(value, list):
return exp.convert(
- tuple(
- self.transform(v) if isinstance(v, exp.Expression) else v for v in value
- )
+ tuple(self.transform(v) if isinstance(v, exp.Expr) else v for v in value)
)
- return exp.convert(
- self.transform(value) if isinstance(value, exp.Expression) else value
- )
+ return exp.convert(self.transform(value) if isinstance(value, exp.Expr) else value)
if isinstance(node, exp.Identifier) and "@" in node.this:
text = self.template(node.this, {})
if node.this != text:
@@ -300,7 +294,9 @@ def evaluate_macros(
return node
transformed = exp.replace_tree(
- expression.copy(), evaluate_macros, prune=lambda n: isinstance(n, exp.Lambda)
+ expression.copy(),
+ evaluate_macros, # type: ignore[arg-type]
+ prune=lambda n: isinstance(n, exp.Lambda),
)
if changed:
@@ -311,7 +307,7 @@ def evaluate_macros(
self.parse_one(node.sql(dialect=self.dialect, copy=False))
for node in transformed
]
- if isinstance(transformed, exp.Expression):
+ if isinstance(transformed, exp.Expr):
return self.parse_one(transformed.sql(dialect=self.dialect, copy=False))
return transformed
@@ -339,7 +335,7 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
}
return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping))
- def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
+ def evaluate(self, node: MacroFunc) -> exp.Expr | t.List[exp.Expr] | None:
if isinstance(node, MacroDef):
if isinstance(node.expression, exp.Lambda):
_, fn = _norm_var_arg_lambda(self, node.expression)
@@ -353,7 +349,7 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
return node
if isinstance(node, (MacroSQL, MacroStrReplace)):
- result: t.Optional[exp.Expression | t.List[exp.Expression]] = exp.convert(
+ result: t.Optional[exp.Expr | t.List[exp.Expr]] = exp.convert(
self.eval_expression(node)
)
else:
@@ -421,7 +417,7 @@ def eval_expression(self, node: t.Any) -> t.Any:
Returns:
The return value of the evaled Python Code.
"""
- if not isinstance(node, exp.Expression):
+ if not isinstance(node, exp.Expr):
return node
code = node.sql()
try:
@@ -434,8 +430,8 @@ def eval_expression(self, node: t.Any) -> t.Any:
)
def parse_one(
- self, sql: str | exp.Expression, into: t.Optional[exp.IntoType] = None, **opts: t.Any
- ) -> exp.Expression:
+ self, sql: str | exp.Expr, into: t.Optional[exp.IntoType] = None, **opts: t.Any
+ ) -> exp.Expr:
"""Parses the given SQL string and returns a syntax tree for the first
parsed SQL statement.
@@ -497,7 +493,7 @@ def resolve_table(self, table: str | exp.Table) -> str:
)
return self._resolve_table(table)
- def resolve_tables(self, query: exp.Expression) -> exp.Expression:
+ def resolve_tables(self, query: exp.Expr) -> exp.Expr:
"""Resolves queries with references to SQLMesh model names to their physical tables."""
if not self._resolve_tables:
raise SQLMeshError(
@@ -588,7 +584,7 @@ def variables(self) -> t.Dict[str, t.Any]:
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}),
}
- def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:
+ def _coerce(self, expr: exp.Expr, typ: t.Any, strict: bool = False) -> t.Any:
"""Coerces the given expression to the specified type on a best-effort basis."""
return _coerce(expr, typ, self.dialect, self._path, strict)
@@ -648,8 +644,8 @@ def _norm_var_arg_lambda(
"""
def substitute(
- node: exp.Expression, args: t.Dict[str, exp.Expression]
- ) -> exp.Expression | t.List[exp.Expression] | None:
+ node: exp.Expr, args: t.Dict[str, exp.Expr]
+ ) -> exp.Expr | t.List[exp.Expr] | None:
if isinstance(node, (exp.Identifier, exp.Var)):
if not isinstance(node.parent, exp.Column):
name = node.name.lower()
@@ -798,8 +794,8 @@ def filter_(evaluator: MacroEvaluator, *args: t.Any) -> t.List[t.Any]:
def _optional_expression(
evaluator: MacroEvaluator,
condition: exp.Condition,
- expression: exp.Expression,
-) -> t.Optional[exp.Expression]:
+ expression: exp.Expr,
+) -> t.Optional[exp.Expr]:
"""Inserts expression when the condition is True
The following examples express the usage of this function in the context of the macros which wrap it.
@@ -864,7 +860,7 @@ def star(
suffix: exp.Literal = exp.Literal.string(""),
quote_identifiers: exp.Boolean = exp.true(),
except_: t.Union[exp.Array, exp.Tuple] = exp.Tuple(expressions=[]),
-) -> t.List[exp.Alias]:
+) -> t.List[exp.Expr]:
"""Returns a list of projections for the given relation.
Args:
@@ -939,7 +935,7 @@ def star(
@macro()
def generate_surrogate_key(
evaluator: MacroEvaluator,
- *fields: exp.Expression,
+ *fields: exp.Expr,
hash_function: exp.Literal = exp.Literal.string("MD5"),
) -> exp.Func:
"""Generates a surrogate key (string) for the given fields.
@@ -956,7 +952,7 @@ def generate_surrogate_key(
>>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql, dialect="bigquery")).sql("bigquery")
"SELECT SHA256(CONCAT(COALESCE(CAST(a AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS STRING), '_sqlmesh_surrogate_key_null_'))) FROM foo"
"""
- string_fields: t.List[exp.Expression] = []
+ string_fields: t.List[exp.Expr] = []
for i, field in enumerate(fields):
if i > 0:
string_fields.append(exp.Literal.string("|"))
@@ -980,7 +976,7 @@ def generate_surrogate_key(
@macro()
-def safe_add(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
+def safe_add(_: MacroEvaluator, *fields: exp.Expr) -> exp.Case:
"""Adds numbers together, substitutes nulls for 0s and only returns null if all fields are null.
Example:
@@ -998,7 +994,7 @@ def safe_add(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
@macro()
-def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
+def safe_sub(_: MacroEvaluator, *fields: exp.Expr) -> exp.Case:
"""Subtract numbers, substitutes nulls for 0s and only returns null if all fields are null.
Example:
@@ -1016,7 +1012,7 @@ def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
@macro()
-def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expression) -> exp.Div:
+def safe_div(_: MacroEvaluator, numerator: exp.Expr, denominator: exp.Expr) -> exp.Div:
"""Divides numbers, returns null if the denominator is 0.
Example:
@@ -1032,7 +1028,7 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr
@macro()
def union(
evaluator: MacroEvaluator,
- *args: exp.Expression,
+ *args: exp.Expr,
) -> exp.Query:
"""Returns a UNION of the given tables. Only choosing columns that have the same name and type.
@@ -1107,10 +1103,10 @@ def union(
@macro()
def haversine_distance(
_: MacroEvaluator,
- lat1: exp.Expression,
- lon1: exp.Expression,
- lat2: exp.Expression,
- lon2: exp.Expression,
+ lat1: exp.Expr,
+ lon1: exp.Expr,
+ lat2: exp.Expr,
+ lon2: exp.Expr,
unit: exp.Literal = exp.Literal.string("mi"),
) -> exp.Mul:
"""Returns the haversine distance between two points.
@@ -1150,17 +1146,17 @@ def haversine_distance(
def pivot(
evaluator: MacroEvaluator,
column: SQL,
- values: t.List[exp.Expression],
+ values: t.List[exp.Expr],
alias: bool = True,
- agg: exp.Expression = exp.Literal.string("SUM"),
- cmp: exp.Expression = exp.Literal.string("="),
- prefix: exp.Expression = exp.Literal.string(""),
- suffix: exp.Expression = exp.Literal.string(""),
+ agg: exp.Expr = exp.Literal.string("SUM"),
+ cmp: exp.Expr = exp.Literal.string("="),
+ prefix: exp.Expr = exp.Literal.string(""),
+ suffix: exp.Expr = exp.Literal.string(""),
then_value: SQL = SQL("1"),
else_value: SQL = SQL("0"),
quote: bool = True,
distinct: bool = False,
-) -> t.List[exp.Expression]:
+) -> t.List[exp.Expr]:
"""Returns a list of projections as a result of pivoting the given column on the given values.
Example:
@@ -1173,14 +1169,14 @@ def pivot(
>>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql)).sql("bigquery")
"SELECT SUM(CASE WHEN a = 'v' THEN tv ELSE 0 END) AS v_sfx"
"""
- aggregates: t.List[exp.Expression] = []
+ aggregates: t.List[exp.Expr] = []
for value in values:
proj = f"{agg.name}("
if distinct:
proj += "DISTINCT "
proj += f"CASE WHEN {column} {cmp.name} {value.sql(evaluator.dialect)} THEN {then_value} ELSE {else_value} END) "
- node = evaluator.parse_one(proj)
+ node: exp.Expr = evaluator.parse_one(proj)
if alias:
node = node.as_(
@@ -1196,7 +1192,7 @@ def pivot(
@macro("AND")
-def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> exp.Condition:
+def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expr]) -> exp.Condition:
"""Returns an AND statement filtering out any NULL expressions."""
conditions = [e for e in expressions if not isinstance(e, exp.Null)]
@@ -1207,7 +1203,7 @@ def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) ->
@macro("OR")
-def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> exp.Condition:
+def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expr]) -> exp.Condition:
"""Returns an OR statement filtering out any NULL expressions."""
conditions = [e for e in expressions if not isinstance(e, exp.Null)]
@@ -1219,8 +1215,8 @@ def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) ->
@macro("VAR")
def var(
- evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None
-) -> exp.Expression:
+ evaluator: MacroEvaluator, var_name: exp.Expr, default: t.Optional[exp.Expr] = None
+) -> exp.Expr:
"""Returns the value of a variable or the default value if the variable is not set."""
if not var_name.is_string:
raise SQLMeshError(f"Invalid variable name '{var_name.sql()}'. Expected a string literal.")
@@ -1230,8 +1226,8 @@ def var(
@macro("BLUEPRINT_VAR")
def blueprint_var(
- evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None
-) -> exp.Expression:
+ evaluator: MacroEvaluator, var_name: exp.Expr, default: t.Optional[exp.Expr] = None
+) -> exp.Expr:
"""Returns the value of a blueprint variable or the default value if the variable is not set."""
if not var_name.is_string:
raise SQLMeshError(
@@ -1244,8 +1240,8 @@ def blueprint_var(
@macro()
def deduplicate(
evaluator: MacroEvaluator,
- relation: exp.Expression,
- partition_by: t.List[exp.Expression],
+ relation: exp.Expr,
+ partition_by: t.List[exp.Expr],
order_by: t.List[str],
) -> exp.Query:
"""Returns a QUERY to deduplicate rows within a table
@@ -1301,9 +1297,9 @@ def deduplicate(
@macro()
def date_spine(
evaluator: MacroEvaluator,
- datepart: exp.Expression,
- start_date: exp.Expression,
- end_date: exp.Expression,
+ datepart: exp.Expr,
+ start_date: exp.Expr,
+ end_date: exp.Expr,
) -> exp.Select:
"""Returns a query that produces a date spine with the given datepart, and range of start_date and end_date. Useful for joining as a date lookup table.
@@ -1491,7 +1487,7 @@ def _coerce(
"""Coerces the given expression to the specified type on a best-effort basis."""
base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'."
try:
- if typ is None or typ is t.Any or not isinstance(expr, exp.Expression):
+ if typ is None or typ is t.Any or not isinstance(expr, exp.Expr):
return expr
base = t.get_origin(typ) or typ
@@ -1503,7 +1499,7 @@ def _coerce(
except Exception:
pass
raise SQLMeshError(base_err_msg)
- if base is SQL and isinstance(expr, exp.Expression):
+ if base is SQL and isinstance(expr, exp.Expr):
return expr.sql(dialect)
if base is t.Literal:
@@ -1528,7 +1524,7 @@ def _coerce(
if isinstance(expr, base):
return expr
- if issubclass(base, exp.Expression):
+ if issubclass(base, exp.Expr):
d = Dialect.get_or_raise(dialect)
into = base if base in d.parser_class.EXPRESSION_PARSERS else None
if into is None:
@@ -1603,7 +1599,7 @@ def _convert_sql(v: t.Any, dialect: DialectType) -> t.Any:
except Exception:
pass
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
if (isinstance(v, exp.Column) and not v.table) or (
isinstance(v, exp.Identifier) or v.is_string
):
diff --git a/sqlmesh/core/metric/definition.py b/sqlmesh/core/metric/definition.py
index dd11cfd38d..6119a883ed 100644
--- a/sqlmesh/core/metric/definition.py
+++ b/sqlmesh/core/metric/definition.py
@@ -10,13 +10,13 @@
from sqlmesh.core.node import str_or_exp_to_str
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.errors import ConfigError
-from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
+from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, validation_data
MeasureAndDimTables = t.Tuple[str, t.Tuple[str, ...]]
def load_metric_ddl(
- expression: exp.Expression, dialect: t.Optional[str], path: Path = Path(), **kwargs: t.Any
+ expression: exp.Expr, dialect: t.Optional[str], path: Path = Path(), **kwargs: t.Any
) -> MetricMeta:
"""Returns a MetricMeta from raw Metric DDL."""
if not isinstance(expression, d.Metric):
@@ -70,7 +70,7 @@ class MetricMeta(PydanticModel, frozen=True):
name: str
dialect: str
- expression: exp.Expression
+ expression: exp.Expr
description: t.Optional[str] = None
owner: t.Optional[str] = None
@@ -87,11 +87,11 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]:
return str_or_exp_to_str(v)
@field_validator("expression", mode="before")
- def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expression:
+ def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expr:
if isinstance(v, str):
- dialect = info.data.get("dialect")
+ dialect = validation_data(info).get("dialect")
return d.parse_one(v, dialect=dialect)
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return v
return v
@@ -139,7 +139,7 @@ def to_metric(
class Metric(MetricMeta, frozen=True):
- expanded: exp.Expression
+ expanded: exp.Expr
@property
def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]:
@@ -150,7 +150,7 @@ def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]:
return {
t.cast(
exp.AggFunc,
- t.cast(exp.Expression, agg.parent).transform(
+ t.cast(exp.Expr, agg.parent).transform(
lambda node: (
exp.column(node.this, table=remove_namespace(node))
if isinstance(node, exp.Column) and node.table
@@ -162,7 +162,7 @@ def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]:
}
@property
- def formula(self) -> exp.Expression:
+ def formula(self) -> exp.Expr:
"""Returns the post aggregation formula of a metric.
For simple metrics it is just the metric name. For derived metrics,
@@ -181,7 +181,7 @@ def _raise_metric_config_error(msg: str, path: Path) -> None:
raise ConfigError(f"{msg}. '{path}'")
-def _get_measure_and_dim_tables(expression: exp.Expression) -> MeasureAndDimTables:
+def _get_measure_and_dim_tables(expression: exp.Expr) -> MeasureAndDimTables:
"""Finds all the table references in a metric definition.
Additionally ensure than the first table returned is the 'measure' or numeric value being aggregated.
@@ -190,7 +190,7 @@ def _get_measure_and_dim_tables(expression: exp.Expression) -> MeasureAndDimTabl
tables = {}
measure_table = None
- def is_measure(node: exp.Expression) -> bool:
+ def is_measure(node: exp.Expr) -> bool:
parent = node.parent
if isinstance(parent, exp.AggFunc) and node.arg_key == "this":
diff --git a/sqlmesh/core/metric/rewriter.py b/sqlmesh/core/metric/rewriter.py
index bbdc6c6135..6c9ec429a8 100644
--- a/sqlmesh/core/metric/rewriter.py
+++ b/sqlmesh/core/metric/rewriter.py
@@ -34,13 +34,13 @@ def __init__(
self.join_type = join_type
self.semantic_name = f"{semantic_schema}.{semantic_table}"
- def rewrite(self, expression: exp.Expression) -> exp.Expression:
+ def rewrite(self, expression: exp.Expr) -> exp.Expr:
for select in list(expression.find_all(exp.Select)):
self._expand(select)
return expression
- def _build_sources(self, projections: t.List[exp.Expression]) -> SourceAggsAndJoins:
+ def _build_sources(self, projections: t.List[exp.Expr]) -> SourceAggsAndJoins:
sources: SourceAggsAndJoins = {}
for projection in projections:
@@ -78,7 +78,7 @@ def _expand(self, select: exp.Select) -> None:
explicit_joins = {exp.table_name(join.this): join for join in select.args.pop("joins", [])}
for i, (name, (aggs, joins)) in enumerate(sources.items()):
- source: exp.Expression = exp.to_table(name)
+ source: exp.Expr = exp.to_table(name)
table_name = remove_namespace(name)
if not isinstance(source, exp.Select):
@@ -110,7 +110,7 @@ def _expand(self, select: exp.Select) -> None:
copy=False,
)
- for node in find_all_in_scope(query, (exp.Column, exp.TableAlias)):
+ for node in find_all_in_scope(query, exp.Column, exp.TableAlias): # type: ignore[arg-type,var-annotated]
if isinstance(node, exp.Column):
if node.table in mapping:
node.set("table", exp.to_identifier(mapping[node.table]))
@@ -123,7 +123,7 @@ def _add_joins(
source: exp.Select,
name: str,
joins: t.Dict[str, t.Optional[exp.Join]],
- group_by: t.List[exp.Expression],
+ group_by: t.List[exp.Expr],
mapping: t.Dict[str, str],
) -> exp.Select:
grain = [e.copy() for e in group_by]
@@ -177,7 +177,7 @@ def _add_joins(
return source.select(*grain, copy=False).group_by(*grain, copy=False)
-def _replace_table(node: exp.Expression, table: str, base_alias: str) -> exp.Expression:
+def _replace_table(node: exp.Expr, table: str, base_alias: str) -> exp.Expr:
for column in find_all_in_scope(node, exp.Column):
if column.table == base_alias:
column.args["table"] = exp.to_identifier(table)
@@ -185,11 +185,11 @@ def _replace_table(node: exp.Expression, table: str, base_alias: str) -> exp.Exp
def rewrite(
- sql: str | exp.Expression,
+ sql: str | exp.Expr,
graph: ReferenceGraph,
metrics: t.Dict[str, Metric],
dialect: t.Optional[str] = "",
-) -> exp.Expression:
+) -> exp.Expr:
rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect)
return optimize(
diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py
index 774bfa402b..1f038c5d79 100644
--- a/sqlmesh/core/model/cache.py
+++ b/sqlmesh/core/model/cache.py
@@ -81,7 +81,7 @@ def get(self, name: str, entry_id: str = "") -> t.List[Model]:
@dataclass
class OptimizedQueryCacheEntry:
- optimized_rendered_query: t.Optional[exp.Expression]
+ optimized_rendered_query: t.Optional[exp.Query]
renderer_violations: t.Optional[t.Dict[type[Rule], t.Any]]
diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py
index 9e117b56fb..c75531afb8 100644
--- a/sqlmesh/core/model/common.py
+++ b/sqlmesh/core/model/common.py
@@ -21,7 +21,13 @@
prepare_env,
serialize_env,
)
-from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect
+from sqlmesh.utils.pydantic import (
+ PydanticModel,
+ ValidationInfo,
+ field_validator,
+ get_dialect,
+ validation_data,
+)
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@@ -33,8 +39,8 @@
def make_python_env(
expressions: t.Union[
- exp.Expression,
- t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]],
+ exp.Expr,
+ t.List[t.Union[exp.Expr, t.Tuple[exp.Expr, bool]]],
],
jinja_macro_references: t.Optional[t.Set[MacroReference]],
module_path: Path,
@@ -71,7 +77,7 @@ def make_python_env(
visited_macro_funcs: t.Set[int] = set()
def _is_metadata_var(
- name: str, expression: exp.Expression, appears_in_metadata_expression: bool
+ name: str, expression: exp.Expr, appears_in_metadata_expression: bool
) -> t.Optional[bool]:
is_metadata_so_far = used_variables.get(name, True)
if is_metadata_so_far is False:
@@ -202,7 +208,7 @@ def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool:
def _extract_macro_func_variable_references(
- macro_func: exp.Expression,
+ macro_func: exp.Expr,
is_metadata: bool,
) -> t.Tuple[t.Set[str], t.Dict[int, bool], t.Set[int]]:
var_references = set()
@@ -255,7 +261,7 @@ def _add_variables_to_python_env(
# - appear in metadata-only expressions, such as `audits (...)`, virtual statements, etc
# - appear in the ASTs or definitions of metadata-only macros
#
- # See also: https://github.com/TobikoData/sqlmesh/pull/4936#issuecomment-3136339936,
+ # See also: https://github.com/SQLMesh/sqlmesh/pull/4936#issuecomment-3136339936,
# specifically the "Terminology" and "Observations" section.
metadata_used_variables = {
var_name for var_name, is_metadata in used_variables.items() if is_metadata
@@ -275,7 +281,7 @@ def _add_variables_to_python_env(
if overlapping_variables := (non_metadata_used_variables & metadata_used_variables):
raise ConfigError(
f"Variables {', '.join(overlapping_variables)} are both metadata and non-metadata, "
- "which is unexpected. Please file an issue at https://github.com/TobikoData/sqlmesh/issues/new."
+ "which is unexpected. Please file an issue at https://github.com/SQLMesh/sqlmesh/issues/new."
)
metadata_variables = {
@@ -292,12 +298,12 @@ def _add_variables_to_python_env(
if blueprint_variables:
metadata_blueprint_variables = {
- k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
+ k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expr) else v
for k, v in blueprint_variables.items()
if k in metadata_used_variables
}
blueprint_variables = {
- k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
+ k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expr) else v
for k, v in blueprint_variables.items()
if k in non_metadata_used_variables
}
@@ -469,9 +475,9 @@ def single_value_or_tuple(values: t.Sequence) -> exp.Identifier | exp.Tuple:
def parse_expression(
cls: t.Type,
- v: t.Union[t.List[str], t.List[exp.Expression], str, exp.Expression, t.Callable, None],
+ v: t.Union[t.List[str], t.List[exp.Expr], str, exp.Expr, t.Callable, None],
info: t.Optional[ValidationInfo],
-) -> t.List[exp.Expression] | exp.Expression | t.Callable | None:
+) -> t.List[exp.Expr] | exp.Expr | t.Callable | None:
"""Helper method to deserialize SQLGlot expressions in Pydantic Models."""
if v is None:
return None
@@ -479,11 +485,11 @@ def parse_expression(
if callable(v):
return v
- dialect = info.data.get("dialect") if info else ""
+ dialect = validation_data(info).get("dialect") if info else ""
if isinstance(v, list):
return [
- e if isinstance(e, exp.Expression) else d.parse_one(e, dialect=dialect)
+ e if isinstance(e, exp.Expr) else d.parse_one(e, dialect=dialect) # type: ignore[misc]
for e in v
if not isinstance(e, exp.Semicolon)
]
@@ -498,7 +504,7 @@ def parse_expression(
def parse_bool(v: t.Any) -> bool:
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
if not isinstance(v, exp.Boolean):
from sqlglot.optimizer.simplify import simplify
@@ -519,12 +525,12 @@ def parse_properties(
if v is None:
return v
- dialect = info.data.get("dialect") if info else ""
+ dialect = validation_data(info).get("dialect") if info else ""
if isinstance(v, str):
v = d.parse_one(v, dialect=dialect)
if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)):
- eq_expressions: t.List[exp.Expression] = (
+ eq_expressions: t.List[exp.Expr] = (
[v.unnest()] if isinstance(v, exp.Paren) else v.expressions
)
@@ -557,8 +563,9 @@ def default_catalog(cls: t.Type, v: t.Any) -> t.Optional[str]:
def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[str]]:
- dialect = info.data.get("dialect")
- default_catalog = info.data.get("default_catalog")
+ data = validation_data(info)
+ dialect = data.get("dialect")
+ default_catalog = data.get("default_catalog")
if isinstance(v, exp.Paren):
v = v.unnest()
@@ -665,18 +672,18 @@ class ParsableSql(PydanticModel):
sql: str
transaction: t.Optional[bool] = None
- _parsed: t.Optional[exp.Expression] = None
+ _parsed: t.Optional[exp.Expr] = None
_parsed_dialect: t.Optional[str] = None
- def parse(self, dialect: str) -> exp.Expression:
+ def parse(self, dialect: str) -> exp.Expr:
if self._parsed is None or self._parsed_dialect != dialect:
self._parsed = d.parse_one(self.sql, dialect=dialect)
self._parsed_dialect = dialect
- return self._parsed
+ return self._parsed # type: ignore[return-value]
@classmethod
def from_parsed_expression(
- cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False
+ cls, parsed_expression: exp.Expr, dialect: str, use_meta_sql: bool = False
) -> ParsableSql:
sql = (
parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect)
@@ -697,7 +704,7 @@ def _validate_parsable_sql(
return v
if isinstance(v, str):
return ParsableSql(sql=v)
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return ParsableSql.from_parsed_expression(
v, get_dialect(info.data), use_meta_sql=False
)
@@ -707,7 +714,7 @@ def _validate_parsable_sql(
ParsableSql(sql=s)
if isinstance(s, str)
else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False)
- if isinstance(s, exp.Expression)
+ if isinstance(s, exp.Expr)
else ParsableSql.parse_obj(s)
for s in v
]
diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py
index 73452cc165..328b763f9f 100644
--- a/sqlmesh/core/model/decorator.py
+++ b/sqlmesh/core/model/decorator.py
@@ -193,7 +193,7 @@ def model(
)
rendered_name = rendered_fields["name"]
- if isinstance(rendered_name, exp.Expression):
+ if isinstance(rendered_name, exp.Expr):
rendered_fields["name"] = rendered_name.sql(dialect=dialect)
rendered_defaults = (
diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py
index 831b52a44e..5b3a656a54 100644
--- a/sqlmesh/core/model/definition.py
+++ b/sqlmesh/core/model/definition.py
@@ -215,7 +215,7 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Returns the original list of sql expressions comprising the model definition.
Args:
@@ -366,7 +366,7 @@ def render_pre_statements(
engine_adapter: t.Optional[EngineAdapter] = None,
inside_transaction: t.Optional[bool] = True,
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Renders pre-statements for a model.
Pre-statements are statements that preceded the model's SELECT query.
@@ -413,7 +413,7 @@ def render_post_statements(
engine_adapter: t.Optional[EngineAdapter] = None,
inside_transaction: t.Optional[bool] = True,
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
"""Renders post-statements for a model.
Post-statements are statements that follow after the model's SELECT query.
@@ -460,7 +460,7 @@ def render_on_virtual_update(
deployability_index: t.Optional[DeployabilityIndex] = None,
engine_adapter: t.Optional[EngineAdapter] = None,
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
return self._render_statements(
self.on_virtual_update,
start=start,
@@ -552,15 +552,15 @@ def render_audit_query(
return rendered_query
@property
- def pre_statements(self) -> t.List[exp.Expression]:
+ def pre_statements(self) -> t.List[exp.Expr]:
return self._get_parsed_statements("pre_statements_")
@property
- def post_statements(self) -> t.List[exp.Expression]:
+ def post_statements(self) -> t.List[exp.Expr]:
return self._get_parsed_statements("post_statements_")
@property
- def on_virtual_update(self) -> t.List[exp.Expression]:
+ def on_virtual_update(self) -> t.List[exp.Expr]:
return self._get_parsed_statements("on_virtual_update_")
@property
@@ -572,7 +572,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]:
if isinstance(s, d.MacroDef)
]
- def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]:
+ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expr]:
value = getattr(self, attr_name)
if not value:
return []
@@ -587,9 +587,9 @@ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]:
def _render_statements(
self,
- statements: t.Iterable[exp.Expression],
+ statements: t.Iterable[exp.Expr],
**kwargs: t.Any,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
rendered = (
self._statement_renderer(statement).render(**kwargs)
for statement in statements
@@ -597,7 +597,7 @@ def _render_statements(
)
return [r for expressions in rendered if expressions for r in expressions]
- def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
+ def _statement_renderer(self, expression: exp.Expr) -> ExpressionRenderer:
expression_key = id(expression)
if expression_key not in self._statement_renderer_cache:
self._statement_renderer_cache[expression_key] = ExpressionRenderer(
@@ -631,7 +631,7 @@ def render_signals(
The list of rendered expressions.
"""
- def _render(e: exp.Expression) -> str | int | float | bool:
+ def _render(e: exp.Expr) -> str | int | float | bool:
rendered_exprs = (
self._create_renderer(e).render(start=start, end=end, execution_time=execution_time)
or []
@@ -676,7 +676,7 @@ def render_merge_filter(
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
execution_time: t.Optional[TimeLike] = None,
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Expr]:
if self.merge_filter is None:
return None
rendered_exprs = (
@@ -690,9 +690,9 @@ def render_merge_filter(
return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect)
def _render_properties(
- self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any
+ self, properties: t.Dict[str, exp.Expr] | SessionProperties, **render_kwargs: t.Any
) -> t.Dict[str, t.Any]:
- def _render(expression: exp.Expression) -> exp.Expression | None:
+ def _render(expression: exp.Expr) -> exp.Expr | None:
# note: we use the _statement_renderer instead of _create_renderer because it sets model_fqn which
# in turn makes @this_model available in the evaluation context
rendered_exprs = self._statement_renderer(expression).render(**render_kwargs)
@@ -714,7 +714,7 @@ def _render(expression: exp.Expression) -> exp.Expression | None:
return {
k: rendered
for k, v in properties.items()
- if (rendered := (_render(v) if isinstance(v, exp.Expression) else v))
+ if (rendered := (_render(v) if isinstance(v, exp.Expr) else v))
}
def render_physical_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]:
@@ -726,7 +726,7 @@ def render_virtual_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any
def render_session_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]:
return self._render_properties(properties=self.session_properties, **render_kwargs)
- def _create_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
+ def _create_renderer(self, expression: exp.Expr) -> ExpressionRenderer:
return ExpressionRenderer(
expression,
self.dialect,
@@ -822,7 +822,7 @@ def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMA
def convert_to_time_column(
self, time: TimeLike, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
- ) -> exp.Expression:
+ ) -> exp.Expr:
"""Convert a TimeLike object to the same time format and type as the model's time column."""
if self.time_column:
if columns_to_types is None:
@@ -970,7 +970,7 @@ def validate_definition(self) -> None:
col.name
for expr in values
for col in t.cast(
- exp.Expression, exp.maybe_parse(expr, dialect=self.dialect)
+ exp.Expr, exp.maybe_parse(expr, dialect=self.dialect)
).find_all(exp.Column)
]
@@ -1158,11 +1158,7 @@ def _audit_metadata_hash_values(self) -> t.List[str]:
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
metadata.append(audit_name)
- if audit_name in BUILT_IN_AUDITS:
- for arg_name, arg_value in audit_args.items():
- metadata.append(arg_name)
- metadata.append(gen(arg_value))
- else:
+ if audit_name not in BUILT_IN_AUDITS:
audit = self.audit_definitions[audit_name]
metadata.extend(
[
@@ -1172,6 +1168,9 @@ def _audit_metadata_hash_values(self) -> t.List[str]:
str(audit.blocking),
]
)
+ for arg_name, arg_value in audit_args.items():
+ metadata.append(arg_name)
+ metadata.append(gen(arg_value))
return metadata
@@ -1266,7 +1265,7 @@ def _additional_metadata(self) -> t.List[str]:
return additional_metadata
- def _is_metadata_statement(self, statement: exp.Expression) -> bool:
+ def _is_metadata_statement(self, statement: exp.Expr) -> bool:
if isinstance(statement, d.MacroDef):
return True
if isinstance(statement, d.MacroFunc):
@@ -1295,7 +1294,7 @@ def full_depends_on(self) -> t.Set[str]:
return self._full_depends_on
@property
- def partitioned_by(self) -> t.List[exp.Expression]:
+ def partitioned_by(self) -> t.List[exp.Expr]:
"""Columns to partition the model by, including the time column if it is not already included."""
if self.time_column and not self._is_time_column_in_partitioned_by:
# This allows the user to opt out of automatic time_column injection
@@ -1323,7 +1322,7 @@ def partition_interval_unit(self) -> t.Optional[IntervalUnit]:
return None
@property
- def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]:
+ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expr]]]:
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
audits_by_name = {**BUILT_IN_AUDITS, **self.audit_definitions}
@@ -1422,8 +1421,8 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
- result = super().render_definition(
+ ) -> t.List[exp.Expr]:
+ result: t.List[exp.Expr] = super().render_definition(
include_python=include_python, include_defaults=include_defaults
)
@@ -1946,7 +1945,7 @@ def render_definition(
include_python: bool = True,
include_defaults: bool = False,
render_query: bool = False,
- ) -> t.List[exp.Expression]:
+ ) -> t.List[exp.Expr]:
# Ignore the provided value for the include_python flag, since the Pyhon model's
# definition without Python code is meaningless.
return super().render_definition(
@@ -2001,7 +2000,7 @@ class AuditResult(PydanticModel):
"""The model this audit is for."""
count: t.Optional[int] = None
"""The number of records returned by the audit query. This could be None if the audit was skipped."""
- query: t.Optional[exp.Expression] = None
+ query: t.Optional[exp.Expr] = None
"""The rendered query used by the audit. This could be None if the audit was skipped."""
skipped: bool = False
"""Whether or not the audit was blocking. This can be overriden by the user."""
@@ -2009,7 +2008,7 @@ class AuditResult(PydanticModel):
class EvaluatableSignals(PydanticModel):
- signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]
+ signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expr]]]
"""A mapping of signal names to the kwargs passed to the signal."""
python_env: t.Dict[str, Executable]
"""The Python environment that should be used to evaluated the rendered signal calls."""
@@ -2054,7 +2053,7 @@ def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t.
def create_models_from_blueprints(
- gateway: t.Optional[str | exp.Expression],
+ gateway: t.Optional[str | exp.Expr],
blueprints: t.Any,
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
loader: t.Callable[..., Model],
@@ -2065,7 +2064,9 @@ def create_models_from_blueprints(
**loader_kwargs: t.Any,
) -> t.List[Model]:
model_blueprints: t.List[Model] = []
+ original_default_catalog = loader_kwargs.get("default_catalog")
for blueprint in _extract_blueprints(blueprints, path):
+ loader_kwargs["default_catalog"] = original_default_catalog
blueprint_variables = _extract_blueprint_variables(blueprint, path)
if gateway:
@@ -2083,12 +2084,15 @@ def create_models_from_blueprints(
else:
gateway_name = None
- if (
- default_catalog_per_gateway
- and gateway_name
- and (catalog := default_catalog_per_gateway.get(gateway_name)) is not None
- ):
- loader_kwargs["default_catalog"] = catalog
+ if default_catalog_per_gateway and gateway_name:
+ catalog = default_catalog_per_gateway.get(gateway_name)
+ if catalog is not None:
+ loader_kwargs["default_catalog"] = catalog
+ else:
+ # Gateway exists but has no entry in the dict (e.g., catalog-unsupported
+ # engines like ClickHouse). Clear the default catalog so the global
+ # default from the primary gateway doesn't leak into this model's name.
+ loader_kwargs["default_catalog"] = None
model_blueprints.append(
loader(
@@ -2105,7 +2109,7 @@ def create_models_from_blueprints(
def load_sql_based_models(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
path: Path = Path(),
module_path: Path = Path(),
@@ -2113,8 +2117,8 @@ def load_sql_based_models(
default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None,
**loader_kwargs: t.Any,
) -> t.List[Model]:
- gateway: t.Optional[exp.Expression] = None
- blueprints: t.Optional[exp.Expression] = None
+ gateway: t.Optional[exp.Expr] = None
+ blueprints: t.Optional[exp.Expr] = None
model_meta = seq_get(expressions, 0)
for prop in (isinstance(model_meta, d.Model) and model_meta.expressions) or []:
@@ -2160,7 +2164,7 @@ def load_sql_based_models(
def load_sql_based_model(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
*,
defaults: t.Optional[t.Dict[str, t.Any]] = None,
path: t.Optional[Path] = None,
@@ -2306,7 +2310,7 @@ def load_sql_based_model(
if kind_prop.name.lower() == "merge_filter":
meta_fields["kind"].expressions[idx] = unrendered_merge_filter
- if isinstance(meta_fields.get("dialect"), exp.Expression):
+ if isinstance(meta_fields.get("dialect"), exp.Expr):
meta_fields["dialect"] = meta_fields["dialect"].name
# The name of the model will be inferred from its path relative to `models/`, if it's not explicitly specified
@@ -2367,7 +2371,7 @@ def load_sql_based_model(
def create_sql_model(
name: TableName,
- query: t.Optional[exp.Expression],
+ query: t.Optional[exp.Expr],
**kwargs: t.Any,
) -> Model:
"""Creates a SQL model.
@@ -2492,7 +2496,7 @@ def create_python_model(
)
depends_on = {
dep.sql(dialect=dialect)
- for dep in t.cast(t.List[exp.Expression], depends_on_rendered)[0].expressions
+ for dep in t.cast(t.List[exp.Expr], depends_on_rendered)[0].expressions
}
used_variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables}
@@ -2597,7 +2601,7 @@ def _create_model(
if not issubclass(klass, SqlModel):
defaults.pop("optimize_query", None)
- statements: t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]] = []
+ statements: t.List[t.Union[exp.Expr, t.Tuple[exp.Expr, bool]]] = []
if "query" in kwargs:
statements.append(kwargs["query"])
@@ -2636,11 +2640,11 @@ def _create_model(
if isinstance(property_values, exp.Tuple):
statements.extend(property_values.expressions)
- if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expression):
+ if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expr):
statements.append(kwargs["kind"].merge_filter)
jinja_macro_references, referenced_variables = extract_macro_references_and_variables(
- *(gen(e if isinstance(e, exp.Expression) else e[0]) for e in statements)
+ *(gen(e if isinstance(e, exp.Expr) else e[0]) for e in statements)
)
if jinja_macros:
@@ -2687,7 +2691,7 @@ def _create_model(
model.audit_definitions.update(audit_definitions)
# Any macro referenced in audits or signals needs to be treated as metadata-only
- statements.extend((audit.query, True) for audit in audit_definitions.values())
+ statements.extend((audit.query, True) for audit in audit_definitions.values()) # type: ignore[misc]
# Ensure that all audits referenced in the model are defined
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
@@ -2743,14 +2747,14 @@ def _create_model(
def _split_sql_model_statements(
- expressions: t.List[exp.Expression],
+ expressions: t.List[exp.Expr],
path: t.Optional[Path],
dialect: t.Optional[str] = None,
) -> t.Tuple[
- t.Optional[exp.Expression],
- t.List[exp.Expression],
- t.List[exp.Expression],
- t.List[exp.Expression],
+ t.Optional[exp.Expr],
+ t.List[exp.Expr],
+ t.List[exp.Expr],
+ t.List[exp.Expr],
UniqueKeyDict[str, ModelAudit],
]:
"""Extracts the SELECT query from a sequence of expressions.
@@ -2811,8 +2815,8 @@ def _split_sql_model_statements(
def _resolve_properties(
default: t.Optional[t.Dict[str, t.Any]],
- provided: t.Optional[exp.Expression | t.Dict[str, t.Any]],
-) -> t.Optional[exp.Expression]:
+ provided: t.Optional[exp.Expr | t.Dict[str, t.Any]],
+) -> t.Optional[exp.Expr]:
if isinstance(provided, dict):
properties = {k: exp.Literal.string(k).eq(v) for k, v in provided.items()}
elif provided:
@@ -2834,7 +2838,7 @@ def _resolve_properties(
return None
-def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expression:
+def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expr:
return exp.Tuple(
expressions=[
exp.Anonymous(
@@ -2849,16 +2853,16 @@ def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> ex
)
-def _is_projection(expr: exp.Expression) -> bool:
+def _is_projection(expr: exp.Expr) -> bool:
parent = expr.parent
return isinstance(parent, exp.Select) and expr.arg_key == "expressions"
-def _single_expr_or_tuple(values: t.Sequence[exp.Expression]) -> exp.Expression | exp.Tuple:
+def _single_expr_or_tuple(values: t.Sequence[exp.Expr]) -> exp.Expr | exp.Tuple:
return values[0] if len(values) == 1 else exp.Tuple(expressions=values)
-def _refs_to_sql(values: t.Any) -> exp.Expression:
+def _refs_to_sql(values: t.Any) -> exp.Expr:
return exp.Tuple(expressions=values)
@@ -2874,7 +2878,7 @@ def render_meta_fields(
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Dict[str, t.Any]:
def render_field_value(value: t.Any) -> t.Any:
- if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value):
+ if isinstance(value, exp.Expr) or (isinstance(value, str) and "@" in value):
expression = exp.maybe_parse(value, dialect=dialect)
rendered_expr = render_expression(
expression=expression,
@@ -3011,7 +3015,7 @@ def parse_defaults_properties(
def render_expression(
- expression: exp.Expression,
+ expression: exp.Expr,
module_path: Path,
path: t.Optional[Path],
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
@@ -3020,7 +3024,7 @@ def render_expression(
variables: t.Optional[t.Dict[str, t.Any]] = None,
default_catalog: t.Optional[str] = None,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
-) -> t.Optional[t.List[exp.Expression]]:
+) -> t.Optional[t.List[exp.Expr]]:
meta_python_env = make_python_env(
expressions=expression,
jinja_macro_references=None,
@@ -3092,8 +3096,8 @@ def get_model_name(path: Path) -> str:
# function applied to time column when automatically used for partitioning in INCREMENTAL_BY_TIME_RANGE models
def clickhouse_partition_func(
- column: exp.Expression, columns_to_types: t.Optional[t.Dict[str, exp.DataType]]
-) -> exp.Expression:
+ column: exp.Expr, columns_to_types: t.Optional[t.Dict[str, exp.DataType]]
+) -> exp.Expr:
# `toMonday()` function accepts a Date or DateTime type column
col_type = (columns_to_types and columns_to_types.get(column.name)) or exp.DataType.build(
diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py
index 9abaa9c650..7ae1ef8c0d 100644
--- a/sqlmesh/core/model/kind.py
+++ b/sqlmesh/core/model/kind.py
@@ -279,7 +279,7 @@ def model_kind_name(self) -> t.Optional[ModelKindName]:
return self.name
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
kwargs["expressions"] = expressions
return d.ModelKind(this=self.name.value.upper(), **kwargs)
@@ -294,7 +294,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
class TimeColumn(PydanticModel):
- column: exp.Expression
+ column: exp.Expr
format: t.Optional[str] = None
@classmethod
@@ -306,7 +306,7 @@ def _time_column_validator(v: t.Any, info: ValidationInfo) -> TimeColumn:
@field_validator("column", mode="before")
@classmethod
- def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression:
+ def _column_validator(cls, v: t.Union[str, exp.Expr]) -> exp.Expr:
if not v:
raise ConfigError("Time Column cannot be empty.")
if isinstance(v, str):
@@ -314,14 +314,14 @@ def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression:
return v
@property
- def expression(self) -> exp.Expression:
+ def expression(self) -> exp.Expr:
"""Convert this pydantic model into a time_column SQLGlot expression."""
if not self.format:
return self.column
return exp.Tuple(expressions=[self.column, exp.Literal.string(self.format)])
- def to_expression(self, dialect: str) -> exp.Expression:
+ def to_expression(self, dialect: str) -> exp.Expr:
"""Convert this pydantic model into a time_column SQLGlot expression."""
if not self.format:
return self.column
@@ -346,7 +346,7 @@ def create(cls, v: t.Any, dialect: str) -> Self:
exp.column(column_expr) if isinstance(column_expr, exp.Identifier) else column_expr
)
format = v.expressions[1].name if len(v.expressions) > 1 else None
- elif isinstance(v, exp.Expression):
+ elif isinstance(v, exp.Expr):
column = exp.column(v) if isinstance(v, exp.Identifier) else v
format = None
elif isinstance(v, str):
@@ -400,7 +400,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -444,7 +444,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -473,7 +473,7 @@ class IncrementalByTimeRangeKind(_IncrementalBy):
_time_column_validator = TimeColumn.validator()
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -513,7 +513,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
)
unique_key: SQLGlotListOfFields
when_matched: t.Optional[exp.Whens] = None
- merge_filter: t.Optional[exp.Expression] = None
+ merge_filter: t.Optional[exp.Expr] = None
batch_concurrency: t.Literal[1] = 1
@field_validator("when_matched", mode="before")
@@ -543,9 +543,9 @@ def _when_matched_validator(
@field_validator("merge_filter", mode="before")
def _merge_filter_validator(
cls,
- v: t.Optional[exp.Expression],
+ v: t.Optional[exp.Expr],
info: ValidationInfo,
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Expr]:
if v is None:
return v
@@ -568,7 +568,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -590,7 +590,7 @@ class IncrementalByPartitionKind(_Incremental):
disable_restatement: SQLGlotBool = False
@field_validator("forward_only", mode="before")
- def _forward_only_validator(cls, v: t.Union[bool, exp.Expression]) -> t.Literal[True]:
+ def _forward_only_validator(cls, v: t.Union[bool, exp.Expr]) -> t.Literal[True]:
if v is not True:
raise ConfigError(
"Do not specify the `forward_only` configuration key - INCREMENTAL_BY_PARTITION models are always forward_only."
@@ -606,7 +606,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -640,7 +640,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -669,7 +669,7 @@ def supports_python_models(self) -> bool:
return False
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -690,7 +690,7 @@ class SeedKind(_ModelKind):
def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]:
if v is None or isinstance(v, CsvSettings):
return v
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
tuple_exp = parse_properties(cls, v, None)
if not tuple_exp:
return None
@@ -700,7 +700,7 @@ def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]:
return v
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
"""Convert the seed kind into a SQLGlot expression."""
return super().to_expression(
@@ -756,13 +756,16 @@ class _SCDType2Kind(_Incremental):
@field_validator("time_data_type", mode="before")
@classmethod
- def _time_data_type_validator(
- cls, v: t.Union[str, exp.Expression], values: t.Any
- ) -> exp.Expression:
- if isinstance(v, exp.Expression) and not isinstance(v, exp.DataType):
+ def _time_data_type_validator(cls, v: t.Union[str, exp.Expr], values: t.Any) -> exp.Expr:
+ if isinstance(v, exp.Expr) and not isinstance(v, exp.DataType):
v = v.name
dialect = get_dialect(values)
data_type = exp.DataType.build(v, dialect=dialect)
+ # Clear meta["sql"] (set by our parser extension) so the pydantic encoder
+ # uses dialect-aware rendering: e.sql(dialect=meta["dialect"]). Without this,
+ # the raw SQL text takes priority, which can be wrong for dialect-normalized
+ # types (e.g., default "TIMESTAMP" should render as "DATETIME" in BigQuery).
+ data_type.meta.pop("sql", None)
data_type.meta["dialect"] = dialect
return data_type
@@ -783,7 +786,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
gen(self.valid_to_name),
str(self.invalidate_hard_deletes),
self.time_data_type.sql(self.dialect),
- gen(self.batch_size) if self.batch_size is not None else None,
+ str(self.batch_size) if self.batch_size is not None else None,
]
@property
@@ -795,7 +798,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -835,7 +838,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -871,7 +874,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -922,7 +925,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -1005,7 +1008,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]:
]
def to_expression(
- self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
+ self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
@@ -1142,7 +1145,7 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M
)
return kind_type(**props)
- name = (v.name if isinstance(v, exp.Expression) else str(v)).upper()
+ name = (v.name if isinstance(v, exp.Expr) else str(v)).upper()
return model_kind_type_from_name(name)(name=name) # type: ignore
diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py
index c48b7d1524..d5a93c459c 100644
--- a/sqlmesh/core/model/meta.py
+++ b/sqlmesh/core/model/meta.py
@@ -44,13 +44,14 @@
list_of_fields_validator,
model_validator,
get_dialect,
+ validation_data,
)
if t.TYPE_CHECKING:
from sqlmesh.core._typing import CustomMaterializationProperties, SessionProperties
from sqlmesh.core.engine_adapter._typing import GrantsConfig
-FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]]
+FunctionCall = t.Tuple[str, t.Dict[str, exp.Expr]]
class GrantsTargetLayer(str, Enum):
@@ -92,8 +93,8 @@ class ModelMeta(_Node):
retention: t.Optional[int] = None # not implemented yet
table_format: t.Optional[str] = None
storage_format: t.Optional[str] = None
- partitioned_by_: t.List[exp.Expression] = Field(default=[], alias="partitioned_by")
- clustered_by: t.List[exp.Expression] = []
+ partitioned_by_: t.List[exp.Expr] = Field(default=[], alias="partitioned_by")
+ clustered_by: t.List[exp.Expr] = []
default_catalog: t.Optional[str] = None
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
columns_to_types_: t.Optional[t.Dict[str, exp.DataType]] = Field(default=None, alias="columns")
@@ -101,8 +102,8 @@ class ModelMeta(_Node):
default=None, alias="column_descriptions"
)
audits: t.List[FunctionCall] = []
- grains: t.List[exp.Expression] = []
- references: t.List[exp.Expression] = []
+ grains: t.List[exp.Expr] = []
+ references: t.List[exp.Expr] = []
physical_schema_override: t.Optional[str] = None
physical_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="physical_properties")
virtual_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="virtual_properties")
@@ -135,7 +136,7 @@ def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any:
@field_validator("tags", mode="before")
def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
- return ensure_list(cls._validate_value_or_tuple(v, info.data))
+ return ensure_list(cls._validate_value_or_tuple(v, validation_data(info)))
@classmethod
def _validate_value_or_tuple(
@@ -151,11 +152,11 @@ def _normalize(value: t.Any) -> t.Any:
if isinstance(v, (exp.Tuple, exp.Array)):
return [_normalize(e).name for e in v.expressions]
- if isinstance(v, exp.Expression):
+ if isinstance(v, exp.Expr):
return _normalize(v).name
if isinstance(v, str):
value = _normalize(v)
- return value.name if isinstance(value, exp.Expression) else value
+ return value.name if isinstance(value, exp.Expr) else value
if isinstance(v, (list, tuple)):
return [cls._validate_value_or_tuple(elm, data, normalize=normalize) for elm in v]
@@ -163,8 +164,8 @@ def _normalize(value: t.Any) -> t.Any:
@field_validator("table_format", "storage_format", mode="before")
def _format_validator(cls, v: t.Any, info: ValidationInfo) -> t.Optional[str]:
- if isinstance(v, exp.Expression) and not (isinstance(v, (exp.Literal, exp.Identifier))):
- return v.sql(info.data.get("dialect"))
+ if isinstance(v, exp.Expr) and not (isinstance(v, (exp.Literal, exp.Identifier))):
+ return v.sql(validation_data(info).get("dialect"))
return str_or_exp_to_str(v)
@field_validator("dialect", mode="before")
@@ -188,13 +189,11 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]:
return gateway and gateway.lower()
@field_validator("partitioned_by_", "clustered_by", mode="before")
- def _partition_and_cluster_validator(
- cls, v: t.Any, info: ValidationInfo
- ) -> t.List[exp.Expression]:
+ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.List[exp.Expr]:
if (
isinstance(v, list)
and all(isinstance(i, str) for i in v)
- and info.field_name == "partitioned_by_"
+ and (info.field_name or "") == "partitioned_by_"
):
# this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str]
# however, we should only invoke this if the list contains strings because this validator is also
@@ -207,7 +206,7 @@ def _partition_and_cluster_validator(
)
v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v
- expressions = list_of_fields_validator(v, info.data)
+ expressions = list_of_fields_validator(v, validation_data(info))
for expression in expressions:
num_cols = len(list(expression.find_all(exp.Column)))
@@ -230,7 +229,7 @@ def _columns_validator(
cls, v: t.Any, info: ValidationInfo
) -> t.Optional[t.Dict[str, exp.DataType]]:
columns_to_types = {}
- dialect = info.data.get("dialect")
+ dialect = validation_data(info).get("dialect")
if isinstance(v, exp.Schema):
for column in v.expressions:
@@ -244,9 +243,33 @@ def _columns_validator(
return columns_to_types
if isinstance(v, dict):
- udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
+ dialect_obj = Dialect.get_or_raise(dialect)
+ udt = dialect_obj.SUPPORTS_USER_DEFINED_TYPES
for k, data_type in v.items():
+ is_string_type = isinstance(data_type, str)
expr = exp.DataType.build(data_type, dialect=dialect, udt=udt)
+ # When deserializing from a string (e.g. JSON roundtrip), normalize the type
+ # through the dialect's type system so that aliases (e.g. INT in BigQuery,
+ # which is an alias for INT64/BIGINT) are resolved to their canonical form.
+ # This ensures stable data hash computation across serialization/deserialization
+ # roundtrips. We skip this for DataType objects passed directly (Python API)
+ # since those should be used as-is.
+ if (
+ is_string_type
+ and dialect
+ and expr.this
+ not in (
+ exp.DataType.Type.USERDEFINED,
+ exp.DataType.Type.UNKNOWN,
+ )
+ ):
+ sql_repr = expr.sql(dialect=dialect)
+ try:
+ normalized = parse_one(sql_repr, read=dialect, into=exp.DataType)
+ if normalized is not None:
+ expr = normalized
+ except Exception:
+ pass
expr.meta["dialect"] = dialect
columns_to_types[normalize_identifiers(k, dialect=dialect).name] = expr
@@ -258,7 +281,8 @@ def _columns_validator(
def _column_descriptions_validator(
cls, vs: t.Any, info: ValidationInfo
) -> t.Optional[t.Dict[str, str]]:
- dialect = info.data.get("dialect")
+ data = validation_data(info)
+ dialect = data.get("dialect")
if vs is None:
return None
@@ -280,7 +304,7 @@ def _column_descriptions_validator(
for k, v in raw_col_descriptions.items()
}
- columns_to_types = info.data.get("columns_to_types_")
+ columns_to_types = data.get("columns_to_types_")
if columns_to_types:
from sqlmesh.core.console import get_console
@@ -288,15 +312,15 @@ def _column_descriptions_validator(
for column_name in list(col_descriptions):
if column_name not in columns_to_types:
console.log_warning(
- f"In model '{info.data['name']}', a description is provided for column '{column_name}' but it is not a column in the model."
+ f"In model '{data.get('name', '