diff --git a/.azure-pipelines/_release-template.yml b/.azure-pipelines/_release-template.yml
deleted file mode 100644
index 9bf1d9d0cf..0000000000
--- a/.azure-pipelines/_release-template.yml
+++ /dev/null
@@ -1,21 +0,0 @@
-# Template steps for the release pipeline
-
-steps:
- - task: UsePythonVersion@0
- inputs:
- versionSpec: '3.11'
- displayName: 'Set Up Python'
- - script: python -m pip install --upgrade pip build wheel
- displayName: 'Install Python build dependencies'
- - script: python -m build
- displayName: 'Build ONNX Script wheel dev version'
- - task: CopyFiles@2
- displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)'
- inputs:
- SourceFolder: 'dist'
- Contents: '*.*'
- TargetFolder: '$(Build.ArtifactStagingDirectory)'
- - task: PublishBuildArtifacts@1
- displayName: 'Publish onnxscript'
- inputs:
- ArtifactName: onnxscript
diff --git a/.azure-pipelines/publish-dev.yml b/.azure-pipelines/publish-dev.yml
new file mode 100644
index 0000000000..77968d313b
--- /dev/null
+++ b/.azure-pipelines/publish-dev.yml
@@ -0,0 +1,45 @@
+trigger: none
+name: onnxscript-publish-dev.$(Date:yyyyMMdd).$(Rev:r)
+resources:
+ repositories:
+ - repository: 1ESPipelineTemplates
+ type: git
+ name: 1ESPipelineTemplates/1ESPipelineTemplates
+ ref: refs/tags/release
+ pipelines:
+ - pipeline: onnxscript-release-dev
+ source: onnxscript-release-dev
+ trigger: true
+extends:
+ template: v1/1ES.Official.PipelineTemplate.yml@1ESPipelineTemplates
+ parameters:
+ stages:
+ - stage: Release
+ dependsOn: []
+ jobs:
+ - job: onnxscript_publish_dev
+ templateContext:
+ type: releaseJob
+ isProduction: true
+ inputs:
+ - input: pipelineArtifact
+ artifactName: drop
+ pipeline: onnxscript-release-dev
+ targetPath: $(Pipeline.Workspace)/drop
+ pool:
+ name: 'onnxruntime-Win-CPU-2022'
+ steps:
+ - task: EsrpRelease@9
+ displayName: 'ESRP Release'
+ inputs:
+ connectedservicename: esrp_release
+ keyvaultname: 'ortbuildkeyvault'
+ signcertname: 'esrpcodesign'
+ clientid: '53d54d02-978d-4305-8572-583cf6711c4f'
+ contenttype: PyPi
+ folderlocation: '$(Pipeline.Workspace)/drop'
+ owners: 'justinchu@microsoft.com'
+ approvers: 'grama@microsoft.com'
+ mainpublisher: AIFrameworks
+ usemanagedidentity: true
+ domaintenantid: '975f013f-7f24-47e8-a7d3-abc4752bf346'
diff --git a/.azure-pipelines/publish.yml b/.azure-pipelines/publish.yml
new file mode 100644
index 0000000000..e37d34a282
--- /dev/null
+++ b/.azure-pipelines/publish.yml
@@ -0,0 +1,50 @@
+trigger: none
+name: onnxscript-publish.$(Date:yyyyMMdd).$(Rev:r)
+resources:
+ repositories:
+ - repository: 1ESPipelineTemplates
+ type: git
+ name: 1ESPipelineTemplates/1ESPipelineTemplates
+ ref: refs/tags/release
+ pipelines:
+ - pipeline: onnxscript-release
+ source: onnxscript-release
+ trigger: true
+extends:
+ template: v1/1ES.Official.PipelineTemplate.yml@1ESPipelineTemplates
+ parameters:
+ stages:
+ - stage: Release
+ dependsOn: []
+ jobs:
+ - deployment: onnxscript_publish
+ templateContext:
+ type: releaseJob
+ isProduction: true
+ inputs:
+ - input: pipelineArtifact
+ artifactName: drop
+ pipeline: onnxscript-release
+ targetPath: $(Pipeline.Workspace)/drop
+ environment:
+ name: 'onnxscript-release'
+ pool:
+ name: 'onnxruntime-Win-CPU-2022'
+ strategy:
+ runOnce:
+ deploy:
+ steps:
+ - task: EsrpRelease@9
+ displayName: 'ESRP Release'
+ inputs:
+ connectedservicename: esrp_release
+ keyvaultname: 'ortbuildkeyvault'
+ signcertname: 'esrpcodesign'
+ clientid: '53d54d02-978d-4305-8572-583cf6711c4f'
+ contenttype: PyPi
+ folderlocation: '$(Pipeline.Workspace)/drop'
+ owners: 'justinchu@microsoft.com'
+ approvers: 'grama@microsoft.com'
+ mainpublisher: AIFrameworks
+ usemanagedidentity: true
+ domaintenantid: '975f013f-7f24-47e8-a7d3-abc4752bf346'
diff --git a/.azure-pipelines/release-dev.yml b/.azure-pipelines/release-dev.yml
index 81ffa68b3a..61f780ed31 100644
--- a/.azure-pipelines/release-dev.yml
+++ b/.azure-pipelines/release-dev.yml
@@ -3,9 +3,28 @@
# To configure triggers, see https://github.com/microsoft/onnx-converters-private/wiki/ONNX-Script-release
trigger: none
-pool:
- vmImage: ubuntu-latest
variables:
CI: 'true'
-steps:
- - template: _release-template.yml
+
+resources:
+ repositories:
+ - repository: 1esPipelines
+ type: git
+ name: 1ESPipelineTemplates/1ESPipelineTemplates
+ ref: refs/tags/release
+
+extends:
+ # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks.
+ # For non-production pipelines, use "Unofficial" as defined below.
+ # For productions pipelines, use "Official".
+ template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines
+ parameters:
+ sdl:
+ sourceAnalysisPool:
+ name: onnxruntime-Win-CPU-2022
+ os: windows
+ pool:
+ name: 'onnxruntime-Ubuntu2204-AMD-CPU'
+ os: 'linux'
+ stages:
+ - template: stages/release-stage.yml
diff --git a/.azure-pipelines/release.yml b/.azure-pipelines/release.yml
index 130ae5a09c..b5fde4c319 100644
--- a/.azure-pipelines/release.yml
+++ b/.azure-pipelines/release.yml
@@ -2,19 +2,30 @@
trigger: none
-pool:
- vmImage: ubuntu-latest
variables:
CI: 'true'
# Set the release environment variable to build a release version of the wheel
ONNX_SCRIPT_RELEASE: 1
-steps:
- - template: _release-template.yml
- # Test the wheels. This needs to happen after PublishBuildArtifacts
- # to avoid interference with the artifacts
- - script: python -m pip install -r requirements-dev.txt
- displayName: 'Install Python dependencies'
- - script: python -m pip install dist/*.whl --no-deps
- displayName: 'Install wheel'
- - script: python -m pytest -v -n auto
- displayName: 'Run tests'
+
+resources:
+ repositories:
+ - repository: 1esPipelines
+ type: git
+ name: 1ESPipelineTemplates/1ESPipelineTemplates
+ ref: refs/tags/release
+
+extends:
+ # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks.
+ # For non-production pipelines, use "Unofficial" as defined below.
+ # For productions pipelines, use "Official".
+ template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines
+ parameters:
+ sdl:
+ sourceAnalysisPool:
+ name: onnxruntime-Win-CPU-2022
+ os: windows
+ pool:
+ name: 'onnxruntime-Ubuntu2204-AMD-CPU'
+ os: 'linux'
+ stages:
+ - template: stages/release-stage.yml
diff --git a/.azure-pipelines/stages/jobs/steps/release-steps.yml b/.azure-pipelines/stages/jobs/steps/release-steps.yml
new file mode 100644
index 0000000000..be1d9e8860
--- /dev/null
+++ b/.azure-pipelines/stages/jobs/steps/release-steps.yml
@@ -0,0 +1,20 @@
+steps:
+- task: UsePythonVersion@0
+ inputs:
+ versionSpec: '3.11'
+ displayName: 'Set Up Python'
+- script: python -m pip install --upgrade pip build wheel
+ displayName: 'Install Python build dependencies'
+- script: python -m build
+ displayName: 'Build ONNX Script wheel'
+- task: CopyFiles@2
+ displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)'
+ inputs:
+ SourceFolder: 'dist'
+ Contents: '*.*'
+ TargetFolder: '$(Build.ArtifactStagingDirectory)'
+- task: 1ES.PublishPipelineArtifact@1
+ displayName: 'Publish Python Wheel'
+ inputs:
+ ArtifactName: 'onnxscript'
+ targetPath: '$(Build.ArtifactStagingDirectory)'
diff --git a/.azure-pipelines/stages/release-stage.yml b/.azure-pipelines/stages/release-stage.yml
new file mode 100644
index 0000000000..881fdbd60b
--- /dev/null
+++ b/.azure-pipelines/stages/release-stage.yml
@@ -0,0 +1,11 @@
+stages:
+- stage: Stage
+ jobs:
+ - job: Job
+ steps:
+ - template: jobs/steps/release-steps.yml
+ # Test the wheels. This needs to happen after PublishBuildArtifacts
+ # to avoid interference with the artifacts
+ - script: python -m pip install dist/*.whl --no-deps
+ displayName: 'Install wheel'
+ condition: eq(variables['ONNX_SCRIPT_RELEASE'], 1)
diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
new file mode 100644
index 0000000000..b74c06fed3
--- /dev/null
+++ b/.github/copilot-instructions.md
@@ -0,0 +1,5 @@
+## Code Standards
+
+### Required Before Each Commit
+- Run `lintrunner -a` before committing any changes to ensure proper code formatting
+- This will run lintrunner on all updated files to maintain consistent style
diff --git a/.github/release.yml b/.github/release.yml
new file mode 100644
index 0000000000..2434ad5390
--- /dev/null
+++ b/.github/release.yml
@@ -0,0 +1,30 @@
+changelog:
+ exclude:
+ authors:
+ - dependabot
+ categories:
+ - title: Breaking Changes
+ labels:
+ - "topic: breaking changes"
+ - title: Core ONNX Script
+ labels:
+ - "topic: onnxscript core"
+ - "topic: ast converter"
+ - title: Optimizer and rewriter
+ labels:
+ - "module: rewriter"
+ - "module: optimizer"
+ - "topic: ort-fusions"
+ - title: ONNX IR
+ labels:
+ - "module: IR"
+ - "topic: passes"
+ - title: Torch Lib
+ labels:
+ - "module: torchlib"
+ - title: Documentation
+ labels:
+ - "topic: documentation"
+ - title: Other Changes
+ labels:
+ - "*"
diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml
index a4cedc9daa..6953a76929 100644
--- a/.github/workflows/codeql-analysis.yml
+++ b/.github/workflows/codeql-analysis.yml
@@ -41,7 +41,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
index d0ecd01ebf..3fe51a3a5a 100644
--- a/.github/workflows/lint.yaml
+++ b/.github/workflows/lint.yaml
@@ -6,7 +6,7 @@ on:
- main
- 'gh/**/base' # ghstack base branches
pull_request:
- types: [opened, synchronize, reopened, ready_for_review]
+ merge_group:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
@@ -20,7 +20,7 @@ jobs:
pull-requests: write
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: misspell # Check spelling
uses: reviewdog/action-misspell@v1
with:
@@ -43,19 +43,20 @@ jobs:
permissions:
security-events: write
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: Setup Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
# Version range or exact version of Python to use, using SemVer's version range syntax. Reads from .python-version if unset.
python-version: "3.10"
- name: Install ONNXScript
run: |
- # The code is from azure-pipelines.yml
# Install dependencies
python -m pip install --upgrade pip
python -m pip install --upgrade setuptools
- python -m pip install -q -r requirements-dev.txt
+ python -m pip install -r requirements-dev.txt
+ # FIXME: numpy 2.2 has some typing changes that break the mypy CI but it's otherwise fine
+ python -m pip install "numpy<2.2"
# Install packages
python -m pip install -e .
lintrunner init
diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml
index f28d6ce349..faf40b9ec3 100644
--- a/.github/workflows/main.yaml
+++ b/.github/workflows/main.yaml
@@ -13,6 +13,7 @@ on:
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
+ merge_group:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
@@ -25,32 +26,23 @@ jobs:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
name:
- - py312-torch-nightly
+ - py312
- py311
- py311-torch-nightly
- py311-onnx-weekly
- py311-ort-nightly
- - py311-experimental-torchlib-tracing
- - py311-experimental-torchlib-onnx-ir
+ - py311-onnx-ir-git
- py310
- - py39
- - py38
include:
+ - name: py312
+ python-version: "3.12"
+ nox-tag: test build
- name: py311
python-version: "3.11"
- nox-tag: test build
+ nox-tag: test
- name: py310
python-version: "3.10"
nox-tag: test
- - name: py39
- python-version: "3.9"
- nox-tag: test
- - name: py38
- python-version: "3.8"
- nox-tag: test
- - name: py312-torch-nightly
- python-version: "3.12"
- nox-tag: test-torch-nightly
- name: py311-torch-nightly
python-version: "3.11"
nox-tag: test-torch-nightly
@@ -60,17 +52,14 @@ jobs:
- name: py311-ort-nightly
python-version: "3.11"
nox-tag: test-ort-nightly
- - name: py311-experimental-torchlib-tracing
- python-version: "3.11"
- nox-tag: test-experimental-torchlib-tracing
- - name: py311-experimental-torchlib-onnx-ir
+ - name: py311-onnx-ir-git
python-version: "3.11"
- nox-tag: test-experimental-torchlib-onnx-ir
+ nox-tag: test-onnx-ir-git
runs-on: ${{ matrix.os }}
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install nox
@@ -78,33 +67,26 @@ jobs:
- name: Pull Test Data
run: git lfs pull
- name: Run tests
- run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml
+ run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junitxml junit.xml
env:
CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
- name: Upload coverage to Codecov
if: always()
- uses: codecov/codecov-action@v4
+ uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
- - name: Upload Test Results
- if: always()
- uses: actions/upload-artifact@v3
+ - name: Upload test results to Codecov
+ if: ${{ !cancelled() }}
+ uses: codecov/test-results-action@v1
with:
- name: Test Results (${{ matrix.name }}-${{ matrix.os }})
- path: pytest.xml
+ token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload torchlib error reports
if: always()
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: Error reports (${{ matrix.name }}-${{ matrix.os }})
path: error_reports
- - name: Upload IR profiling results
- if: matrix.name == 'py311' || matrix.name == 'py311-onnx-weekly'
- uses: actions/upload-artifact@v3
- with:
- name: IR profiling results
- path: tests/ir/serde_test_profiles
build_docs:
strategy:
@@ -113,9 +95,9 @@ jobs:
os: [ubuntu-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: Setup Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: "3.10"
cache: pip
@@ -137,9 +119,9 @@ jobs:
update_readme:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: Setup Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
- name: Update readme
run: |
python docs/update_readme.py
@@ -148,23 +130,3 @@ jobs:
echo "Update readme by running `python docs/update_readme.py`"
exit 1
fi
-
- publish-test-results:
- name: "Publish Tests Results to Github"
- needs: test
- runs-on: ubuntu-latest
- permissions:
- checks: write
- # only needed unless run with comment_mode: off
- pull-requests: write
- if: always()
- steps:
- - name: Download Artifacts
- uses: actions/download-artifact@v3
- with:
- path: artifacts
-
- - name: Publish Test Results
- uses: EnricoMi/publish-unit-test-result-action@v2
- with:
- files: "artifacts/**/*.xml"
diff --git a/.github/workflows/pages.yaml b/.github/workflows/pages.yaml
index 1e6aa4142c..ce638dc60d 100644
--- a/.github/workflows/pages.yaml
+++ b/.github/workflows/pages.yaml
@@ -25,14 +25,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
- name: Setup Pages
uses: actions/configure-pages@v4
- name: Setup Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: "3.10"
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
@@ -42,7 +42,7 @@ jobs:
- name: Build documentation
run: python -m sphinx docs dist/html
- name: Upload documentation archive
- uses: actions/upload-pages-artifact@v3
+ uses: actions/upload-pages-artifact@v4
with:
path: 'dist/html'
- name: Deploy to GitHub Pages
diff --git a/.gitignore b/.gitignore
index 0e9a057b9f..3344aa7659 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,9 +41,11 @@ coverage.xml
.pytest_cache/
cover/
test-output.xml
+*.sarif
# Sphinx documentation
docs/_build/
+docs/sg_execution_times.rst
# Jupyter Notebook
.ipynb_checkpoints
@@ -93,10 +95,13 @@ dmypy.json
# Generated files
*.onnx
+*.csv
+*.xlsx
!testdata/**/*.onnx
*.onnxlib
**/onnx_backend_test_code/**
docs/auto_examples/*
+docs/**/generated/*
tests/export/*
tests/models/testoutputs/*
tests/mylib.onnxlib
diff --git a/.lintrunner.toml b/.lintrunner.toml
index aa88d1f66e..907f3bfce6 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -46,17 +46,17 @@ exclude_patterns = [
'onnxscript/onnx_types.py',
'onnxscript/**/*_test.py', # Skip linting test files for speed
'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy
- 'onnxscript/optimizer/evaluator.py', # FIXME
- 'onnxscript/optimizer/constant_folding.py', # FIXME
+ 'onnxscript/optimizer/_legacy/evaluator.py', # FIXME
+ 'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
- 'onnxscript/_legacy_ir/irbuilder.py', # FIXME
- 'onnxscript/optimizer/fold_constants_v0.py', # FIXME
+ 'onnxscript/rewriter/ort_fusions/models/*.py', # onnxscript code
+ 'onnxscript/rewriter/ort_fusions/models/_phi2lm.py', # onnxscript code
+ 'onnxscript/rewriter/ort_fusions/models/_phi4lm.py', # onnxscript code
+ 'onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py', # onnxscript code
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
- 'onnxscript/_legacy_ir/visitor.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
- 'onnxscript/rewriter/generic_pattern.py', # FIXME
]
command = [
'python',
@@ -113,16 +113,14 @@ include_patterns = [
'**/*.py',
]
exclude_patterns = [
- 'examples/**', # TODO: Merge with docs/examples
- 'docs/examples/**',
- 'docs/tutorial/examples/**',
+ 'examples/**',
+ 'docs/**',
'onnxscript/converter_test.py',
'tests/functions/**',
'tests/models/**',
'tests/onnx_backend_test_code/**',
'onnxscript/optimizer/**', # FIXME
'onnxscript/rewriter/**', # FIXME
- 'onnxscript/_legacy_ir/**', # FIXME
]
command = [
'python',
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index f9ba8cf65f..686e5e7a09 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -7,3 +7,4 @@ Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
+- Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 66d4781c4f..346fad1f6a 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,19 +1,3 @@
-
-
-⚠️ |
-
-NOTE: ONNX Script is in very early
-and active development and the team anticipates
-breaking changes as the project evolves.
-ONNX Script is not ready for production,
-but early feedback is welcome.
- |
-⚠️ |
-
-
-
-----
-
# Contributing to ONNX Script
We're always looking for your help to improve the product (bug fixes, new features, documentation, etc). Currently ONNX Script is under early and heavy development, so we encourage proposing any major changes by [filing an issue](https://github.com/microsoft/onnxscript/issues) to discuss your idea with the team first.
diff --git a/README.md b/README.md
index 484917be66..ec3ce7bcc8 100644
--- a/README.md
+++ b/README.md
@@ -15,9 +15,19 @@ models using a subset of Python. ONNX Script is:
* **Debuggable:** allows for eager-mode evaluation that provides for a
more delightful ONNX model debugging experience.
+This repo also covers:
+
+* **ONNX Script Optimizer:** provides functionality to optimize an ONNX
+ model by performing optimizations and clean-ups such as constant folding,
+ dead code elimination, etc.
+* **ONNX Rewriter:** provides functionality to replace certain patterns in
+ an ONNX graph with replacement patterns based on user-defined rewrite rules.
+
Note however that ONNX Script does **not** intend to support the entirety
of the Python language.
+Website: [https://microsoft.github.io/onnxscript/](https://microsoft.github.io/onnxscript/)
+
## Design Overview
ONNX Script provides a few major capabilities for authoring and debugging
@@ -140,6 +150,63 @@ result = Hardmax(v)
More examples can be found in the [docs/examples](docs/examples) directory.
+## ONNX Script Tools
+
+### ONNX Optimizer
+
+The ONNX Script Optimizer tool provides the user with the functionality to optimize an ONNX model by performing optimizations and clean-ups such as constant folding, dead code elimination, etc. In order to utilize the optimizer tool:
+
+```python
+import onnxscript
+
+onnxscript.optimizer.optimize(onnx_model)
+```
+
+For a detailed summary of all the optimizations applied by the optimizer call, refer to the tutorial [Optimizing a Model using the Optimizer](https://microsoft.github.io/onnxscript/tutorial/optimizer/optimize.html)
+
+### ONNX Rewriter
+
+The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on user-defined rewrite rules. The rewriter tools allows two different methods in which patterns in the graph can be rewritten.
+
+### Pattern-based rewriting
+
+For this style of rewriting, the user provides a `target_pattern` that is to be replaced, a `replacement_pattern` and a `match_condition` (pattern rewrite will occur only if the match condition is satisfied). A simple example on how to use the pattern-based rewriting tool is as follows:
+
+```python
+from onnxscript.rewriter import pattern
+
+# The target pattern
+def erf_gelu_pattern(op, x):
+ return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
+
+def erf_gelu_pattern_2(op, x):
+ return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
+
+# The replacement pattern
+def gelu(op, x: ir.Value):
+ return op.Gelu(x, domain="com.microsoft")
+
+# Create multiple rules
+rule1 = pattern.RewriteRule(
+ erf_gelu_pattern, # Target Pattern
+ gelu, # Replacement
+)
+rule2 = pattern.RewriteRule(
+ erf_gelu_pattern_2, # Target Pattern
+ gelu, # Replacement
+)
+# Create a Rewrite Rule Set with multiple rules.
+rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
+# Apply rewrites
+model_with_rewrite_applied = onnxscript.rewriter.rewrite(
+ model, # Original ONNX Model
+ pattern_rewrite_rules=rewrite_rule_set,
+)
+return model_with_rewrite_applied
+```
+
+For a detailed tutorial on how to create target_pattern, replacement_pattern and match_condition blocks in order to utilize the pattern-based rewriter, refer to the tutorial [Pattern-based Rewrite Using Rules](https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html)
+
## Development Guidelines
Every change impacting the converter or the eager evaluation must be
diff --git a/VERSION b/VERSION
index 6e8bf73aa5..7d8568351b 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.1.0
+0.5.4
diff --git a/docs/_templates/classtemplate.rst b/docs/_templates/classtemplate.rst
new file mode 100644
index 0000000000..24a5ac1803
--- /dev/null
+++ b/docs/_templates/classtemplate.rst
@@ -0,0 +1,14 @@
+.. role:: hidden
+ :class: hidden-section
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :members:
+ :undoc-members:
+ :member-order: bysource
+
+..
+ autogenerated from docs/_templates/classtemplate.rst
diff --git a/docs/_templates/classtemplate_inherited.rst b/docs/_templates/classtemplate_inherited.rst
new file mode 100644
index 0000000000..07c84a9068
--- /dev/null
+++ b/docs/_templates/classtemplate_inherited.rst
@@ -0,0 +1,16 @@
+.. role:: hidden
+ :class: hidden-section
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :members:
+ :undoc-members:
+ :inherited-members:
+ :member-order: bysource
+
+
+..
+ autogenerated from docs/_templates/classtemplate.rst
diff --git a/docs/_templates/functiontemplate.rst b/docs/_templates/functiontemplate.rst
new file mode 100644
index 0000000000..f41fb0d764
--- /dev/null
+++ b/docs/_templates/functiontemplate.rst
@@ -0,0 +1,12 @@
+.. role:: hidden
+ :class: hidden-section
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autofunction:: {{ name }}
+
+
+..
+ autogenerated from docs/_templates/functiontemplate.rst
diff --git a/docs/api/index.md b/docs/api/index.md
index 59162fb166..a6dd4bd59b 100644
--- a/docs/api/index.md
+++ b/docs/api/index.md
@@ -1,8 +1,31 @@
# API
+## Author Models
+
```{toctree}
+:maxdepth: 1
+
decorator
opsets
converter
values
```
+
+## Model transformation
+
+```{toctree}
+:maxdepth: 1
+
+optimizer
+rewriter
+rewriter_pattern
+version_converter
+```
+
+## Testing
+
+```{toctree}
+:maxdepth: 1
+
+testing
+```
diff --git a/docs/api/optimizer.md b/docs/api/optimizer.md
new file mode 100644
index 0000000000..6c8adf21bb
--- /dev/null
+++ b/docs/api/optimizer.md
@@ -0,0 +1,18 @@
+# onnxscript.optimizer
+
+```{eval-rst}
+.. automodule::onnxscript.optimizer
+.. currentmodule:: onnxscript
+```
+
+```{eval-rst}
+.. autosummary::
+ :toctree: generated
+ :template: functiontemplate.rst
+ :nosignatures:
+
+ optimizer.optimize
+ optimizer.inline
+ optimizer.basic_constant_propagation
+ optimizer.fold_constants
+```
diff --git a/docs/api/rewriter.md b/docs/api/rewriter.md
new file mode 100644
index 0000000000..8ff015844b
--- /dev/null
+++ b/docs/api/rewriter.md
@@ -0,0 +1,26 @@
+# onnxscript.rewriter
+
+```{eval-rst}
+.. automodule::onnxscript.rewriter
+.. currentmodule:: onnxscript
+```
+
+```{eval-rst}
+.. autosummary::
+ :toctree: generated
+ :template: functiontemplate.rst
+ :nosignatures:
+
+ rewriter.rewrite
+```
+
+## IR passes
+
+```{eval-rst}
+.. autosummary::
+ :toctree: generated
+ :template: classtemplate.rst
+ :nosignatures:
+
+ rewriter.RewritePass
+```
diff --git a/docs/api/rewriter_pattern.md b/docs/api/rewriter_pattern.md
new file mode 100644
index 0000000000..c7deccc6dd
--- /dev/null
+++ b/docs/api/rewriter_pattern.md
@@ -0,0 +1,40 @@
+# onnxscript.rewriter.pattern
+
+```{eval-rst}
+.. automodule::onnxscript.rewriter.pattern
+.. currentmodule:: onnxscript
+```
+
+```{eval-rst}
+.. autosummary::
+ :toctree: generated
+ :template: classtemplate.rst
+ :nosignatures:
+
+ rewriter.pattern.Pattern
+ rewriter.pattern.StringPattern
+ rewriter.pattern.StringConstantPattern
+ rewriter.pattern.PrefixPattern
+ rewriter.pattern.AttrPattern
+ rewriter.pattern.AttrConstantPattern
+ rewriter.pattern.OpsetPatternBuilder
+ rewriter.pattern.OpPatternBuilder
+ rewriter.pattern.MatchResult
+ rewriter.pattern.ValuePattern
+ rewriter.pattern.NodePattern
+ rewriter.pattern.NodeOutputPattern
+ rewriter.pattern.AnyValue
+ rewriter.pattern.Constant
+ rewriter.pattern.OrValue
+ rewriter.pattern.GraphPattern
+ rewriter.pattern.ReplacementSubgraph
+ rewriter.pattern.ReplacementPatternFunction
+ rewriter.pattern.PatternMatcher
+ rewriter.pattern.SimplePatternMatcher
+ rewriter.pattern.RewriteRule
+ rewriter.pattern.RewriteRuleSet
+ rewriter.pattern.RewriteRuleClassBase
+ rewriter.pattern.MatchStatus
+ rewriter.pattern.MatchInfo
+ rewriter.pattern.MatchingTracer
+```
diff --git a/docs/api/testing.md b/docs/api/testing.md
new file mode 100644
index 0000000000..d7d5fca800
--- /dev/null
+++ b/docs/api/testing.md
@@ -0,0 +1,6 @@
+# Testing
+
+```{eval-rst}
+.. automodule:: onnxscript.testing
+ :members:
+```
diff --git a/docs/api/version_converter.md b/docs/api/version_converter.md
new file mode 100644
index 0000000000..0478efbf5a
--- /dev/null
+++ b/docs/api/version_converter.md
@@ -0,0 +1,28 @@
+# onnxscript.version_converter
+
+```{eval-rst}
+.. automodule::onnxscript.version_converter
+.. currentmodule:: onnxscript
+```
+
+## Functions
+
+```{eval-rst}
+.. autosummary::
+ :toctree: generated
+ :template: functiontemplate.rst
+ :nosignatures:
+
+ version_converter.convert_version
+```
+
+## IR passes
+
+```{eval-rst}
+.. autosummary::
+ :toctree: generated
+ :template: classtemplate.rst
+ :nosignatures:
+
+ version_converter.ConvertVersionPass
+```
diff --git a/docs/conf.py b/docs/conf.py
index 63dd8e7d44..d96ffe067f 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,5 +1,9 @@
-# Configuration file for the Sphinx documentation builder.
-# To run the documentation: python -m sphinx docs dist/html
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Configuration file for the Sphinx documentation builder.
+
+To run the documentation: python -m sphinx docs dist/html
+"""
import os
import re
@@ -20,7 +24,7 @@
# -- General configuration ---------------------------------------------------
extensions = [
- "myst_parser",
+ "myst_nb",
"sphinx_copybutton",
"sphinx_exec_code",
"sphinx_gallery.gen_gallery",
@@ -84,7 +88,11 @@
"python": (f"https://docs.python.org/{sys.version_info.major}", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
+ "onnx": ("https://onnx.ai/onnx/", None),
+ "onnx_ir": ("https://onnx.ai/ir-py/", None),
"onnxruntime": ("https://onnxruntime.ai/docs/api/python/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/", None),
+ "torch": ("https://pytorch.org/docs/main/", None),
}
# -- Options for Sphinx Gallery ----------------------------------------------
diff --git a/docs/examples/01_plot_selu.py b/docs/examples/01_plot_selu.py
index 57a1f03c11..5ad3c49355 100644
--- a/docs/examples/01_plot_selu.py
+++ b/docs/examples/01_plot_selu.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""
Generating a FunctionProto
==========================
diff --git a/docs/examples/02_plot_square_loss.py b/docs/examples/02_plot_square_loss.py
index 5dce3545c8..181e4cd2ac 100644
--- a/docs/examples/02_plot_square_loss.py
+++ b/docs/examples/02_plot_square_loss.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""
Generating a ModelProto
=======================
diff --git a/docs/examples/03_export_lib.py b/docs/examples/03_export_lib.py
index 8a8993b7a8..f710fcb880 100644
--- a/docs/examples/03_export_lib.py
+++ b/docs/examples/03_export_lib.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""
Generating a LibProto
=====================
diff --git a/docs/examples/04_plot_eager_mode_evaluation.py b/docs/examples/04_plot_eager_mode_evaluation.py
index 740e2275af..d1c8f7fb75 100644
--- a/docs/examples/04_plot_eager_mode_evaluation.py
+++ b/docs/examples/04_plot_eager_mode_evaluation.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""
Eager mode evaluation
=====================
diff --git a/docs/examples/05_plot_model_props.py b/docs/examples/05_plot_model_props.py
index 4e10339bea..950b0e3467 100644
--- a/docs/examples/05_plot_model_props.py
+++ b/docs/examples/05_plot_model_props.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""
ModelProto Properties
=====================
diff --git a/docs/examples/06_plot_model_local_funs.py b/docs/examples/06_plot_model_local_funs.py
index 3a60b3e6cc..fdb0e434bb 100644
--- a/docs/examples/06_plot_model_local_funs.py
+++ b/docs/examples/06_plot_model_local_funs.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""
Model Local Functions
=====================
diff --git a/docs/index.md b/docs/index.md
index 3cd5e3db30..4dd0472706 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -103,7 +103,7 @@ result = MatmulAdd(x, wt, bias)
Overview
tutorial/index
api/index
-intermediate_representation/index
+ir/index
auto_examples/index
articles/index
```
diff --git a/docs/intermediate_representation/index.md b/docs/intermediate_representation/index.md
deleted file mode 100644
index fd3199671b..0000000000
--- a/docs/intermediate_representation/index.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# ONNX IR
-
-```{toctree}
-:maxdepth: 1
-
-tensors
-ir_api
-```
diff --git a/docs/intermediate_representation/ir_api.md b/docs/intermediate_representation/ir_api.md
deleted file mode 100644
index 2d1d8ebcb6..0000000000
--- a/docs/intermediate_representation/ir_api.md
+++ /dev/null
@@ -1,9 +0,0 @@
-# onnxscript.ir
-
-
-
-```{eval-rst}
-.. automodule:: onnxscript.ir
- :members:
- :undoc-members:
-```
diff --git a/docs/intermediate_representation/tensors.md b/docs/intermediate_representation/tensors.md
deleted file mode 100644
index cca80264ee..0000000000
--- a/docs/intermediate_representation/tensors.md
+++ /dev/null
@@ -1,322 +0,0 @@
-# Tensor Representation in the IR
-
-The ONNX IR offers the {py:class}`ir.TensorProtocol ` interface for usings different data structures as backing data for tensors. Besides the traditional {py:class}`onnx.TensorProto`, you can also use {py:class}`np.ndarray`, {py:class}`torch.Tensor`, {py:class}`jax.Array`, and virtually anything else to represent tensors in the graph. This allows for them to be accessed and serialized via the same `TensorProtocol` interface, without incurring additional copies at initialization.
-
-## The `TensorProtocol`
-
-{py:class}`ir.TensorProtocol ` defines a read-only interface for representing tensors. A tensor class implementing the interface has attributes like `name`, `shape`, `dtype`, `size`, `nbytes` and `metadata_props` to describe basic properties of the tensor. Additionally, it should implement two methods {py:meth}`numpy ` and {py:meth}`__array__ ` which will produce equivalent NumPy arrays from the backing data.
-
-:::{note}
-When interacting with initializers, constant values and tensor attributes, it is best to assume `TensorProtocol` and only use `isinstance` to check for concrete classes when there is a need.
-:::
-
-## Tensor Classes
-
-### ir.TensorProtoTensor
-
-The ONNX spec defines [different ways](https://github.com/onnx/onnx/blob/d6f87121ba256ac6cc4d1da0463c300c278339d2/onnx/onnx.proto#L567-L654) for storing tensor data as an {py:class}`onnx.TensorProto ` protocol buffer message. The IR has corresponding classes for each of these data storage methods.
-
-We use the {py:class}`ir.TensorProtoTensor ` as a wrapper around the proto to implement the `ir.TensorProtocol` interface. You can access `shape`, `dtype` etc. as usual. A copy is incurred only when `numpy()` is called.
-
-:::{note}
-Directly initializing an `ir.TensorProtoTensor`, as below, is possible. However, it is usually recommended to use `ir.serde.deserialize_tensor` because it handles all types of `TensorProto`s (`ir.TensorProtoTensor` doesn't handle external tensors, for example). Please refer to [From `TensorProto`s and back](#from-tensorprotos-and-back) for an example.
-:::
-
-```{eval-rst}
-.. exec_code::
-
- import onnx
- from onnxscript import ir
-
- tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3])
- tensor = ir.TensorProtoTensor(tensor_proto)
- print("tensor: ", tensor) # TensorProtoTensor(name='tensor')
- print("shape: ", tensor.shape) # ir.Shape([3])
- print("dtype: ", tensor.dtype) # ir.DataType.INT16
- print(tensor.raw == tensor_proto) # The raw field is the exact tensor_proto provided at initialization
- print("tobytes: ", tensor.tobytes()) # b'\x01\x00\x02\x00\x03\x00'
- print("numpy: ", tensor.numpy()) # array([1, 2, 3], dtype=int16)
-```
-
-### ir.ExternalTensor
-
-Tensor data stored externally in the disk are typically large and will take up memory when loaded. The {py:class}`ir.ExternalTensor ` class uses memory mapping to avoid loading the tensor into memory. You are able to use the tensor as a normal NumPy array with minimal memory usage.
-
-Refer to {py:func}`ir.serde.deserialize_tensor ` to find an example on converting an `onnx.TensorProto` to an {py:class}`ir.ExternalTensor `.
-
-### ir.Tensor
-
-{py:class}`ir.Tensor ` is a wrapper around NumPy array compatible array objects like {py:class}`np.ndarray` and {py:class}`torch.Tensor`. It is best for creating in-memory tensors without converting it to a `TensorProto` to reduce the conversion overhead.
-
-:::{tip}
-An array object is compatible if it defines the `__array__` method.
-:::
-
-To create a tensor from an array, simply initialize it with an NumPy array
-
-```python
-tensor = ir.Tensor(np.random.rand(1, 2))
-```
-
-The initializer will obtain dtype and shape information from the array.
-
-To create a tensor from objects other than NumPy array, you need to specify the dtype:
-
-```{eval-rst}
-.. exec_code::
-
- import torch
- from onnxscript import ir
-
- torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16)
- tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16)
- print(tensor.numpy()) # array([1., 2., 3.], dtype=float16)
-```
-
-### String Tensor
-
-Use {py:class}`ir.StringTensor ` to create a string tensor.
-
-
-
-### Sparse Tensor
-
-Sparse tensors are not yet supported, but they are on our roadmap.
-
-## From `TensorProto`s and back
-
-In the following scenario, we show how to go from a `TensorProto` to an `ir.Tensor`, run some computation, then turn it back to an `ir.Tensor` and finally `TensorProto`
-
-```{eval-rst}
-.. exec_code::
-
- from onnxscript import ir
- import onnx
- import numpy as np
-
- # 1. Create the TensorProto
- proto = onnx.helper.make_tensor(
- "tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6]
- )
-
- # 2. Create an IR Tensor from the Protobuf message
- tensor = ir.serde.deserialize_tensor(proto)
- # Note that we get a TensorProtoTensor that implements the TensorProtocol
- print("tensor:", tensor) # TensorProtoTensor(name='tensor')
- print("tensor.numpy():", tensor.numpy()) # [[1. 2. 3.]
- # [4. 5. 6.]]
- print("tensor.tobytes():", tensor.tobytes()) # b'\x00<\x00@\x00B\x00D\x00E\x00F'
-
- # 3. Do computation using numpy
- mean = tensor.numpy().mean(axis=0)
- print("mean:", mean) # array([2.5, 3.5, 4.5], dtype=float16)
-
- # 4. Create a Tensor from the ndarray. Note that we use ir.Tensor
- tensor_mean = ir.Tensor(mean)
- print("tensor_mean:", tensor_mean) # Tensor(array([2.5, 3.5, 4.5], dtype=float16), name='')
-
- # 5. Obtain the TensorProto from ir.Tensor
- mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean)
- print("mean_tensor_proto:", mean_tensor_proto)
- print(
- "onnx.numpy_helper.to_array(mean_tensor_proto):",
- onnx.numpy_helper.to_array(mean_tensor_proto)
- # array([2.5, 3.5, 4.5], dtype=float16)
- )
-
- # You can obtain the bytes data as well
- print("tensor_mean.tobytes():", tensor_mean.tobytes())
- print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes())
-
- # Explore other methods defined by TensorProtocol:
- print("\n# Explore other methods defined by TensorProtocol:")
- print("tensor_mean.shape:", tensor_mean.shape)
- print("tensor_mean.dtype:", tensor_mean.dtype)
- print("tensor_mean.name:", tensor_mean.name)
- print("tensor_mean.doc_string:", tensor_mean.doc_string)
- print("tensor_mean.raw:", tensor_mean.raw)
- print("tensor_mean.metadata_props:", tensor_mean.metadata_props)
- print("tensor_mean.size:", tensor_mean.size)
- print("tensor_mean.nbytes:", tensor_mean.nbytes)
- print("tensor_mean.raw:", tensor_mean.raw)
- print("\nUse the display() method to view the tensor")
- tensor_mean.display()
-```
-
-## Working with non-native NumPy dtypes: bfloat16, float8, int4
-
-`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, the value is the bit representation for the dtype:
-
-- `int8` for (unpacked) int4, with the sign bit extended to 8 bits.
-- `uint8` for (unpacked) uint4.
-- `uint8` for 8-bit data types like float8.
-- `uint16` for bfloat16.
-
-uint4/int4 is always unpacked; `tobyte()` produces a packed representation as expected.
-
-Initialization of `ir.Tensor` requires the NumPy array to follow these typing constraints as well.
-
-:::{tip}
-You can use the [ml_dtypes package](https://github.com/jax-ml/ml_dtypes) to extend NumPy and work with these values.
-
-```bash
-pip install --upgrade ml_dtypes
-```
-
-:::
-
-The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its values, and create a new tensor to store the transformed values.
-
-```{eval-rst}
-.. exec_code::
-
- from onnxscript import ir
- import numpy as np
-
- array = np.array([0b1, 0b11], dtype=np.uint8)
- tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)
- print(tensor) # Tensor(array([1, 3], dtype=uint8), name='')
- print("tensor.numpy():", tensor.numpy()) # array([1, 3], dtype=uint8)
-
- # You can use the ml_dtypes package to work with these values in NumPy
- import ml_dtypes
- float8_array = tensor.numpy().view(ml_dtypes.float8_e4m3fn)
- print("float8_array:", float8_array) # array([0.00195312, 0.00585938], dtype='float8_e4m3fn')
-
- # Compute
- times_100 = float8_array * 100
- print("times_100:", times_100)
-
- # Create a new tensor out of the new value; dtype must be specified
- new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)
- print("new_tensor:", new_tensor) # Tensor(array([36, 49], dtype=uint8), name='')
- print("new_tensor == times_100", new_tensor.numpy().view(ml_dtypes.float8_e4m3fn) == times_100) # array([ True, True])
-
-```
-
-## Advanced Usage
-
-### Subclass ir.Tensor for More Efficient Access and Broader dtype Support
-
-{py:class}`ir.Tensor` internally converts any array compatible objects into NumPy arrays to produce the byte representation in `tobytes()`. This can be inefficient due to the additional conversion. It also limits support for dtypes not supported by NumPy like bfloat16, because the `__array__` method would fail.
-
-To fully support arrays from other frameworks, it is usually a good idea to create specialized classes to handle them. The `TorchTensor` class below demonstrates how you can subclass `ir.Tensor` to handle PyTorch tensors:
-
-```{eval-rst}
-.. exec_code::
-
- import ctypes
- from typing import Any
-
- import torch
- from onnxscript import ir
-
- # Define utilities to convert PyTorch data types so users do not need to specify manually
- _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
- torch.bfloat16: ir.DataType.BFLOAT16,
- torch.bool: ir.DataType.BOOL,
- torch.complex128: ir.DataType.COMPLEX128,
- torch.complex64: ir.DataType.COMPLEX64,
- torch.float16: ir.DataType.FLOAT16,
- torch.float32: ir.DataType.FLOAT,
- torch.float64: ir.DataType.DOUBLE,
- torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
- torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
- torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
- torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
- torch.int16: ir.DataType.INT16,
- torch.int32: ir.DataType.INT32,
- torch.int64: ir.DataType.INT64,
- torch.int8: ir.DataType.INT8,
- torch.uint8: ir.DataType.UINT8,
- }
-
-
- def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
- return _TORCH_DTYPE_TO_ONNX[dtype]
-
- class TorchTensor(ir.Tensor):
- def __init__(self, tensor: torch.Tensor):
- # Pass the tensor as the raw data to ir.Tensor's constructor
- super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
-
- def __array__(self, dtype: Any = None) -> "np.ndarray":
- # numpy() calls __array__ in ir.Tensor
- if self.dtype == ir.DataType.BFLOAT16:
- return self.raw.view(torch.uint16).__array__(dtype)
- if self.dtype in {
- ir.DataType.FLOAT8E4M3FN,
- ir.DataType.FLOAT8E4M3FNUZ,
- ir.DataType.FLOAT8E5M2,
- ir.DataType.FLOAT8E5M2FNUZ
- }:
- return self.raw.view(torch.uint8).__array__(dtype)
- return self.raw.__array__(dtype)
-
- def tobytes(self) -> bytes:
- # Implement tobytes to support native PyTorch types so we can use types like bloat16
- # Reading from memory directly is also more efficient because
- # it avoids the copy to NumPy array
- tensor = self.raw.detach().cpu().contiguous()
- return bytes(
- (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
- tensor.data_ptr()
- )
- )
-
- # Test the implementation
- torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)
- tensor = TorchTensor(torch_tensor)
- print("tensor: ", tensor)
- print("numpy: ", tensor.numpy())
- print("tobytes: ", tensor.tobytes()) # b'\x80?\x00@@@'
- print("nbytes: ", tensor.nbytes) # 6
-```
-
-The `TorchTensor` class above implements `tobytes()` to produce the correct bytes representation for the tensor when it is serialized into an ONNX file / TensorProto. The class also implements the `__array__()` method to return the bit representation for types NumPy does not support. This way analysis passes can still perform computation on these values.
-
-### Computation with different Frameworks
-
-Since `ir.Tensor` implements the `__array__` method and `__dlpack__` methods, its content can be shared with computation frameworks without copying. For example:
-
-```{eval-rst}
-.. exec_code::
-
- from onnxscript import ir
-
- # We can call numpy methods directly on ir.Tensor
- import numpy as np
- print(np.multiply(ir.Tensor(np.array([1, 2])), 42)) # array([42., 84.])
-
- # We can transfer arrays to different frameworks
- import jax.numpy as jnp
- import jax
- import torch
-
- # Create ir.Tensor
- jax_array = jnp.array([10., 20.])
- ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT)
- torch_tensor = torch.tensor([30., 40.])
- ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)
-
- # Use numpy for computation
- print(np.multiply(ir_tensor_jax, ir_tensor_torch)) # array([300., 800.], dtype=float32)
-
- # Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same
- jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch)
- print(jax_array_from_ir + jax_array) # [40. 60.]
-
- # Use PyTorch for computation
- torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax)
- print(torch_tensor_from_ir - torch_tensor) # tensor([-20., -20.])
-
- # They can all be serialized into TensorProto
- proto = ir.serde.serialize_tensor(ir_tensor_jax)
- print(type(proto)) #
- print(proto)
-
- # The value is exactly the same as jax_array
- print(ir.serde.deserialize_tensor(proto).numpy()) # [10. 20.]
-```
-
-This is particularly useful if you are creating passes on the graph that requires doing computation on concrete values. You are free to use your favorite frameworks to create the passes. The transformed graph that contains newly created `ir.Tensor`s will be compatible with downstream passes even if they leverage other computation frameworks.
diff --git a/docs/ir/index.md b/docs/ir/index.md
new file mode 100644
index 0000000000..ae6b0802b5
--- /dev/null
+++ b/docs/ir/index.md
@@ -0,0 +1,5 @@
+# ONNX IR
+
+ONNX IR is now an official ONNX project! Documentation has been migrated to [onnx.ai/ir-py/](https://onnx.ai/ir-py/).
+
+You may continue to use `onnxscript.ir` unchanged for compatibility with older (<0.3) versions of ONNX Script.
diff --git a/docs/test/test_documentation_examples.py b/docs/test/test_documentation_examples.py
index dcdcde2818..3cf7ac3b30 100644
--- a/docs/test/test_documentation_examples.py
+++ b/docs/test/test_documentation_examples.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
import os
import subprocess
@@ -36,6 +34,9 @@ def do_test_folder(self, folder):
if tested == 0:
raise RuntimeError(f"No example was tested in folder {folder}.")
+ @unittest.skipIf(
+ sys.platform != "linux", reason="No need to run the documentation on every OS."
+ )
def test_documentation_examples(self):
this = os.path.abspath(os.path.dirname(__file__))
onxc = os.path.normpath(os.path.join(this, "..", ".."))
diff --git a/docs/tutorial/examples/dropout.py b/docs/tutorial/examples/dropout.py
index 850b22edc4..4530c7f34d 100644
--- a/docs/tutorial/examples/dropout.py
+++ b/docs/tutorial/examples/dropout.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/firstdim.py b/docs/tutorial/examples/firstdim.py
index 187fedf569..63476949fd 100644
--- a/docs/tutorial/examples/firstdim.py
+++ b/docs/tutorial/examples/firstdim.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/forloop.py b/docs/tutorial/examples/forloop.py
index 3b32b1a0eb..75a13205d7 100644
--- a/docs/tutorial/examples/forloop.py
+++ b/docs/tutorial/examples/forloop.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/forwhileloop.py b/docs/tutorial/examples/forwhileloop.py
index 100f246c76..ffca170d43 100644
--- a/docs/tutorial/examples/forwhileloop.py
+++ b/docs/tutorial/examples/forwhileloop.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/hardmax_end_to_end.py b/docs/tutorial/examples/hardmax_end_to_end.py
index e4cd881eb3..9b49a5ca77 100644
--- a/docs/tutorial/examples/hardmax_end_to_end.py
+++ b/docs/tutorial/examples/hardmax_end_to_end.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
import onnx
# We use ONNX opset 15 to define the function below.
diff --git a/docs/tutorial/examples/leaky_relu.py b/docs/tutorial/examples/leaky_relu.py
index 92fce52b10..e1d09a2a3d 100644
--- a/docs/tutorial/examples/leaky_relu.py
+++ b/docs/tutorial/examples/leaky_relu.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/leaky_relu_attr_promoted.py b/docs/tutorial/examples/leaky_relu_attr_promoted.py
index eb736162e3..058dc19366 100644
--- a/docs/tutorial/examples/leaky_relu_attr_promoted.py
+++ b/docs/tutorial/examples/leaky_relu_attr_promoted.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/omitted_input.py b/docs/tutorial/examples/omitted_input.py
index b4e839dd26..df35f49686 100644
--- a/docs/tutorial/examples/omitted_input.py
+++ b/docs/tutorial/examples/omitted_input.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/outerscope_redef_error.py b/docs/tutorial/examples/outerscope_redef_error.py
index a810e8eb71..41bd820d93 100644
--- a/docs/tutorial/examples/outerscope_redef_error.py
+++ b/docs/tutorial/examples/outerscope_redef_error.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import graph, script
from onnxscript import opset15 as op
@@ -13,7 +15,7 @@ def Sum(sum_in, next):
return sum_out, sum_out
g = op.Constant(value=1)
- all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1)
+ _all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1)
return cumulative_sum
except Exception as e:
diff --git a/docs/tutorial/examples/scanloop.py b/docs/tutorial/examples/scanloop.py
index c12da498da..6a409716a7 100644
--- a/docs/tutorial/examples/scanloop.py
+++ b/docs/tutorial/examples/scanloop.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import graph, script
from onnxscript import opset15 as op
@@ -9,5 +11,5 @@ def Sum(sum_in, next):
sum_out = sum_in + next
return sum_out, sum_out
- all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1)
+ _all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1)
return cumulative_sum
diff --git a/docs/tutorial/examples/softplus.py b/docs/tutorial/examples/softplus.py
index 0929bc0a0b..18c194ea5d 100644
--- a/docs/tutorial/examples/softplus.py
+++ b/docs/tutorial/examples/softplus.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
# We use ONNX opset 15 to define the function below.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/tensor_attr.py b/docs/tutorial/examples/tensor_attr.py
index de24de9f70..312ad7c5eb 100644
--- a/docs/tutorial/examples/tensor_attr.py
+++ b/docs/tutorial/examples/tensor_attr.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnx import TensorProto, helper
from onnxscript import opset15 as op
diff --git a/docs/tutorial/examples/tensor_attr2.py b/docs/tutorial/examples/tensor_attr2.py
index a602b914c8..eb60b04bcd 100644
--- a/docs/tutorial/examples/tensor_attr2.py
+++ b/docs/tutorial/examples/tensor_attr2.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnx import TensorProto, helper
from onnxscript import opset15 as op
diff --git a/docs/tutorial/examples/tensor_attr_short.py b/docs/tutorial/examples/tensor_attr_short.py
index ddf32295cf..b6a2452b9b 100644
--- a/docs/tutorial/examples/tensor_attr_short.py
+++ b/docs/tutorial/examples/tensor_attr_short.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnxscript import opset15 as op
from onnxscript import script
diff --git a/docs/tutorial/examples/whileloop.py b/docs/tutorial/examples/whileloop.py
index 36b153c810..68bcfbea46 100644
--- a/docs/tutorial/examples/whileloop.py
+++ b/docs/tutorial/examples/whileloop.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from onnx import TensorProto
from onnx.helper import make_tensor
diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md
index f3d2173270..708793a8a0 100644
--- a/docs/tutorial/index.md
+++ b/docs/tutorial/index.md
@@ -123,7 +123,9 @@ subsequently modified, this modification has no effect on the attribute-value
or the ONNX function/model created. This may potentially cause the behavior
of eager-mode execution to be inconsistent with the ONNX construct generated.
-Thus, the example shown above is equivalent to the following:
+Thus, the second assignment to `script_const` in the following code has no effect
+on the subsequent call to `tensor_attr.to_function_proto()`, which will use the
+original value of `script_const`:
```{literalinclude} examples/tensor_attr2.py
```
@@ -268,7 +270,7 @@ ONNX perspective, the two assignments to *g* represent two distinct tensors
```{toctree}
:maxdepth: 1
-optimizer/index
rewriter/index
-```
+optimizer/index
+```
diff --git a/docs/tutorial/optimizer/optimize.md b/docs/tutorial/optimizer/optimize.md
index 5ceb7dfb80..8ff36f4c67 100644
--- a/docs/tutorial/optimizer/optimize.md
+++ b/docs/tutorial/optimizer/optimize.md
@@ -15,6 +15,7 @@ onnxscript.optimizer.optimize(model)
```
### optimize API
+
The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization.
```{eval-rst}
@@ -24,12 +25,8 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th
## Description of optimizations applied by `onnxscript.optimizer.optimize`
-:::{table}
-:widths: auto
-:align: center
-
-| Optimization 'onnxscript.optimizer.` + .. | Description |
-| - | - |
+| Optimization | Description |
+|-------------|-------------|
| **Constant folding**
`constant_folding.fold_constants` | Applies constant folding optimization to the model. |
| **Constant propagation**
`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. |
| **Sequence simplification**
`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. |
@@ -37,17 +34,3 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th
| **Remove unused functions**
`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. |
| **Inline functions with unused outputs**
`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. |
| **Inline simple functions**
`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. |
-:::
-
-## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize`
-
-```{eval-rst}
-.. autosummary::
- :nosignatures:
-
- onnxscript.rewriter.broadcast_to_matmul
- onnxscript.rewriter.cast_constant_of_shape
- onnxscript.rewriter.gemm_to_matmul_add
- onnxscript.rewriter.no_op
-
-```
diff --git a/docs/tutorial/rewriter/allow_other_inputs.md b/docs/tutorial/rewriter/allow_other_inputs.md
new file mode 100644
index 0000000000..29ccabca03
--- /dev/null
+++ b/docs/tutorial/rewriter/allow_other_inputs.md
@@ -0,0 +1,27 @@
+# Specifying variable inputs in the pattern
+
+This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting.
+The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs
+beyond those specified in the pattern. If it is set to `False` (the default), then the node must
+have exactly the specified inputs for a successful match. If set to `True`, the pattern will
+match nodes that have the specified inputs plus any number of additional inputs.
+
+This is particularly useful when matching operations like `Conv` that can have optional inputs
+(such as bias), or when creating generic patterns that should work with various input configurations.
+
+```{literalinclude} examples/allow_other_inputs.py
+:pyobject: conv_pattern
+```
+
+```{literalinclude} examples/allow_other_inputs.py
+:pyobject: conv_replacement
+```
+
+```{literalinclude} examples/allow_other_inputs.py
+:pyobject: apply_rewrite
+```
+
+In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation
+might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting
+`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs
+in the pattern definition.
diff --git a/docs/tutorial/rewriter/attributes.md b/docs/tutorial/rewriter/attributes.md
new file mode 100644
index 0000000000..ba72cc5ade
--- /dev/null
+++ b/docs/tutorial/rewriter/attributes.md
@@ -0,0 +1,23 @@
+# Specifying attributes in the pattern
+
+This section demonstrates the use of attribute values in pattern-based rewriting.
+First, write a target pattern and replacement pattern in a similar way to the previous examples.
+The example pattern below will match successfully only against Dropout nodes with the
+attribute value `training_mode` set to `False`.
+
+The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes
+not specified in the pattern. If it is set to `False`, then the node must have only the specified
+attribute values, and no other attributes, for a successful match. The default value for this
+option is `True`.
+
+```{literalinclude} examples/allow_other_attributes.py
+:pyobject: add_pattern
+```
+
+```{literalinclude} examples/allow_other_attributes.py
+:pyobject: add_replacement
+```
+
+```{literalinclude} examples/allow_other_attributes.py
+:pyobject: apply_rewrite
+```
diff --git a/docs/tutorial/rewriter/commute.md b/docs/tutorial/rewriter/commute.md
new file mode 100644
index 0000000000..d0690892f2
--- /dev/null
+++ b/docs/tutorial/rewriter/commute.md
@@ -0,0 +1,83 @@
+(heading-target-commute)=
+# Utilizing `commute` parameter for pattern-matching
+
+```{warning}
+Please note that the section below describes a convenience feature for handling commutative operators
+in pattern matching. However, the implementation is a simple, brute-force, technique that generates a collection
+of rewrite-rules from a given rule, taking commutativity of addition and multiplication into account. This can
+lead to an exponential increase in the number of rewrite-rules. So, it should be used with caution. Pattern
+disjunctions (_OR Patterns_) described earlier can be used judiciously to get a somewhat more efficient
+implementation in practice (even though the potential for exponential increase still exists within the
+pattern matching algorithm). Reimplementing commutativity handling using pattern disjunctions is future
+work.
+```
+
+Extending the previous [simple example](heading-target-simple), assuming a scenario where we have a graph with the following structure.
+
+{align=center width=500px}
+
+In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched.
+
+{width=330px align=left} {width=330px align=center}
+
+
+If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched.
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: erf_gelu_pattern
+```
+
+```{image} examples/img/erfgelu_06_commute.png
+:alt: The resulting graph after matching.
+:width: 400px
+:align: center
+```
+
+Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods.
+
+(heading-target-commute-ruleset)=
+
+## 1. Creating a rule-set with different patterns.
+
+This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example:
+
+```python
+from onnxscript.rewriter import pattern
+rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
+```
+
+In order to apply this method to the example above, first create the two separate target patterns as follows:
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: erf_gelu_pattern
+```
+```{literalinclude} examples/erfgelu.py
+:pyobject: erf_gelu_pattern_2
+```
+
+:::{note}
+:name: rule-application-order-matters
+
+When you pass multiple rules in `pattern_rewrite_rules`, the **order in which they appear is important**.
+This is because some rules may depend on patterns created or modified by earlier rules. For example, if `rule2` can only match after `rule1` has made a specific change in the model, then `rule1` must come **before** `rule2` in the list.
+If you're not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur.
+:::
+
+
+Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter.
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: apply_rewrite_with_ruleset
+```
+
+## 2. Using the `commute` parameter while creating a rule.
+
+Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter.
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: apply_rewrite_with_commute
+```
+
+For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows:
+
+{align=center width=300px}
diff --git a/docs/tutorial/rewriter/conditional_rewrite.md b/docs/tutorial/rewriter/conditional_rewrite.md
new file mode 100644
index 0000000000..379788e657
--- /dev/null
+++ b/docs/tutorial/rewriter/conditional_rewrite.md
@@ -0,0 +1,103 @@
+# Using the `match_condition` parameter for pattern-matching
+
+This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration.
+
+Let us consider a model which consists of the following pattern.
+
+{align=center}
+
+Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following:
+
+1. Input shapes check: `input_a` and `input_b` should be broadcastable
+2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)`
+
+If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite.
+
+First, write a target pattern and replacement pattern in a similar way to the first example.
+
+```{literalinclude} examples/broadcast_matmul.py
+:pyobject: two_reshapes_matmul_reshape_pattern
+```
+
+```{literalinclude} examples/broadcast_matmul.py
+:pyobject: matmul_pattern
+```
+
+:::{note}
+:name: omitting inputs in signature
+
+The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters.
+
+Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature.
+:::
+
+In order to validate whether matmul broadcast is sufficient, we write a condition checking function as below.
+Note that the relevant inputs passed to the check function are all instances of {py:class}`onnx_ir.Value`. These represent
+the values in the input graph IR that matched against the corresponding _pattern variables_ in the target
+pattern. Please see documentation of the [IR API](https://onnx.ai/ir-py/) for more details on how to use it, for example to identify
+the type or shape or rank of these values.
+
+```{literalinclude} examples/broadcast_matmul.py
+:pyobject: check_if_not_need_reshape
+```
+
+With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite.
+
+```{literalinclude} examples/broadcast_matmul.py
+:pyobject: apply_rewrite
+```
+
+The final graph with the applied rewrite looks as follows:
+
+{align=center}
+
+# Using MatchContext for Advanced Condition Checking
+
+The `context` parameter passed to condition functions is an instance of {py:class}`onnxscript.rewriter.MatchContext`, which provides access to additional information about the pattern match that can be useful for sophisticated condition checking.
+
+## MatchContext Properties
+
+The MatchContext provides the following read-only properties:
+
+- `model`: The entire ONNX model being matched
+- `graph_or_function`: The specific graph or function being matched
+- `root`: The root node of the matching subgraph
+- `output_values`: The output values of the matching subgraph
+- `nodes`: All nodes that are part of the matching subgraph
+
+## Example Usage
+
+Here's an example showing how to use the MatchContext to implement more sophisticated condition checking:
+
+```python
+def advanced_condition_check(context, x, y, **_):
+ """Example condition function using MatchContext."""
+
+ # Access the main node of the pattern match
+ main_node = context.root
+
+ # Check that the main_node does not have an attribute called "alpha"
+ if "alpha" in main_node.attributes:
+ return False
+
+ # Access the broader graph context and check that x occurs as a graph-input
+ model = context.model
+ if x not in model.graph.inputs:
+ return False
+
+ # You can inspect the matched nodes for advanced validation
+ for node in context.nodes:
+ if node.op_type == "Constant":
+ # Check properties of constant nodes in the match
+ pass
+
+ # Access output values for shape/type validation
+ outputs = context.output_values
+ if len(outputs) > 0 and outputs[0].shape is not None:
+ # Validate output shapes
+ pass
+
+ return True
+```
+
+This context information enables condition functions to make decisions based on the broader graph structure, the specific nodes involved in the match, and relationships between matched patterns and the rest of the model.
diff --git a/docs/tutorial/rewriter/domain_option.md b/docs/tutorial/rewriter/domain_option.md
new file mode 100644
index 0000000000..30a7384b59
--- /dev/null
+++ b/docs/tutorial/rewriter/domain_option.md
@@ -0,0 +1,38 @@
+# Specifying domains in the pattern
+
+This section demonstrates the use of the `_domain` option in pattern-based rewriting.
+The `_domain` option allows you to specify which operator domain the pattern should match against,
+and also allows you to create replacement operations in specific domains.
+
+ONNX operators can belong to different domains:
+- The default ONNX domain (empty string or "ai.onnx")
+- Custom domains like "com.microsoft" for Microsoft-specific operations
+- User-defined domains for custom operations
+
+## Matching operations from a specific domain
+
+```{literalinclude} examples/domain_option.py
+:pyobject: custom_relu_pattern
+```
+
+In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the
+"custom.domain" domain will be matched, not standard ONNX `Relu` operations.
+
+## Creating replacement operations in a specific domain
+
+```{literalinclude} examples/domain_option.py
+:pyobject: microsoft_relu_replacement
+```
+
+Here, the replacement operation is created in the "com.microsoft" domain, which might
+provide optimized implementations of standard operations.
+
+## Complete rewrite example
+
+```{literalinclude} examples/domain_option.py
+:pyobject: apply_rewrite
+```
+
+This example shows how domain-specific pattern matching can be used to migrate operations
+between different operator domains, such as replacing custom domain operations with
+standard ONNX operations or vice versa.
diff --git a/docs/tutorial/rewriter/examples/allow_other_attributes.py b/docs/tutorial/rewriter/examples/allow_other_attributes.py
new file mode 100644
index 0000000000..67e14ad659
--- /dev/null
+++ b/docs/tutorial/rewriter/examples/allow_other_attributes.py
@@ -0,0 +1,67 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Onnx Pattern Rewriting with attributes
+
+This script shows how to define a rewriting rule based on patterns that
+are dependent on the attributes of the nodes.
+"""
+
+import onnx
+
+import onnxscript
+from onnxscript import FLOAT, opset18, script
+from onnxscript.rewriter import pattern
+
+
+@script()
+def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]:
+ add = opset18.Add(A, B)
+ result = opset18.Dropout(add, training_mode=False)
+ return result
+
+
+_model = original_model.to_model_proto()
+onnx.checker.check_model(_model)
+
+
+####################################
+# The target pattern
+# =====================
+
+
+def add_pattern(op, input):
+ return op.Dropout(input, training_mode=False, _allow_other_attributes=True)
+
+
+####################################
+# The replacement pattern
+# =====================
+
+
+def add_replacement(op, input, **_):
+ return op.Identity(input)
+
+
+####################################
+# Create Rewrite Rule and Apply to Model
+# =====================
+
+
+def apply_rewrite(model):
+ # Create rewrite rules
+ add_rule = pattern.RewriteRule(
+ add_pattern, # target pattern
+ add_replacement, # replacement pattern
+ )
+ # Create a Rewrite Rule Set
+ rewrite_rule_set = pattern.RewriteRuleSet([add_rule])
+ # Apply rewrite while passing match_condition
+ model_with_rewrite = onnxscript.rewriter.rewrite(
+ model,
+ pattern_rewrite_rules=rewrite_rule_set,
+ )
+ return model_with_rewrite
+
+
+_model_with_rewrite = apply_rewrite(_model)
+onnx.checker.check_model(_model_with_rewrite)
diff --git a/docs/tutorial/rewriter/examples/allow_other_inputs.py b/docs/tutorial/rewriter/examples/allow_other_inputs.py
new file mode 100644
index 0000000000..cc3a3d926f
--- /dev/null
+++ b/docs/tutorial/rewriter/examples/allow_other_inputs.py
@@ -0,0 +1,71 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""ONNX Pattern Rewriting with variable number of inputs
+
+This script shows how to define a rewriting rule based on patterns that
+can match nodes with additional inputs beyond those specified in the pattern.
+"""
+
+import onnx
+
+import onnxscript
+from onnxscript import FLOAT, opset18, script
+from onnxscript.rewriter import pattern
+
+
+@script()
+def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]:
+ # Conv with bias - has 3 inputs: input, weight, bias
+ result = opset18.Conv(A, B, C)
+ return result
+
+
+_model = original_model.to_model_proto()
+onnx.checker.check_model(_model)
+
+
+####################################
+# The target pattern
+# =====================
+
+
+def conv_pattern(op, input, weight):
+ # Pattern to match Conv operations, allowing additional inputs like bias
+ # _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs)
+ # even though we only specify 2 inputs in the pattern
+ return op.Conv(input, weight, _allow_other_inputs=True)
+
+
+####################################
+# The replacement pattern
+# =====================
+
+
+def conv_replacement(op, input, weight, **_):
+ # Replace with a custom operation in a different domain
+ return op.OptimizedConv(input, weight, _domain="custom.domain")
+
+
+####################################
+# Create Rewrite Rule and Apply to Model
+# =====================
+
+
+def apply_rewrite(model):
+ # Create rewrite rules
+ conv_rule = pattern.RewriteRule(
+ conv_pattern, # target pattern
+ conv_replacement, # replacement pattern
+ )
+ # Create a Rewrite Rule Set
+ rewrite_rule_set = pattern.RewriteRuleSet([conv_rule])
+ # Apply rewrite
+ model_with_rewrite = onnxscript.rewriter.rewrite(
+ model,
+ pattern_rewrite_rules=rewrite_rule_set,
+ )
+ return model_with_rewrite
+
+
+_model_with_rewrite = apply_rewrite(_model)
+onnx.checker.check_model(_model_with_rewrite)
diff --git a/docs/tutorial/rewriter/examples/broadcast_matmul.py b/docs/tutorial/rewriter/examples/broadcast_matmul.py
index 22b374e5b2..cf56b49f07 100644
--- a/docs/tutorial/rewriter/examples/broadcast_matmul.py
+++ b/docs/tutorial/rewriter/examples/broadcast_matmul.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Onnx Pattern Rewriting with match condition parameter.
This script shows how to define a rewriting rule based on patterns while
@@ -9,12 +11,11 @@
import logging
-import numpy as np
import onnx
import onnxscript
from onnxscript import FLOAT, ir, opset18, script
-from onnxscript.rewriter import _ir_utils, pattern
+from onnxscript.rewriter import pattern
logger = logging.getLogger(__name__)
@@ -40,14 +41,12 @@ def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1
# The target pattern
# =====================
-_op = pattern.onnxop
-
-def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c):
- reshape_a = _op.Reshape(input_a, shape_a)
- reshape_b = _op.Reshape(input_b, shape_b)
- matmul = _op.MatMul(reshape_a, reshape_b)
- return _op.Reshape(matmul, shape_c)
+def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
+ reshape_a = op.Reshape(input_a, shape_a)
+ reshape_b = op.Reshape(input_b, shape_b)
+ matmul = op.MatMul(reshape_a, reshape_b)
+ return op.Reshape(matmul, shape_c)
####################################
@@ -65,72 +64,79 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
def check_if_not_need_reshape(
- input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
+ context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
- """If matmul broadcasting is enough, then we don't need the reshapes.
+ """Condition to check if we need to replace the pattern.
+
+ If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
If the above are true, then we don't need the reshapes.
+
+ Returns:
+ True if we need to replace the pattern, False otherwise.
"""
input_a_shape = input_a.shape
input_b_shape = input_b.shape
- # TODO: Get a helper func to get const_value
- shape_c_value = _ir_utils.propagate_const_value(shape_c)
- shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr]
- if shape_c is None:
- return False
- if not isinstance(shape_c, np.ndarray):
- logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c))
+ shape_c_tensor = shape_c.const_value
+ if shape_c_tensor is None:
+ logger.info("The value 'shape_c' is not statically known.")
return False
- if len(shape_c.shape) != 1:
+
+ if len(shape_c_tensor.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
- shape_c.shape,
+ shape_c_tensor.shape,
)
return False
- shape_c_list = shape_c.tolist()
# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
- if input_a_shape is None or input_b_shape is None or shape_c is None:
+ if input_a_shape is None or input_b_shape is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
- input_a_shape = list(input_a_shape)
- input_b_shape = list(input_b_shape)
+ input_a_shape = input_a_shape.numpy()
+ input_b_shape = input_b_shape.numpy()
+ shape_c = shape_c_tensor.numpy().tolist()
+
+ a_rank = len(input_a_shape)
+ b_rank = len(input_b_shape)
- dim_a = len(input_a_shape)
- dim_b = len(input_b_shape)
+ # TODO(justinchuby): Check shape size
# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
- if dim_a < 2:
+ if a_rank < 2:
+ if b_rank < 2:
+ logger.info("Optimization of dot product is not supported yet.")
+ return False
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
- dim_a = len(input_a_shape)
+ a_rank = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
- if dim_b < 2:
+ if b_rank < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
- dim_b = len(input_b_shape)
+ b_rank = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
- input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]]
+ input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
@@ -150,23 +156,27 @@ def check_if_not_need_reshape(
# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
- if dim_a > dim_b:
+ if a_rank > b_rank:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
- broadcast_matmul_output_shape = (
- longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape
- )
- if mimic_matmul_broadcast_behavior and dim_b == 2:
+ broadcast_matmul_output_shape = [
+ *longer_shape[: -len(shorter_shape)],
+ *broadcast_matmul_output_shape,
+ ]
+ if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:
+ # If input_b is expanded to 2-D, then we need to remove the last dimension
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
- if mimic_matmul_broadcast_behavior and dim_a == 2:
+ if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:
+ # If input_a is expanded to 2-D, then we need to remove the first dimension
+ # of input_a, which would be the -2nd dimension of the output shape.
broadcast_matmul_output_shape.pop(-2)
- if shape_c_list != broadcast_matmul_output_shape:
+ if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
- shape_c_list,
+ shape_c,
broadcast_matmul_output_shape,
)
return False
diff --git a/docs/tutorial/rewriter/examples/domain_option.py b/docs/tutorial/rewriter/examples/domain_option.py
new file mode 100644
index 0000000000..7018c04719
--- /dev/null
+++ b/docs/tutorial/rewriter/examples/domain_option.py
@@ -0,0 +1,86 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""ONNX Pattern Rewriting with domain specification
+
+This script shows how to define a rewriting rule that targets operations
+from specific domains and replaces them with operations in other domains.
+"""
+
+import onnx
+
+import onnxscript
+from onnxscript import script
+from onnxscript.rewriter import pattern
+from onnxscript.values import Opset
+
+# Create an opset for the custom domain
+opset = Opset("custom.domain", 1)
+
+
+@script(opset)
+def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]:
+ """Create a model with a Relu operation in a custom domain."""
+ return opset.Relu(input)
+
+
+_model = create_model_with_custom_domain.to_model_proto()
+_model = onnx.shape_inference.infer_shapes(_model)
+onnx.checker.check_model(_model)
+
+
+####################################
+# The target pattern
+# =====================
+
+
+def custom_relu_pattern(op, input):
+ # Pattern to match Relu operations from a specific domain
+ # _domain="custom.domain" specifies we only want to match operations from this domain
+ return op.Relu(input, _domain="custom.domain")
+
+
+####################################
+# The replacement pattern
+# =====================
+
+
+def standard_relu_replacement(op, input, **_):
+ # Replace with standard ONNX Relu (default domain)
+ return op.Relu(input)
+
+
+####################################
+# Alternative: Replace with operation in different domain
+# =====================
+
+
+def microsoft_relu_replacement(op, input, **_):
+ # Replace with operation in Microsoft's domain
+ return op.OptimizedRelu(input, _domain="com.microsoft")
+
+
+####################################
+# Create Rewrite Rule and Apply to Model
+# =====================
+
+
+def apply_rewrite(model):
+ # Create rewrite rules
+ relu_rule = pattern.RewriteRule(
+ custom_relu_pattern, # target pattern - matches custom domain operations
+ standard_relu_replacement, # replacement pattern - uses standard domain
+ )
+ # Create a Rewrite Rule Set
+ rewrite_rule_set = pattern.RewriteRuleSet([relu_rule])
+ # Apply rewrite
+ model_with_rewrite = onnxscript.rewriter.rewrite(
+ model,
+ pattern_rewrite_rules=rewrite_rule_set,
+ )
+ return model_with_rewrite
+
+
+# The rewrite rule will now match the Relu operation in the custom domain
+# and replace it with a standard ONNX Relu operation
+_model_with_rewrite = apply_rewrite(_model)
+onnx.checker.check_model(_model_with_rewrite)
diff --git a/docs/tutorial/rewriter/examples/erfgelu.py b/docs/tutorial/rewriter/examples/erfgelu.py
index f8723da594..e042d9f337 100644
--- a/docs/tutorial/rewriter/examples/erfgelu.py
+++ b/docs/tutorial/rewriter/examples/erfgelu.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Onnx Pattern Rewriting.
This script shows how to define a rewriting rule based on patterns.
@@ -70,15 +72,13 @@ def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
# The target pattern
# =====================
-_op = pattern.onnxop
+def erf_gelu_pattern(op, x):
+ return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
-def erf_gelu_pattern(x):
- return 0.5 * (x * (_op.Erf(x / math.sqrt(2)) + 1.0))
-
-def erf_gelu_pattern_2(x):
- return (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
+def erf_gelu_pattern_2(op, x):
+ return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
####################################
@@ -87,7 +87,7 @@ def erf_gelu_pattern_2(x):
def gelu(op, x: ir.Value):
- return op.Gelu(x, domain="com.microsoft")
+ return op.Gelu(x, _domain="com.microsoft")
####################################
@@ -98,7 +98,7 @@ def gelu(op, x: ir.Value):
def apply_rewrite(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
- gelu, # Replacement Pattern
+ gelu, # Replacement
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
@@ -107,15 +107,31 @@ def apply_rewrite(model):
return model_with_rewrite_applied
+####################################
+# Rewrite Rule as a Class
+# =====================
+
+
+class ErfGeluFusion(pattern.RewriteRuleClassBase):
+ def pattern(self, op, x):
+ return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
+
+ def rewrite(self, op, x):
+ return op.Gelu(x, _domain="com.microsoft")
+
+
+erf_gelu_rule_from_class = ErfGeluFusion.rule()
+
+
def apply_rewrite_with_ruleset(model):
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
- gelu, # Replacement Pattern
+ gelu, # Replacement
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
- gelu, # Replacement Pattern
+ gelu, # Replacement
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
@@ -131,7 +147,7 @@ def apply_rewrite_with_ruleset(model):
def apply_rewrite_with_commute(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
- gelu, # Replacement Pattern
+ gelu, # Replacement
)
# Create a Rewrite Rule Set with commute=True
rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
diff --git a/docs/tutorial/rewriter/examples/or_pattern.py b/docs/tutorial/rewriter/examples/or_pattern.py
new file mode 100644
index 0000000000..0e9231cc1f
--- /dev/null
+++ b/docs/tutorial/rewriter/examples/or_pattern.py
@@ -0,0 +1,93 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""OR-patterns.
+
+This script shows how to define a rewriting rule based on OR-patterns.
+"""
+
+import onnx
+
+import onnxscript
+from onnxscript import FLOAT, opset18, script
+from onnxscript.rewriter import pattern
+
+####################################
+# The target pattern
+# =====================
+
+
+def scaled_matmul(op, x, y, factor):
+ xy = op.MatMul(x, y)
+ choice1 = op.Mul(xy, factor)
+ choice2 = op.Div(xy, factor)
+ scaled_xy = pattern.OrValue(
+ [choice1, choice2], tag_var="op_type", tag_values=["Mul", "Div"]
+ )
+ return op.Relu(scaled_xy)
+
+
+####################################
+# The replacement pattern
+# =====================
+
+
+def scaled_matmul_replacement(op, x, y, factor, op_type):
+ if op_type == "Mul":
+ return op.MatMulMulRelu(x, y, factor, _domain="some.domain")
+ elif op_type == "Div":
+ return op.MatMulDivRelu(x, y, factor, _domain="some.domain")
+ else:
+ raise ValueError(f"Unknown operation type: {op_type}")
+
+
+####################################
+# Rewrite Rule
+# =====================
+def apply_rewrite(model):
+ rule = pattern.RewriteRule(
+ scaled_matmul, # target pattern
+ scaled_matmul_replacement, # replacement pattern
+ )
+ # Create a Rewrite Rule Set
+ rewrite_rule_set = pattern.RewriteRuleSet([rule])
+ return onnxscript.rewriter.rewrite(
+ model,
+ pattern_rewrite_rules=rewrite_rule_set,
+ )
+
+
+@script()
+def original_model1(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]:
+ t1 = opset18.MatMul(A, B)
+ c = opset18.Constant(value_float=2.0)
+ t2 = opset18.Mul(t1, c)
+ t3 = opset18.Relu(t2)
+ return t3
+
+
+_model = original_model1.to_model_proto()
+onnx.checker.check_model(_model)
+
+_model_with_rewrite = apply_rewrite(_model)
+onnx.checker.check_model(_model_with_rewrite)
+
+assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulMulRelu"]
+
+
+@script()
+def original_model2(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]:
+ t1 = opset18.MatMul(A, B)
+ c = opset18.Constant(value_float=2.0)
+ t2 = opset18.Div(t1, c)
+ t3 = opset18.Relu(t2)
+ return t3
+
+
+_model = original_model2.to_model_proto()
+onnx.checker.check_model(_model)
+
+_model_with_rewrite = apply_rewrite(_model)
+onnx.checker.check_model(_model_with_rewrite)
+
+assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulDivRelu"]
diff --git a/docs/tutorial/rewriter/examples/outputs_option.py b/docs/tutorial/rewriter/examples/outputs_option.py
new file mode 100644
index 0000000000..88483385dc
--- /dev/null
+++ b/docs/tutorial/rewriter/examples/outputs_option.py
@@ -0,0 +1,76 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""ONNX Pattern Rewriting with output specification
+
+This script shows how to define a rewriting rule that specifies
+the number and names of outputs from operations.
+"""
+
+import onnx
+
+import onnxscript
+from onnxscript import FLOAT, opset18, script
+from onnxscript.rewriter import pattern
+
+
+@script()
+def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]:
+ # Split operation that produces 2 outputs
+ result1, _result2 = opset18.Split(A, num_outputs=2, axis=0)
+ # We only return the first output for simplicity
+ return result1
+
+
+_model = original_model.to_model_proto()
+onnx.checker.check_model(_model)
+
+
+####################################
+# The target pattern with multiple outputs
+# =====================
+
+
+def split_pattern(op, input):
+ # Pattern to match Split operations with 2 outputs
+ # num_outputs=2 corresponds to the attribute of the ONNX Split op
+ # _outputs=2 is an option controlling the pattern constructor
+ return op.Split(input, num_outputs=2, axis=0, _outputs=2)
+
+
+####################################
+# The replacement pattern with named outputs
+# =====================
+
+
+def custom_split_replacement(op, input, **_):
+ # Replace with a custom split operation using named outputs
+ # _outputs=["first_half", "second_half"] assigns names to the outputs
+ # IMPORTANT: The number of outputs must match the pattern (2 outputs)
+ return op.CustomSplit(
+ input, _domain="custom.domain", _outputs=["first_half", "second_half"]
+ )
+
+
+####################################
+# Create Rewrite Rule and Apply to Model
+# =====================
+
+
+def apply_rewrite(model):
+ # Create rewrite rules
+ split_rule = pattern.RewriteRule(
+ split_pattern, # target pattern - matches Split with 2 outputs
+ custom_split_replacement, # replacement pattern - uses named outputs
+ )
+ # Create a Rewrite Rule Set
+ rewrite_rule_set = pattern.RewriteRuleSet([split_rule])
+ # Apply rewrite
+ model_with_rewrite = onnxscript.rewriter.rewrite(
+ model,
+ pattern_rewrite_rules=rewrite_rule_set,
+ )
+ return model_with_rewrite
+
+
+_model_with_rewrite = apply_rewrite(_model)
+onnx.checker.check_model(_model_with_rewrite)
diff --git a/docs/tutorial/rewriter/index.md b/docs/tutorial/rewriter/index.md
index 3b4e01e149..d86ae9a474 100644
--- a/docs/tutorial/rewriter/index.md
+++ b/docs/tutorial/rewriter/index.md
@@ -1,4 +1,4 @@
-# Rewriter Tutorials
+# Rewriter Tutorial
```{toctree}
rewrite_patterns
diff --git a/docs/tutorial/rewriter/node_value_checkers.md b/docs/tutorial/rewriter/node_value_checkers.md
new file mode 100644
index 0000000000..e9e5661431
--- /dev/null
+++ b/docs/tutorial/rewriter/node_value_checkers.md
@@ -0,0 +1,187 @@
+(heading-target-checkers)=
+# Node and Value Level Checkers
+
+The pattern matching infrastructure supports custom validation logic at both the node and value levels through checker functions. These checkers allow for more sophisticated pattern matching by enabling additional constraints beyond basic operator and structure matching.
+
+## Value-Level Checkers
+
+Value-level checkers validate properties of specific values in the pattern. They are particularly useful for checking constants, shapes, or other value-specific properties.
+
+### Basic Usage
+
+A value checker is a function that takes a `MatchContext` and an `ir.Value`, and returns either a boolean or a `MatchResult`:
+
+```python
+def is_positive_constant(context, value: ir.Value):
+ """Check if a value is a positive constant."""
+ if value.const_value is not None:
+ # Get the numpy array from const_value
+ numpy_array = value.const_value.numpy()
+
+ # Check if it represents a single value and is positive
+ if numpy_array.size != 1:
+ return False
+
+ return float(numpy_array.item()) > 0
+
+ return False
+```
+
+You can use this checker directly in your pattern by passing the callable as an input:
+
+```python
+def add_pattern(op, x, y):
+ # Use callable as input to create ValuePattern with checker
+ return op.Add(is_positive_constant, y)
+```
+
+This pattern will only match `Add` operations where the first input is a positive constant value.
+
+### Example Usage
+
+```python
+from onnxscript.rewriter import pattern
+from onnxscript import ir, optimizer
+import onnx
+
+# Create a model with different Add operations
+model_proto = onnx.parser.parse_model("""
+
+ agraph (float[N] x, float[N] y) => (float[N] z1, float[N] z2, float[N] z3)
+ {
+ pos_const = Constant ()
+ neg_const = Constant ()
+ z1 = Add(x, y) # non-constant first parameter
+ z2 = Add(pos_const, y) # positive constant first parameter
+ z3 = Add(neg_const, y) # negative constant first parameter
+ }
+""")
+model = ir.serde.deserialize_model(model_proto)
+
+# Apply constant propagation to set const_value fields
+optimizer.basic_constant_propagation(model.graph.all_nodes())
+
+# Create the pattern with value checker
+rule_pattern = pattern.Pattern(add_pattern)
+
+# Test matching against different Add nodes
+add_nodes = [node for node in model.graph if node.op_type == "Add"]
+
+# Non-constant first parameter - will not match
+match_result = rule_pattern.match(model, model.graph, add_nodes[0])
+print(f"Non-constant: {bool(match_result)}") # False
+
+# Positive constant first parameter - will match
+match_result = rule_pattern.match(model, model.graph, add_nodes[1])
+print(f"Positive constant: {bool(match_result)}") # True
+
+# Negative constant first parameter - will not match
+match_result = rule_pattern.match(model, model.graph, add_nodes[2])
+print(f"Negative constant: {bool(match_result)}") # False
+```
+
+## Node-Level Checkers
+
+Node-level checkers validate properties of the operation nodes themselves, such as attributes, operation types, or other node-specific properties.
+
+### Basic Usage
+
+A node checker is a function that takes a `MatchContext` and an `ir.Node`, and returns either a boolean or a `MatchResult`:
+
+```python
+def shape_node_checker(context, node):
+ """Check if a Shape operation has start attribute equal to 0."""
+ return node.attributes.get_int("start", 0) == 0
+```
+
+You can use this checker by passing it to the `_check` parameter of an operation:
+
+```python
+def shape_pattern(op, x):
+ return op.Shape(x, _check=shape_node_checker)
+```
+
+This pattern will only match `Shape` operations where the `start` attribute is 0 (or not present, as the default is 0).
+
+### Example Usage
+
+```python
+from onnxscript.rewriter import pattern
+from onnxscript import ir
+import onnx
+
+# Create a model with different Shape operations
+model_proto = onnx.parser.parse_model("""
+
+ agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3)
+ {
+ z1 = Shape(x)
+ z2 = Shape (x)
+ z3 = Shape (x)
+ }
+""")
+model = ir.serde.deserialize_model(model_proto)
+
+# Create the pattern with node checker
+rule_pattern = pattern.Pattern(shape_pattern)
+
+# Test matching against different Shape nodes
+nodes = list(model.graph)
+shape_nodes = [node for node in nodes if node.op_type == "Shape"]
+
+# Shape without start attribute (default 0) - will match
+match_result = rule_pattern.match(model, model.graph, shape_nodes[0])
+print(f"No start attr: {bool(match_result)}") # True
+
+# Shape with start=0 - will match
+match_result = rule_pattern.match(model, model.graph, shape_nodes[1])
+print(f"Start=0: {bool(match_result)}") # True
+
+# Shape with start=1 - will not match
+match_result = rule_pattern.match(model, model.graph, shape_nodes[2])
+print(f"Start=1: {bool(match_result)}") # False
+```
+
+## Combining Checkers
+
+You can combine both node-level and value-level checkers in the same pattern for more sophisticated matching:
+
+```python
+def complex_pattern(op, x, y):
+ # Value-level checker for first input
+ validated_x = is_positive_constant
+ # Node-level checker for the operation
+ return op.Add(validated_x, y, _check=lambda ctx, node: len(node.attributes) == 0)
+```
+
+This pattern will only match `Add` operations where:
+1. The first input is a positive constant (value-level check)
+2. The node has no custom attributes (node-level check)
+
+## Execution Timing and Limitations
+
+### When Checkers Are Called
+
+Node-level and value-level checkers are called **only at the end of the complete structural match**. This means:
+
+1. **Structural matching happens first**: The pattern matching engine first validates that the graph structure matches the pattern (correct operators, connections, etc.)
+2. **Checkers run after structural validation**: Only after the structural match succeeds do the node and value checkers execute
+3. **Order of execution**: Value-level checkers run first, followed by node-level checkers, and finally the pattern's condition function
+
+### Limitations with Pattern Disjunctions
+
+One important limitation of this design is that these checks don't compose well with pattern disjunctions (multiple alternative patterns). When searching among multiple value patterns:
+
+- **Only structural checking is performed initially**: If structural matching succeeds for the first alternative, other alternatives are not considered
+- **Checker failures don't trigger backtracking**: If a checker fails, the entire pattern match fails rather than trying the next alternative pattern
+
+This means you should be careful when designing patterns with multiple alternatives that rely on checkers, as the checker logic may prevent exploration of valid alternative matches.
+
+## Error Handling
+
+Checkers can return either:
+- `True`: Check passed, continue matching
+- `False`: Check failed, pattern does not match
+- `MatchResult`: More detailed result with potential failure reasons
+
+If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions.
diff --git a/docs/tutorial/rewriter/or_pattern.md b/docs/tutorial/rewriter/or_pattern.md
new file mode 100644
index 0000000000..6c42112467
--- /dev/null
+++ b/docs/tutorial/rewriter/or_pattern.md
@@ -0,0 +1,20 @@
+# OR Patterns
+
+*Note* : This feature is work-in-progress.
+
+Consider the following pattern:
+
+```{literalinclude} examples/or_pattern.py
+:pyobject: scaled_matmul
+```
+
+This pattern will successfully match against the sequence "MatMul => Mul => Relu" as
+well as the sequence "MatMul => Div => Relu". The matcher will bind the variable
+specified in `tag_var` (`op_type` in the above example) to a value from those
+listed in `tag_values` to indicate which of the alternatives was used for a
+successful match. We can use this in the rewrite function to determine how
+we want to rewrite the matched sub-graph, as illustrated by the following code:
+
+```{literalinclude} examples/or_pattern.py
+:pyobject: scaled_matmul_replacement
+```
diff --git a/docs/tutorial/rewriter/outputs_option.md b/docs/tutorial/rewriter/outputs_option.md
new file mode 100644
index 0000000000..cc73bcc561
--- /dev/null
+++ b/docs/tutorial/rewriter/outputs_option.md
@@ -0,0 +1,43 @@
+# Specifying outputs in the pattern
+
+This section demonstrates the use of the `_outputs` option in pattern-based rewriting.
+The `_outputs` option allows you to specify the number of outputs an operation produces
+and optionally assign names to those outputs for easier reference in replacement patterns.
+
+The `_outputs` option can be specified in two ways:
+- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs
+- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs
+
+## Matching operations with multiple outputs
+
+```{literalinclude} examples/outputs_option.py
+:pyobject: split_pattern
+```
+
+This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2`
+specification ensures the pattern only matches operations with this specific output count.
+
+## Creating replacement operations with named outputs
+
+```{literalinclude} examples/outputs_option.py
+:pyobject: custom_split_replacement
+```
+
+In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with
+descriptive names. This can make the replacement pattern more readable and maintainable.
+
+**Important**: The number of outputs in the replacement pattern must match the number of
+outputs in the target pattern. Since the pattern specifies `_outputs=2`, the replacement
+must also produce exactly 2 outputs.
+
+## Complete rewrite example
+
+```{literalinclude} examples/outputs_option.py
+:pyobject: apply_rewrite
+```
+
+The `_outputs` option is particularly important when:
+- Working with operations that have variable numbers of outputs (like `Split`)
+- Creating custom operations that need specific output configurations
+- Ensuring pattern matching precision by specifying exact output counts
+- Improving code readability by naming outputs in replacement patterns
diff --git a/docs/tutorial/rewriter/rewrite_patterns.md b/docs/tutorial/rewriter/rewrite_patterns.md
index 7312380446..50615945d1 100644
--- a/docs/tutorial/rewriter/rewrite_patterns.md
+++ b/docs/tutorial/rewriter/rewrite_patterns.md
@@ -1,10 +1,8 @@
-# Pattern-based Rewrite Using Rules
+# Introduction
-## Introduction
+The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on conditional rewrite rules provided by the user.
-The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on rewrite rules provided by the user.
-
-## Usage
+# Usage
There are three main components needed when rewriting patterns in the graph:
@@ -12,188 +10,40 @@ There are three main components needed when rewriting patterns in the graph:
2. `replacement_pattern` : Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators.
3. `match_condition` (optional) : Pattern rewrite will occur only if the match condition is satisfied.
-(heading-target-simple)=
-## A Simple Example
-
-An simple example demonstrating the usage of this functionality using the `GELU` activation function:
-
-`GELU` activation function can be computed using a Gauss Error Function using the given formula:
-
-```{math}
-\text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})]
-```
-
-We will show how we can find a subgraph matching this computation and replace it by a call to the function.
-
-Firstly, include all the rewriter relevant imports.
-
-```python
-from onnxscript.rewriter import pattern
-from onnxscript import ir
-
-_op = pattern.onnxop
-```
-
-Then create a target pattern that needs to be replaced using onnxscript operators.
-
-```{literalinclude} examples/erfgelu.py
-:pyobject: erf_gelu_pattern
-```
-
-After this, create a replacement pattern that consists of the GELU onnxscript operator.
-
-```{literalinclude} examples/erfgelu.py
-:pyobject: gelu
-```
-:::{note}
-:name: type annotate ir.Value
-
-The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value ` class.
-:::
-
-
-For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function.
-
-```python
-rule = pattern.RewriteRule(
- erf_gelu_pattern, # Target Pattern
- gelu, # Replacement Pattern
-)
-```
-
-Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components:
-
-1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`.
-2. `function_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]`
-3. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjunction with `model`. This parameter is of either one of these types:
- - `Sequence[PatternRewriteRule]`
- - `RewriteRuleSet`
-
-:::{note}
-:name: pattern_rewrite_rules input formatting
-
-`pattern_rewrite_rules` takes a sequence of `PatternRewriteRule` types or a RewriteRuleSet which is also essentially a rule set created using a sequence of `PatternRewriteRule` types, so if only a singular rewrite rule is to be passed, it needs to passed as part of a sequence. For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset).
-:::
-
-The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above:
-
-```{literalinclude} examples/erfgelu.py
-:pyobject: apply_rewrite
-```
-
-The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended.
-
- 
-
+## Pattern Options
-(heading-target-commute)=
-## Utilizing `commute` parameter for pattern-matching
-Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure.
+When defining patterns, you can use several special options to control how patterns match and what they produce:
-{align=center width=500px}
+- `_allow_other_attributes`: Controls whether the pattern allows additional attributes not specified in the pattern (default: True)
+- `_allow_other_inputs`: Controls whether the pattern allows additional inputs beyond those specified (default: False)
+- `_domain`: Specifies the operator domain for matching or creating operations
+- `_outputs`: Specifies the number and optionally names of outputs from an operation
-In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched.
+These options are documented in detail in the following sections.
-{width=330px align=left} {width=330px align=center}
-
-
-If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched.
-
-```{literalinclude} examples/erfgelu.py
-:pyobject: erf_gelu_pattern
+```{include} simple_example.md
```
-```{image} examples/img/erfgelu_06_commute.png
-:alt: The resulting graph after matching.
-:width: 400px
-:align: center
+```{include} attributes.md
```
-Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods.
-
-(heading-target-commute-ruleset)=
-### 1. Creating a rule-set with different patterns.
-
-This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example:
-
-```python
-from onnxscript.rewriter import pattern
-rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
+```{include} allow_other_inputs.md
```
-In order to apply this method to the example above, first create the two separate target patterns as follows:
-
-```{literalinclude} examples/erfgelu.py
-:pyobject: erf_gelu_pattern
-```
-```{literalinclude} examples/erfgelu.py
-:pyobject: erf_gelu_pattern_2
+```{include} domain_option.md
```
-Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter.
-
-```{literalinclude} examples/erfgelu.py
-:pyobject: apply_rewrite_with_ruleset
+```{include} outputs_option.md
```
-
-### 2. Using the `commute` parameter while creating a rule.
-
-Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter.
-
-```{literalinclude} examples/erfgelu.py
-:pyobject: apply_rewrite_with_commute
+```{include} conditional_rewrite.md
```
-For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows:
-
-{align=center width=300px}
-
-## Using the `match_condition` parameter for pattern-matching
-
-This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration.
-
-Let us consider a model which consists of the following pattern.
-
-{align=center}
-
-Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following:
-
-1. Input shapes check: `input_a` and `input_b` should be broadcastable
-2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)`
-
-If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite.
-
-First, write a target pattern and replacement pattern in a similar way to the first example.
-
-```{literalinclude} examples/broadcast_matmul.py
-:pyobject: two_reshapes_matmul_reshape_pattern
+```{include} or_pattern.md
```
-```{literalinclude} examples/broadcast_matmul.py
-:pyobject: matmul_pattern
+```{include} commute.md
```
-:::{note}
-:name: omitting inputs in signature
-
-The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters.
-
-Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature.
-:::
-
-In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows:
-
-```{literalinclude} examples/broadcast_matmul.py
-:pyobject: check_if_not_need_reshape
+```{include} node_value_checkers.md
```
-
-With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite.
-
-```{literalinclude} examples/broadcast_matmul.py
-:pyobject: apply_rewrite
-```
-
-The final graph with the applied rewrite looks as follows:
-
-{align=center}
diff --git a/docs/tutorial/rewriter/simple_example.md b/docs/tutorial/rewriter/simple_example.md
new file mode 100644
index 0000000000..53b3c89aff
--- /dev/null
+++ b/docs/tutorial/rewriter/simple_example.md
@@ -0,0 +1,81 @@
+(heading-target-simple)=
+# A Simple Example
+
+An simple example demonstrating the usage of this functionality using the `GELU` activation function:
+
+`GELU` activation function can be computed using a Gauss Error Function using the given formula:
+
+```{math}
+\text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})]
+```
+
+We will show how we can find a subgraph matching this computation and replace it by a call to the function.
+
+Firstly, include all the rewriter relevant imports.
+
+```python
+from onnxscript.rewriter import pattern
+from onnxscript import ir
+
+```
+
+Then create a target pattern that needs to be replaced using onnxscript operators.
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: erf_gelu_pattern
+```
+
+After this, create a replacement pattern that consists of the GELU onnxscript operator.
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: gelu
+```
+:::{note}
+:name: type annotate ir.Value
+
+The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value ` class.
+:::
+
+
+For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function.
+
+```python
+rule = pattern.RewriteRule(
+ erf_gelu_pattern, # Target Pattern
+ gelu, # Replacement Pattern
+)
+```
+
+It is more convenient to organize more complex rewrite-rules as a class. The above rule can be
+alternatively expressed as below.
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: ErfGeluFusion
+```
+
+The corresponding rewrite-rule can be obtained as below:
+
+```python
+erf_gelu_rule_from_class = ErfGeluFusion.rule()
+```
+
+Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite (model, pattern_rewrite_rules)` call applies the specified rewrite rules to the given model.
+
+1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `ir.Model` or `onnx.ModelProto`. If the model is an `ir.Model`, the rewriter applies the changes in-place, modifying the input model. If it is an `ModelProto`, the rewriter returns a new `ModelProto` representing the transformed model.
+2. `pattern_rewrite_rules` : This parameter either a `Sequence[PatternRewriteRule]` or a `RewriteRuleSet`.
+
+:::{note}
+:name: pattern_rewrite_rules input formatting
+
+For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset).
+:::
+
+The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above:
+
+```{literalinclude} examples/erfgelu.py
+:pyobject: apply_rewrite
+```
+
+The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended.
+
+ 
diff --git a/docs/update_readme.py b/docs/update_readme.py
index ddc5859cd5..7d39406883 100644
--- a/docs/update_readme.py
+++ b/docs/update_readme.py
@@ -1,4 +1,6 @@
-# Script to update end-to-end example in README.md.
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Script to update end-to-end example in README.md."""
updated_readme = []
with open("README.md", encoding="utf-8") as f:
@@ -12,7 +14,7 @@
with open(
"docs/tutorial/examples/hardmax_end_to_end.py", encoding="utf-8"
) as example_f:
- example_code = example_f.readlines()
+ example_code = example_f.readlines()[2:] # Skip the copyright header
updated_readme += example_code
if line == "```\n" and in_stub:
updated_readme.append(line)
diff --git a/examples/pattern_matching_example.py b/examples/pattern_matching_example.py
new file mode 100644
index 0000000000..8de09ecd6a
--- /dev/null
+++ b/examples/pattern_matching_example.py
@@ -0,0 +1,140 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Example demonstrating the new pattern matching functionality."""
+
+import onnx.parser
+
+from onnxscript import ir
+from onnxscript.rewriter import pattern
+
+
+def example_standalone_pattern_matching():
+ """Example showing how to use Pattern for standalone pattern matching."""
+
+ print("=== Standalone Pattern Matching Example ===")
+
+ # Define a pattern that matches Identity nodes
+ def identity_pattern(op, x):
+ return op.Identity(x)
+
+ # Create a Pattern for standalone pattern matching (no replacement)
+ pattern_matcher = pattern.Pattern(identity_pattern, name="IdentityMatcher")
+
+ # Create a model with an Identity node
+ model_proto = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z)
+ {
+ z = Identity(x)
+ }
+ """
+ )
+ model = ir.serde.deserialize_model(model_proto)
+
+ # Find nodes to test pattern matching against
+ for node in model.graph:
+ print(f"Testing pattern against {node.op_type} node...")
+ match_result = pattern_matcher.match(model, model.graph, node)
+
+ if match_result is not None:
+ print(f" ✓ Pattern matched! Found {len(match_result.nodes)} nodes in match.")
+ print(f" Matched node: {match_result.nodes[0].op_type}")
+ else:
+ print(f" ✗ Pattern did not match {node.op_type} node.")
+
+
+def example_class_based_pattern():
+ """Example showing how to use PatternBase for class-based pattern definition."""
+
+ print("\n=== Class-Based Pattern Example ===")
+
+ class IdentityPatternClass(pattern.PatternBase):
+ """A class-based pattern that matches Identity nodes."""
+
+ def pattern(self, op, x):
+ return op.Identity(x)
+
+ def check(self, context, x):
+ """Custom condition - always succeeds for this example."""
+ print(f" Checking condition for input: {x}")
+ return pattern.MatchResult() # Always succeeds
+
+ # Create an instance of the pattern class
+ identity_pattern_class = IdentityPatternClass(name="ClassBasedIdentity")
+
+ # The Pattern is created internally, we can use the pattern directly
+ print(f"Created pattern matcher: {identity_pattern_class.name}")
+
+ # Use it directly with the match method
+ model_proto = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z)
+ {
+ z = Identity(x)
+ }
+ """
+ )
+ model = ir.serde.deserialize_model(model_proto)
+
+ for node in model.graph:
+ if node.op_type == "Identity":
+ print(f"Testing class-based pattern against {node.op_type} node...")
+ match_result = identity_pattern_class.match(model, model.graph, node)
+
+ if match_result is not None:
+ print(" ✓ Class-based pattern matched!")
+ else:
+ print(" ✗ Class-based pattern did not match.")
+
+
+def example_rewrite_rule_still_works():
+ """Example showing that existing RewriteRule functionality is preserved."""
+
+ print("\n=== Existing RewriteRule Still Works ===")
+
+ def identity_pattern(op, x):
+ return op.Identity(x)
+
+ def identity_replacement(op, x):
+ return op.Identity(x) # No-op replacement
+
+ # Create a RewriteRule (which now inherits from Pattern)
+ rule = pattern.RewriteRule(identity_pattern, identity_replacement, name="IdentityRule")
+
+ print(f"Created rewrite rule: {rule.name}")
+ print(f"Rule is also a Pattern: {isinstance(rule, pattern.Pattern)}")
+
+ # The rule can be used both for pattern matching and rewriting
+ model_proto = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z)
+ {
+ z = Identity(x)
+ }
+ """
+ )
+ model = ir.serde.deserialize_model(model_proto)
+
+ # Use it for just pattern matching (inherited from Pattern)
+ for node in model.graph:
+ if node.op_type == "Identity":
+ print(f"Using RewriteRule for pattern matching on {node.op_type}...")
+ match_result = rule.match(model, model.graph, node)
+
+ if match_result is not None:
+ print(" ✓ RewriteRule matched as a pattern matcher!")
+
+ # Use it for rewriting (original functionality)
+ print("Using RewriteRule for rewriting...")
+ count = rule.apply_to_model(model)
+ print(f" Applied rule {count} times")
+
+
+if __name__ == "__main__":
+ example_standalone_pattern_matching()
+ example_class_based_pattern()
+ example_rewrite_rule_still_works()
+ print("\n=== All Examples Completed ===")
diff --git a/examples/pattern_rewriting.py b/examples/pattern_rewriting.py
index 737ce02e84..fd84d7f3cb 100644
--- a/examples/pattern_rewriting.py
+++ b/examples/pattern_rewriting.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Onnx Pattern Rewriting.
This script shows how to define a rewriting rule based on patterns.
@@ -13,9 +15,8 @@
import onnx.helper as oh
import onnx.numpy_helper as onh
-import onnxscript
from onnxscript import ir
-from onnxscript.rewriter import generic_pattern
+from onnxscript.rewriter import pattern
def get_rotary_model(bad_model=False):
@@ -67,18 +68,17 @@ def get_rotary_model(bad_model=False):
# The rewriting pattern
# =====================
-op = onnxscript.opset18
-msft_op = onnxscript.values.Opset("com.microsoft", 1)
-
-def rotary_match_pattern(x, pos_ids, axis):
+def rotary_match_pattern(op, x, pos_ids, axis):
"""The pattern to match."""
unsqueeze = op.Unsqueeze(x, axis)
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
- output, length = msft_op.ConcatTraining(transpose, transpose)
+ output, _length = op.ConcatTraining(
+ transpose, transpose, domain="com.microsoft", outputs=2
+ )
sin = op.Sin(output)
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
@@ -87,25 +87,13 @@ def rotary_match_pattern(x, pos_ids, axis):
return cast1, cast2
-def validate_rotary_mapping(g, match_result) -> bool:
- """The validation post matching.
-
- Returns True to validate the replacement,
- False not to apply it.
-
- :param g: model
- :param match_result: matched nodes
- """
- del g
- del match_result
- return True
-
-
-def rotary_apply_pattern(x, pos_ids, axis):
+def rotary_apply_pattern(op, x, pos_ids, axis):
"""The replacement pattern."""
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
- part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
+ part1, part2 = op.RotaryEmbedding(
+ x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
+ )
return part1, part2
@@ -115,18 +103,7 @@ def rotary_apply_pattern(x, pos_ids, axis):
#
# The rule is easy to create.
-
-rule_with_validation_function = generic_pattern.make_pattern_rule(
- rotary_match_pattern,
- rotary_apply_pattern,
- validate_rotary_mapping,
-)
-
-################################
-# ``validate_rotary_mapping`` always return True.
-# This argument can be ignored in that case.
-
-rule = generic_pattern.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)
+rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
##########################
# Let's apply it.
@@ -161,31 +138,6 @@ def rotary_apply_pattern(x, pos_ids, axis):
# The match did not happen.
# Let's increase the verbosity.
-rule = generic_pattern.make_pattern_rule(
- rotary_match_pattern, rotary_apply_pattern, verbose=10
-)
+rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
rule.apply_to_model(ir_model)
-
-######################################
-# The logs shows every time the algorithm rejected a pattern.
-# We can see the following:
-#
-# ::
-#
-# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
-# --hint--: BACKWARD: different node types
-# --pattern
-# ConcatTraining(transpose, transpose) -> (output, length)
-# -- model
-# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
-# iteration=1
-# --marked-- #2
-# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
-# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
-# len(stacked)=0:[]
-#
-# Line 673 in file `generic_pattern.py`, the match was rejected.
-# It says while comparing two nodes in the backward direction,
-# node types do not match.
-# It also says that two nodes were actually matched.
diff --git a/noxfile.py b/noxfile.py
index 3aad2dfc35..60c2bb901b 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Test with different environment configuration with nox.
Documentation:
@@ -10,14 +12,12 @@
COMMON_TEST_DEPENDENCIES = (
- "beartype==0.17.2",
"expecttest==0.1.6",
"hypothesis",
- 'numpy==1.24.4; python_version<"3.12"',
- 'numpy>1.26.0; python_version>="3.12"',
+ "numpy",
"packaging",
"parameterized",
- "pyinstrument",
+ 'psutil; sys_platform != "win32"',
"pytest-cov",
"pytest-randomly",
"pytest-subtests",
@@ -25,12 +25,14 @@
"pytest!=7.1.0",
"pyyaml",
"types-PyYAML",
- "typing_extensions",
+ "typing_extensions>=4.10",
+ "ml-dtypes",
)
-ONNX = "onnx==1.16"
-ONNX_RUNTIME = "onnxruntime==1.17.1"
-PYTORCH = "torch==2.2.2"
-TORCHVISON = "torchvision==0.17.2"
+ONNX = "onnx==1.17"
+ONNX_RUNTIME = "onnxruntime==1.23.0"
+PYTORCH = "torch==2.7.1"
+TORCHVISON = "torchvision==0.22.1"
+TRANSFORMERS = "transformers==4.37.2"
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
"flatbuffers",
"coloredlogs",
@@ -39,6 +41,8 @@
"packaging",
"protobuf",
)
+ONNX_IR = "onnx_ir==0.1.10"
+ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"
@nox.session(tags=["build"])
@@ -56,7 +60,9 @@ def test(session):
PYTORCH,
TORCHVISON,
ONNX,
+ ONNX_IR,
ONNX_RUNTIME,
+ TRANSFORMERS,
)
session.install(".", "--no-deps")
session.run("pip", "list")
@@ -70,9 +76,11 @@ def test_torch_nightly(session):
session.install(
*COMMON_TEST_DEPENDENCIES,
ONNX_RUNTIME,
+ TRANSFORMERS,
)
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install("-r", "requirements/ci/requirements-pytorch-nightly.txt")
+ session.install(ONNX_IR, "--no-deps")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run("pytest", "onnxscript", "--doctest-modules", *session.posargs)
@@ -82,7 +90,8 @@ def test_torch_nightly(session):
@nox.session(tags=["test-onnx-weekly"])
def test_onnx_weekly(session):
"""Test with ONNX weekly (preview) build."""
- session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON)
+ session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON, TRANSFORMERS)
+ session.install(ONNX_IR, "--no-deps")
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
@@ -98,6 +107,8 @@ def test_ort_nightly(session):
PYTORCH,
TORCHVISON,
ONNX,
+ ONNX_IR,
+ TRANSFORMERS,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
@@ -107,43 +118,19 @@ def test_ort_nightly(session):
session.run("pytest", "tests", *session.posargs)
-@nox.session(tags=["test-experimental-torchlib-tracing"])
-def test_experimental_torchlib_tracing(session):
- """Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
- session.install(
- *COMMON_TEST_DEPENDENCIES,
- PYTORCH,
- TORCHVISON,
- ONNX,
- *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
- )
- session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
- session.install(".", "--no-deps")
- session.run("pip", "list")
- session.run(
- "pytest",
- "tests/function_libs/torch_lib/ops_test.py",
- *session.posargs,
- env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"},
- )
-
-
-@nox.session(tags=["test-experimental-torchlib-onnx-ir"])
-def test_experimental_torchlib_onnx_ir(session):
- """Test TorchLib using the ONNX IR to build graphs."""
+@nox.session(tags=["test-onnx-ir-git"])
+def test_onnx_ir_git(session):
+ """Test with ONNX IR Git builds."""
session.install(
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
- *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
+ ONNX_RUNTIME,
+ TRANSFORMERS,
)
- session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
+ session.install(ONNX_IR_MAIN)
session.install(".", "--no-deps")
session.run("pip", "list")
- session.run(
- "pytest",
- "tests/function_libs/torch_lib/ops_test.py",
- *session.posargs,
- env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
- )
+ session.run("pytest", "onnxscript", "--doctest-modules", *session.posargs)
+ session.run("pytest", "tests", *session.posargs)
diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py
index bee5a1b230..b839093d2b 100644
--- a/onnxscript/__init__.py
+++ b/onnxscript/__init__.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
__all__ = [
"script",
@@ -9,6 +7,7 @@
"ir",
"optimizer",
"rewriter",
+ "version_converter",
"export_onnx_lib",
"OnnxFunction",
"TracedOnnxFunction",
@@ -54,10 +53,14 @@
"opset18",
"opset19",
"opset20",
+ "opset21",
+ "opset22",
"opset_ai_onnx_ml1",
"opset_ai_onnx_ml2",
"opset_ai_onnx_ml3",
"opset_ai_onnx_ml4",
+ "opset_ai_onnx_ml5",
+ "DEBUG",
]
import importlib.metadata
@@ -87,10 +90,13 @@
opset18,
opset19,
opset20,
+ opset21,
+ opset22,
opset_ai_onnx_ml1,
opset_ai_onnx_ml2,
opset_ai_onnx_ml3,
opset_ai_onnx_ml4,
+ opset_ai_onnx_ml5,
)
from .onnx_types import (
@@ -118,10 +124,13 @@
# isort: on
-from . import ir, optimizer, rewriter
+from . import ir, optimizer, rewriter, version_converter
from ._internal.utils import external_tensor
from .values import OnnxFunction, TracedOnnxFunction
+# Set DEBUG to True to enable additional debug checks
+DEBUG: bool = False
+
try: # noqa: SIM105
__version__ = importlib.metadata.version("onnxscript")
except importlib.metadata.PackageNotFoundError:
diff --git a/onnxscript/_framework_apis/__init__.py b/onnxscript/_framework_apis/__init__.py
new file mode 100644
index 0000000000..2aee3dcace
--- /dev/null
+++ b/onnxscript/_framework_apis/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Semi-private stable APIs for framework-specific usage only."""
diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py
new file mode 100644
index 0000000000..162faf4b75
--- /dev/null
+++ b/onnxscript/_framework_apis/torch_2_5.py
@@ -0,0 +1,119 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Stable APIs for PyTorch 2.5."""
+
+from __future__ import annotations
+
+__all__ = [
+ "check_model",
+ "convert_version",
+ "get_torchlib_ops",
+ "optimize",
+ "save_model_with_external_data",
+]
+
+import dataclasses
+import os
+import pathlib
+from typing import Callable
+
+from onnxscript import ir, optimizer, version_converter
+from onnxscript.function_libs.torch_lib import registration
+
+
+@dataclasses.dataclass(frozen=True)
+class _OnnxFunctionMeta:
+ """A wrapper of onnx-script function with additional metadata.
+
+ qualified_name: The qualified name of the aten operator.
+ function: The onnx-script function.
+ domain: The domain of the function.
+ name: The name of the function.
+ is_complex: Whether the function is a complex function.
+ """
+
+ qualified_name: str
+ function: Callable
+ domain: str
+ name: str
+ is_complex: bool = False
+
+
+def optimize(model: ir.Model) -> ir.Model:
+ """Optimize the model."""
+ # Internal flag. Will go away.
+ enabled = os.getenv("TORCH_ONNX_ENABLE_OPTIMIZATION") == "1"
+ if enabled:
+ optimizer.optimize_ir(model)
+ return model
+
+
+def convert_version(model: ir.Model, target_version: int) -> ir.Model:
+ """Convert the model to the specified ONNX opset version."""
+ # Internal flag. Will go away.
+ enabled = os.getenv("TORCH_ONNX_ENABLE_VERSION_CONVERSION") == "1"
+ if enabled:
+ version_converter.convert_version(model, target_version)
+ return model
+
+
+def check_model(model: ir.Model) -> None:
+ """Check the model."""
+
+ del model # Unused yet
+
+
+def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
+ """Save the model with external data. The model is unchanged after saving."""
+
+ # TODO(#1835): Decide if we want to externalize large attributes as well
+ uninitialized_values = [
+ value.name for value in model.graph.initializers.values() if value.const_value is None
+ ]
+ if uninitialized_values:
+ raise ValueError(
+ f"The model contains uninitialized initializer values ({uninitialized_values}). "
+ "Please make sure all initializer values are initialized."
+ )
+ destination_path = pathlib.Path(model_path)
+ data_path = f"{destination_path.name}.data"
+
+ ir.save(model, model_path, external_data=data_path)
+
+
+def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
+ # Trigger op registration
+ from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel
+ ops,
+ )
+
+ del ops # Unused
+
+ torchlib_registry = registration.default_registry
+ function_metas = []
+
+ for qualified_name, aten_overloads_func in torchlib_registry.items():
+ if qualified_name.startswith("internal::"):
+ # Skip the custom defined internal functions
+ continue
+
+ for overload_func in aten_overloads_func.overloads:
+ function_meta = _OnnxFunctionMeta(
+ qualified_name=qualified_name,
+ function=overload_func,
+ domain=overload_func.function_ir.domain,
+ name=overload_func.name,
+ is_complex=False,
+ )
+ function_metas.append(function_meta)
+ for complex_func in aten_overloads_func.complex:
+ function_meta = _OnnxFunctionMeta(
+ qualified_name=qualified_name,
+ function=complex_func,
+ domain=complex_func.function_ir.domain,
+ name=complex_func.name,
+ is_complex=True,
+ )
+ function_metas.append(function_meta)
+
+ return function_metas
diff --git a/onnxscript/_framework_apis/torch_2_6.py b/onnxscript/_framework_apis/torch_2_6.py
new file mode 100644
index 0000000000..2d166cb967
--- /dev/null
+++ b/onnxscript/_framework_apis/torch_2_6.py
@@ -0,0 +1,57 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Stable APIs for PyTorch 2.6."""
+
+from __future__ import annotations
+
+__all__ = [
+ "check_model",
+ "convert_version",
+ "get_torchlib_ops",
+ "optimize",
+ "save_model_with_external_data",
+ "torchlib_opset",
+]
+import logging
+from typing import TYPE_CHECKING
+
+from onnxscript import ir, optimizer, version_converter
+from onnxscript._framework_apis.torch_2_5 import (
+ check_model,
+ get_torchlib_ops,
+ save_model_with_external_data,
+)
+
+if TYPE_CHECKING:
+ from onnxscript.onnx_opset._impl.opset18 import Opset18
+
+
+logger = logging.getLogger(__name__)
+
+
+def optimize(model: ir.Model) -> ir.Model:
+ """Optimize the model."""
+ optimizer.optimize_ir(model)
+ return model
+
+
+def convert_version(model: ir.Model, target_version: int) -> ir.Model:
+ """Convert the model to the specified ONNX opset version."""
+ if target_version < 18:
+ logger.warning("Conversion to opset < 18 is not supported.")
+ return model
+ version_converter.convert_version(model, target_version, fallback=True)
+ return model
+
+
+def torchlib_opset() -> Opset18:
+ """Return the default opset for torchlib."""
+ import onnxscript # pylint: disable=import-outside-toplevel
+
+ return onnxscript.opset18 # type: ignore
+
+
+def torchlib_opset_version() -> int:
+ """Return the default opset version for torchlib."""
+
+ return torchlib_opset().version
diff --git a/onnxscript/_framework_apis/torch_2_7.py b/onnxscript/_framework_apis/torch_2_7.py
new file mode 100644
index 0000000000..ee5e6089e5
--- /dev/null
+++ b/onnxscript/_framework_apis/torch_2_7.py
@@ -0,0 +1,21 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Stable APIs for PyTorch 2.7."""
+
+from __future__ import annotations
+
+__all__ = [
+ "check_model",
+ "convert_version",
+ "get_torchlib_ops",
+ "optimize",
+ "save_model_with_external_data",
+]
+
+from onnxscript._framework_apis.torch_2_6 import (
+ check_model,
+ convert_version,
+ get_torchlib_ops,
+ optimize,
+ save_model_with_external_data,
+)
diff --git a/onnxscript/_framework_apis/torch_2_8.py b/onnxscript/_framework_apis/torch_2_8.py
new file mode 100644
index 0000000000..dca34086a0
--- /dev/null
+++ b/onnxscript/_framework_apis/torch_2_8.py
@@ -0,0 +1,31 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Stable APIs for PyTorch 2.8."""
+
+from __future__ import annotations
+
+__all__ = [
+ "check_model",
+ "convert_version",
+ "get_torchlib_ops",
+ "optimize",
+ "save_model_with_external_data",
+]
+
+import onnx_ir as ir
+
+import onnxscript.optimizer
+import onnxscript.rewriter.onnx_fusions
+from onnxscript._framework_apis.torch_2_6 import (
+ check_model,
+ convert_version,
+ get_torchlib_ops,
+ save_model_with_external_data,
+)
+
+
+def optimize(model: ir.Model) -> ir.Model:
+ """Optimize the model."""
+ onnxscript.optimizer.optimize_ir(model)
+ onnxscript.rewriter.onnx_fusions.fuse(model)
+ return model
diff --git a/onnxscript/_framework_apis/torch_2_9.py b/onnxscript/_framework_apis/torch_2_9.py
new file mode 100644
index 0000000000..88c9b85734
--- /dev/null
+++ b/onnxscript/_framework_apis/torch_2_9.py
@@ -0,0 +1,35 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Stable APIs for PyTorch 2.9."""
+
+from __future__ import annotations
+
+__all__ = [
+ "check_model",
+ "convert_version",
+ "get_torchlib_ops",
+ "optimize",
+ "save_model_with_external_data",
+]
+
+from typing import TYPE_CHECKING
+
+from onnxscript import version_converter
+from onnxscript._framework_apis.torch_2_8 import (
+ check_model,
+ get_torchlib_ops,
+ optimize,
+ save_model_with_external_data,
+)
+
+if TYPE_CHECKING:
+ import onnx_ir as ir
+
+
+def convert_version(model: ir.Model, target_version: int) -> ir.Model:
+ """Convert the model to the specified ONNX opset version.
+
+ Starting from PyTorch 2.9, down conversion is turned on and supported.
+ """
+ version_converter.convert_version(model, target_version, fallback=True)
+ return model
diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py
index 0901382eee..0403f60c91 100644
--- a/onnxscript/_internal/analysis.py
+++ b/onnxscript/_internal/analysis.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
import ast
diff --git a/onnxscript/_internal/analysis_test.py b/onnxscript/_internal/analysis_test.py
index 5531ec3833..74e7ca4c18 100644
--- a/onnxscript/_internal/analysis_test.py
+++ b/onnxscript/_internal/analysis_test.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from __future__ import annotations
import ast
diff --git a/onnxscript/_internal/ast_utils.py b/onnxscript/_internal/ast_utils.py
index 974ae75a09..4146f38e2f 100644
--- a/onnxscript/_internal/ast_utils.py
+++ b/onnxscript/_internal/ast_utils.py
@@ -1,23 +1,21 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Utilities for working with Python ASTs."""
from __future__ import annotations
import ast
import inspect
-import sys
import textwrap
-import types
+from typing import Callable
-PY_VERSION_GE_39 = sys.version_info >= (3, 9)
-
-def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
+def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]:
try:
- src = inspect.getsource(f)
+ src = inspect.getsource(func)
except OSError as e:
raise RuntimeError(
- f"Decorator script does not work on dynamically "
- f"compiled function {f.__name__}."
+ f"Decorator script does not work on dynamically compiled function {func.__name__}."
) from e
src = textwrap.dedent(src)
top_level_ast = ast.parse(src)
@@ -34,17 +32,10 @@ def normalize_subscript_expr(expr: ast.Subscript):
# Returns a list of expressions, denoting the indices, after stripping the extraneous "Index"
# wrapper present in python versions before 3.9
index_expr = expr.slice
- if PY_VERSION_GE_39:
- if isinstance(index_expr, ast.Tuple):
- return index_expr.elts # multiple indices
- else:
- return [index_expr] # single index
+ if isinstance(index_expr, ast.Tuple):
+ return index_expr.elts # multiple indices
else:
- if isinstance(index_expr, ast.ExtSlice):
- indices = index_expr.dims # type: ignore[attr-defined]
- else:
- indices = [index_expr] # single slice-index
- return [x.value if isinstance(x, ast.Index) else x for x in indices] # type: ignore[attr-defined]
+ return [index_expr] # single index
def is_print_call(stmt: ast.stmt) -> bool:
diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py
index b79180ae59..1defac3e53 100644
--- a/onnxscript/_internal/autocast.py
+++ b/onnxscript/_internal/autocast.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
@@ -9,10 +7,10 @@
import numpy as np
import onnx
-from onnx import helper, numpy_helper
+import onnx.helper # noqa: TID251
from onnx.defs import OpSchema
-from onnxscript import tensor
+from onnxscript import ir, tensor
if TYPE_CHECKING:
from onnxscript import converter
@@ -26,42 +24,8 @@
# Utilities to convert a python value to TensorProto (for use by the script converter)
-def _py_type_to_onnx_type(pytype: type):
- if pytype is bool:
- return onnx.TensorProto.BOOL
- if pytype is int:
- return onnx.TensorProto.INT64
- if pytype is float:
- return onnx.TensorProto.FLOAT
- if pytype is str:
- return onnx.TensorProto.STRING
- raise ValueError(f"Tensor element of type {pytype} not supported")
-
-
def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue):
- if isinstance(pyvalue, np.ndarray):
- return numpy_helper.from_array(pyvalue, tensor_name)
- if isinstance(pyvalue, list):
- if len(pyvalue) == 0:
- raise ValueError("Cannot convert an empty list to tensor")
- pytype = type(pyvalue[0])
- if not all(isinstance(e, pytype) for e in pyvalue):
- raise ValueError(
- "Cannot convert an list with elements of different types to tensor"
- )
- return helper.make_tensor(
- tensor_name,
- _py_type_to_onnx_type(pytype),
- [len(pyvalue)],
- pyvalue,
- )
- onnx_type = _py_type_to_onnx_type(type(pyvalue))
- if onnx_type is onnx.TensorProto.BOOL:
- return helper.make_tensor(tensor_name, onnx_type, [], [int(pyvalue)])
- if onnx_type is onnx.TensorProto.STRING:
- return helper.make_tensor(tensor_name, onnx_type, [], vals=[pyvalue.encode("utf-8")])
-
- return helper.make_tensor(tensor_name, onnx_type, [], [pyvalue])
+ return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name))
_REPEATED_ATTRIBUTE_TYPES = frozenset(
@@ -81,7 +45,7 @@ def pyvalue_to_onnx_attribute(
key: str,
value: Any,
name_generator: Callable[[], str],
- attr_type: Optional[onnx.AttributeProto.AttributeType] = None,
+ attr_type: onnx.AttributeProto.AttributeType | None = None,
) -> onnx.AttributeProto:
"""Helper function to create an ONNX AttributeProto.
@@ -105,7 +69,9 @@ def pyvalue_to_onnx_attribute(
name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value)
)
else:
- return onnx.helper.make_attribute(key, value)
+ # When the value is a subgraph, ONNX IR will complain that some values are
+ # not found from the scope.
+ return onnx.helper.make_attribute(key, value) # noqa: TID251
# Utilities to convert python values into onnxscript tensors.
diff --git a/onnxscript/_internal/deprecation.py b/onnxscript/_internal/deprecation.py
index 57769ba091..7bf18482a2 100644
--- a/onnxscript/_internal/deprecation.py
+++ b/onnxscript/_internal/deprecation.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
"""Utility for deprecating APIs."""
# Reference: https://github.com/pytorch/pytorch/blob/aed9bee0413dac190452fbfa9ab2a44b6e6843f5/torch/onnx/_deprecation.py
@@ -14,6 +12,12 @@
T = TypeVar("T")
+@functools.lru_cache(maxsize=1024)
+def _warn_once(message: str):
+ """Issue a FutureWarning only once per message."""
+ warnings.warn(message, category=FutureWarning, stacklevel=3)
+
+
def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T], T]:
"""Marks functions as deprecated.
@@ -32,12 +36,10 @@ def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T],
def decorator(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
- warnings.warn(
+ _warn_once(
f"'{function.__module__}.{function.__qualname__}' "
f"is deprecated in version {since} and will be "
f"removed in {removed_in}. Please {instructions}.",
- category=FutureWarning,
- stacklevel=2,
)
return function(*args, **kwargs)
diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py
index 54593abf32..b3591a0a8d 100644
--- a/onnxscript/_internal/param_manipulation.py
+++ b/onnxscript/_internal/param_manipulation.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Function for manipulating input parameters of an Op or a OnnxFunction."""
from __future__ import annotations
@@ -129,3 +131,18 @@ def tag_arguments_with_param_schemas(
raise TypeError(f"Required input/attribute '{param}' was not provided")
return tagged_args, tagged_kwargs
+
+
+def turn_to_kwargs_to_avoid_ordering(
+ param_schemas: Sequence[values.ParamSchema],
+ inputs: list[Any],
+ attributes: dict[str, Any],
+) -> dict[str, Any]:
+ """Return the inputs and attributes to the order of the function signature."""
+ for idx, param in enumerate(param_schemas):
+ if param.name not in attributes:
+ if param.is_variadic_input:
+ attributes[param.name] = inputs[idx:]
+ elif inputs:
+ attributes[param.name] = inputs.pop(0)
+ return attributes
diff --git a/onnxscript/_internal/param_manipulation_test.py b/onnxscript/_internal/param_manipulation_test.py
index f7148268e0..7b67e4380d 100644
--- a/onnxscript/_internal/param_manipulation_test.py
+++ b/onnxscript/_internal/param_manipulation_test.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
# mypy: disable-error-code=misc
import collections
diff --git a/onnxscript/_internal/runtime_typing.py b/onnxscript/_internal/runtime_typing.py
deleted file mode 100644
index 54e7dae0c0..0000000000
--- a/onnxscript/_internal/runtime_typing.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""An internal wrapper for the beartype library.
-
-Decorate a function with `@runtime_typing.checked` to enable runtime
-type checking. The decorator is a no-op when the `beartype` library is not
-installed.
-"""
-
-import typing
-import warnings
-
-__all__ = [
- "checked",
-]
-
-T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any])
-
-try:
- from beartype import beartype as checked
- from beartype import roar as _roar
-
- # Beartype warns when we import from typing because the types are deprecated
- # in Python 3.9. But there will be a long time until we can move to using
- # the native container types for type annotations (when 3.9 is the lowest
- # supported version). So we silence the warning.
- warnings.filterwarnings(
- "ignore",
- category=_roar.BeartypeDecorHintPep585DeprecationWarning,
- )
-except ImportError:
-
- def checked(func: T) -> T: # type: ignore[no-redef]
- return func
-
-except Exception as e: # pylint: disable=broad-exception-caught
- # Warn errors that are not import errors (unexpected).
- warnings.warn(f"{e}", stacklevel=2)
-
- def checked(func: T) -> T: # type: ignore[no-redef]
- return func
diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py
index c4537e3bcd..ce2b657cfd 100644
--- a/onnxscript/_internal/utils.py
+++ b/onnxscript/_internal/utils.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
import numbers
@@ -9,7 +7,6 @@
import numpy as np
import onnx
-import onnx.helper
from onnxscript import tensor
@@ -67,26 +64,26 @@ def add(k, v):
def value_to_type_proto(val):
"""Return the ONNX type of a python-value."""
if isinstance(val, (np.ndarray, tensor.Tensor)):
- elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype)
+ elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251
shape = val.shape
- return onnx.helper.make_tensor_type_proto(elem_type, shape)
+ return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251
if isinstance(val, int):
- return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, [])
+ return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251
if isinstance(val, (float, np.float32)):
- return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [])
+ return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251
if isinstance(val, list):
if len(val) > 0:
- return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))
+ return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
# Should be using a typed-value instead.
# Treated as a sequence of tensors of float-type.
- return onnx.helper.make_sequence_type_proto(
- onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None)
+ return onnx.helper.make_sequence_type_proto( # noqa: TID251
+ onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251
)
if isinstance(val, numbers.Number):
nparray = np.array(val)
- elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype)
- return onnx.helper.make_tensor_type_proto(elem_type, [])
+ elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251
+ return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")
@@ -95,7 +92,7 @@ def values_to_value_infos(name_values):
skipping any None values.
"""
return [
- onnx.helper.make_value_info(name, value_to_type_proto(val))
+ onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251
for (name, val) in name_values
if val is not None
]
diff --git a/onnxscript/_internal/version_utils.py b/onnxscript/_internal/version_utils.py
index c66cb8d2ba..2b43c54f49 100644
--- a/onnxscript/_internal/version_utils.py
+++ b/onnxscript/_internal/version_utils.py
@@ -1,5 +1,12 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Version utils for testing."""
+from __future__ import annotations
+
+import warnings
+from typing import Callable, Sequence
+
import packaging.version
@@ -23,6 +30,19 @@ def torch_older_than(version: str) -> bool:
)
+def transformers_older_than(version: str) -> bool | None:
+ """Returns True if the transformers version is older than the given version."""
+ try:
+ import transformers # pylint: disable=import-outside-toplevel
+ except ImportError:
+ return None
+
+ return (
+ packaging.version.parse(transformers.__version__).release
+ < packaging.version.parse(version).release
+ )
+
+
def onnxruntime_older_than(version: str) -> bool:
"""Returns True if the onnxruntime version is older than the given version."""
import onnxruntime # pylint: disable=import-outside-toplevel
@@ -31,3 +51,48 @@ def onnxruntime_older_than(version: str) -> bool:
packaging.version.parse(onnxruntime.__version__).release
< packaging.version.parse(version).release
)
+
+
+def numpy_older_than(version: str) -> bool:
+ """Returns True if the numpy version is older than the given version."""
+ import numpy # pylint: disable=import-outside-toplevel
+
+ return (
+ packaging.version.parse(numpy.__version__).release
+ < packaging.version.parse(version).release
+ )
+
+
+def has_transformers():
+ """Tells if transformers is installed."""
+ try:
+ import transformers # pylint: disable=import-outside-toplevel
+
+ assert transformers
+ return True # noqa
+ except ImportError:
+ return False
+
+
+def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
+ """Catches warnings.
+
+ Args:
+ warns: warnings to ignore
+
+ Returns:
+ decorated function
+ """
+
+ def wrapper(fct):
+ if warns is None:
+ raise AssertionError(f"warns cannot be None for '{fct}'.")
+
+ def call_f(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
+ return fct(self)
+
+ return call_f
+
+ return wrapper
diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py
deleted file mode 100644
index 74aa693593..0000000000
--- a/onnxscript/_legacy_ir/__init__.py
+++ /dev/null
@@ -1,339 +0,0 @@
-from __future__ import annotations
-
-import dataclasses
-from collections import deque
-from typing import List, Tuple, Union
-
-import numpy as np
-import onnx
-
-
-class Unknown:
- """A special value used to indicate that a value is not a statically known constant.
-
- We use this instead of None because None is a valid constant value (since ONNX
- supports the Optional type).
- """
-
- instance = None
-
- def __init__(self) -> None:
- if Unknown.instance is not None:
- raise ValueError("Unknown.instance is already set")
- Unknown.instance = self
-
-
-# Singleton instance of Unknown
-unknown = Unknown()
-NotConstant = unknown
-
-# ConcreteValue: This type represents constant values that an ONNX variable can take.
-# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals,
-# maps, etc.
-# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto.
-# A uniform representation would be helpful, but we should avoid unnecessary conversions for
-# large tensors. Should be cleaned up in the new IR.
-ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None]
-
-# SymbolicValue: This information is used to enable partial-evaluation and specialization
-# of sequence operations, as well as elimination of redundant Identity ops.
-# The symbolic value of a variable X can be:
-# - a string with the value "Y", indicating that "X" is a copy of "Y"
-# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values
-# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to
-# "SequenceConstruct(A, B, C)".
-# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of
-# tensors, etc. However, we currently only handle lists of tensors.
-
-SymbolicValue = Union[str, List[str]]
-
-FunctionId = Tuple[str, str, str]
-
-
-def get_function_id(function: onnx.FunctionProto) -> FunctionId:
- return (function.domain, function.name, getattr(function, "overload", ""))
-
-
-def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId:
- return (node.domain, node.op_type, getattr(node, "overload", ""))
-
-
-@dataclasses.dataclass
-class StaticValueInfo:
- name: str
- value: ConcreteValue = NotConstant
- type: onnx.TypeProto | None = None
- symbolic_value: SymbolicValue | None = None
-
- def is_copy(self) -> bool:
- return isinstance(self.symbolic_value, str)
-
- def tensor_shape_proto(self) -> onnx.TensorShapeProto | None:
- """Returns the shape of a tensor or None.
-
- A return value of None could mean that the type is unknown or that the type is not a tensor
- or that the tensor shape (that is, even the rank) is unknown.
- """
- type = self.type
- if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
- return type.tensor_type.shape
- return None
-
- @property
- def shape(self) -> list[str | int | None] | None:
- """Returns the shape in a list.
-
- Str means that the shape is dynamic.
- """
- type = self.type
- if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
- dims = []
- for dim in type.tensor_type.shape.dim:
- if dim.HasField("dim_param"):
- dims.append(dim.dim_param)
- elif dim.HasField("dim_value"):
- dims.append(dim.dim_value)
- else:
- dims.append(None)
- return dims
- if self.value_as_np_array is not None:
- return list(self.value_as_np_array.shape)
- return None
-
- @property
- def element_type(self) -> int | None:
- """Returns the element type of a tensor, or None if type is not known or is not a tensor."""
- type = self.type
- if type and type.HasField("tensor_type"):
- return type.tensor_type.elem_type
- return None
-
- def identity_merge_from(self, other: StaticValueInfo) -> None:
- """Merge the value of other into self.
-
- This models the effect of an identity (copy) operation.
- This will update static-analysis information based on incoming value.
- """
- if not isinstance(other, StaticValueInfo):
- raise TypeError(f"Cannot merge {other} into {self}.")
- if other.value is not NotConstant:
- self.value = other.value
- # TODO: merge and combine best shape information from both types.
- if other.tensor_shape_proto() is not None and other.element_type is not None:
- self.type = other.type
- # We cannot copy symbolic value across different scopes.
-
- # WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo
- # does not fill in the following fields. These fields are filled in by the IRBuilder
- # which constructs the IR from the ONNX model.
- node: Node | None = None
- uses: list[Node] = dataclasses.field(default_factory=list)
- output_index: int | None = None
- is_output: bool = False
-
- @property
- def const_value(self) -> ConcreteValue:
- return self.value
-
- @property
- def value_as_np_array(self) -> np.ndarray | None:
- if isinstance(self.value, np.ndarray):
- return self.value
- if isinstance(self.value, onnx.TensorProto):
- return onnx.numpy_helper.to_array(self.value)
- return None
-
- def def_node(self) -> Node | None:
- return self.node
-
- def def_index(self) -> int:
- return self.output_index # type: ignore[return-value]
-
- def is_same_as(self, other: StaticValueInfo) -> bool:
- """Returns true if this value represents the same IR object as the other value.
-
- This is *not* value-equality, but rather object-equality.
- """
- return self is other
-
- def __str__(self) -> str:
- shape = self.shape
- if shape is not None:
- shape = [str(dim) for dim in shape]
- shape_str = f"[{', '.join(shape)}]" # type: ignore[arg-type]
- else:
- shape_str = "None"
- return (
- f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, "
- f"{'has const value' if self.value is not unknown else 'no const value'}.)"
- )
-
-
-Value = StaticValueInfo
-
-
-class Model:
- def __init__(self) -> None:
- self.gen_var_counter: int = 0
-
- def set(
- self,
- model_proto: onnx.ModelProto,
- graph: Graph,
- functions: list[Function],
- version_map: dict[str, int],
- ) -> None:
- """TODO. This is a temporary patch."""
- self.original_model_proto = model_proto
- self.graph = graph
- self.functions = functions
- self.version_map = version_map
-
- def make_new_name(self):
- # Temporary hack.
- self.gen_var_counter += 1
- return f"_gen_{self.gen_var_counter}"
-
- def __str__(self) -> str:
- # TODO: Naive string representation for debugging. Need to improve this.
- return "\n".join(
- [
- f"ModelGraph: {self.graph}",
- f"Functions: {self.functions}",
- f"VersionMap: {self.version_map}",
- ]
- )
-
-
-class Graph:
- def __init__(self, graph_proto: onnx.GraphProto):
- self.original_graph_proto = graph_proto
- self.nodes: deque[Node] = deque()
- self.values: dict[str, Value] = {}
-
- @property
- def name(self) -> str:
- return self.original_graph_proto.name
-
- def __str__(self) -> str:
- return "\n".join(
- [
- "Graph",
- f"Nodes: {[str(n) for n in self.nodes]}",
- f"Values: {[str(v) for v in self.values]}",
- ]
- )
-
- @property
- def input_names(self) -> list[str]:
- return [_.name for _ in self.original_graph_proto.input]
-
- @property
- def output_names(self) -> list[str]:
- return [_.name for _ in self.original_graph_proto.output]
-
-
-class Function:
- def __init__(self, function_proto: onnx.FunctionProto):
- self.original_function_proto = function_proto
- self.nodes = deque() # type: ignore[var-annotated]
- self.values = {} # type: ignore[var-annotated]
-
- @property
- def id(self) -> FunctionId:
- return (self.domain, self.name, self.overload)
-
- @property
- def domain(self) -> str:
- return self.original_function_proto.domain
-
- @property
- def name(self) -> str:
- return self.original_function_proto.name
-
- @property
- def overload(self) -> str:
- return getattr(self.original_function_proto, "overload", "")
-
- def __str__(self) -> str:
- return "\n".join(
- [
- "Function",
- f"Nodes: {[str(n) for n in self.nodes]}",
- f"Values: {[str(v) for v in self.values]}",
- ]
- )
-
-
-class RefAttr:
- def __init__(self, name: str, ref_attr_name: str, type) -> None:
- self.name = name
- self.ref_attr_name = ref_attr_name
- self.type = type
-
- def to_proto(self) -> onnx.AttributeProto:
- attr_proto = onnx.AttributeProto()
- attr_proto.name = self.name
- attr_proto.ref_attr_name = self.ref_attr_name
- attr_proto.type = self.type
- return attr_proto
-
-
-class Node:
- def __init__(
- self,
- node_proto: onnx.NodeProto,
- populate_io: bool = False,
- ) -> None:
- self.original_node_proto = node_proto
- self.domain: str = node_proto.domain
- self.version: int | None = None
- self.op_type: str = node_proto.op_type
- if populate_io:
- self.inputs: list[Value | None] = [Value(i) for i in node_proto.input]
- self.outputs: list[Value | None] = [Value(i) for i in node_proto.output]
- else:
- self.inputs: list[Value | None] = [] # type: ignore[no-redef]
- self.outputs: list[Value | None] = [] # type: ignore[no-redef]
- # TODO: attributes are never populated.
- self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {}
-
- def __repr__(self) -> str:
- return (
- f"{self.op_type}({','.join(self.original_node_proto.input)})"
- f"->{','.join(self.original_node_proto.output)}"
- )
-
- @property
- def name(self) -> str:
- return self.original_node_proto.name
-
- @property
- def input_names(self):
- return self.original_node_proto.input
-
- @property
- def output_names(self):
- return self.original_node_proto.output
-
- @property
- def attribute(self):
- return self.original_node_proto.attribute
-
- def set_version_if_custom_op(self, version_map: dict[str, int]) -> None:
- if self.domain != "" and self.domain in version_map:
- self.version = version_map[self.domain]
-
- def get_attribute(self, name: str) -> int | float | None:
- return self.attributes.get(name, None) # type: ignore[return-value]
-
- def __str__(self) -> str:
- return "\n".join(
- [
- "Node",
- f"OpType: {self.op_type}",
- f"Inputs: {self.inputs}",
- f"Outputs: {self.outputs}",
- f"Attributes: {self.attributes}",
- ]
- )
diff --git a/onnxscript/_legacy_ir/visitor.py b/onnxscript/_legacy_ir/visitor.py
deleted file mode 100644
index 3044fdd77e..0000000000
--- a/onnxscript/_legacy_ir/visitor.py
+++ /dev/null
@@ -1,922 +0,0 @@
-from __future__ import annotations
-
-import dataclasses
-import logging
-from typing import Any, Sequence
-
-import numpy as np
-import onnx
-
-import onnxscript._legacy_ir as ir
-from onnxscript.utils.utils import (
- get_initializer_type,
- is_control_flow_op,
- normalize_domain,
-)
-
-logger = logging.getLogger(__name__)
-
-
-def _override_inferred_value_type_with_symbolic_value_type(
- symbolic_value: ir.Value | None,
- inferred_value: ir.Value | None,
-) -> ir.Value | None:
- if inferred_value is not None and symbolic_value is not None:
- inferred_value.type = symbolic_value.type
- if inferred_value is None:
- inferred_value = symbolic_value
- return inferred_value
-
-
-def is_local_function_node(
- node: onnx.NodeProto, functions: dict[ir.FunctionId, onnx.FunctionProto]
-) -> bool:
- return ir.get_function_id_from_node(node) in functions
-
-
-class FunctionShapeEnv:
- def __init__(self):
- # Mapping from (domain, function_name, overload) to {value_name: ir_value}
- self._function_values: dict[ir.FunctionId, dict[str, ir.Value]] = {}
-
- def load_from_model_proto(self, model_proto: onnx.ModelProto) -> None:
- for value_info in model_proto.graph.value_info:
- self.load_from_value_info(value_info)
-
- def save_to_model_proto(self, model_proto: onnx.ModelProto) -> None:
- for (
- domain,
- function_name,
- overload,
- ), named_ir_values in self._function_values.items():
- for ir_value in named_ir_values.values():
- if (
- value_info := self.save_to_value_info(
- ir_value, domain, function_name, overload
- )
- ) is not None:
- model_proto.graph.value_info.append(value_info)
-
- def load_from_value_info(self, value_info: onnx.ValueInfoProto) -> None:
- function_id, ir_value = self.process_value_info(value_info)
- if function_id is not None:
- logger.debug(
- "Loads torch symbolic value info '%s'.",
- value_info.name,
- )
- self._function_values.setdefault(function_id, {})[ir_value.name] = ir_value
-
- def process_value_info(
- self, value_info: onnx.ValueInfoProto
- ) -> tuple[ir.FunctionId | None, ir.Value]:
- name = value_info.name
- if len(splits := name.split("/")) == 2:
- # Experimental function value info format.
- # To be deprecated after ONNX 1.16, where value_info is introduced in FunctionProto.
- function_id, value_name = splits
- splits = function_id.split("::")
- domain, function_name = splits[0], splits[1]
- # 'overload' is introduced in ONNX 1.16, consider it as empty string prior to that.
- # The code is for future proof, in case overload is encoded in this format.
- overload = ""
- if len(splits) == 3:
- overload = splits[2]
- function_id = (domain, function_name, overload)
- else:
- # Standard main graph value info format.
- function_id = None
- value_name = name
- return function_id, ir.Value(name=value_name, type=value_info.type)
-
- def save_to_value_info(
- self, value: ir.Value, domain: str, function_name: str, overload: str
- ) -> onnx.ValueInfoProto | None:
- if overload != "":
- raise NotImplementedError("Overload is not supported yet.")
- function_id = f"{domain}::{function_name}"
-
- if value.type is not None:
- return onnx.helper.make_value_info(f"{function_id}/{value.name}", value.type)
- return None
-
- def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | None:
- """Lookup ir value of 'value_name' inside 'function'."""
- function_id = ir.get_function_id(function)
- function_values = self._function_values.get(function_id)
- if function_values is None or (ir_value := function_values.get(value_name)) is None:
- logger.debug(
- "Lookup Missed %s torch symbolic value info in function %s::%s.",
- value_name,
- function.domain,
- function.name,
- )
- return None
- logger.debug(
- "Lookup found %s torch symbolic value info in function %s::%s.",
- value_name,
- function.domain,
- function.name,
- )
- return ir_value
-
- def bind(self, value: ir.Value, domain: str, function_name: str, overload: str) -> None:
- """Bind ir value 'value' to 'value_name' inside 'function'."""
- function_id = (domain, function_name, overload)
- self._function_values.setdefault(function_id, {})[value.name] = value
-
- def get_ir_values(self, function: onnx.FunctionProto) -> dict[str, ir.Value]:
- """Get all ir values inside 'function'."""
- function_id = ir.get_function_id(function)
- return self._function_values.get(function_id, {})
-
-
-class SubScope:
- values: dict[str, ir.Value]
- ref_attributes: dict[str, onnx.AttributeProto]
- owner: onnx.GraphProto | onnx.FunctionProto
-
- def __init__(self, owner: onnx.GraphProto | onnx.FunctionProto):
- self.values = {}
- self.ref_attributes = {}
- self.owner = owner
-
- def lookup(self, name: str) -> ir.Value | None:
- return self.values.get(name)
-
- def bind(self, name: str, value: ir.Value) -> None:
- self.values[name] = value
-
- def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None:
- return self.ref_attributes.get(ref_attr_name)
-
- def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None:
- self.ref_attributes[ref_attr_name] = attr
-
- def readable_strs(self, indent: int = 0) -> list[str]:
- indent_str = " " * indent
- strs = []
- if isinstance(self.owner, onnx.GraphProto):
- strs.append(f"Graph {self.owner.name}:")
- else:
- strs.append(f"Function {self.owner.name}:")
- strs.append(" ir.Values:")
- for name, value in self.values.items():
- strs.append(f" {name}: {value}")
- strs.append(" RefAttributes:")
- for name, attr in self.ref_attributes.items():
- strs.append(f" {name}: {attr}")
-
- return [f"{indent_str}{s}" for s in strs]
-
- def __str__(self) -> str:
- return "\n".join(self.readable_strs())
-
-
-@dataclasses.dataclass
-class Scope:
- _sub_scopes: list[SubScope] = dataclasses.field(default_factory=list)
-
- def lookup(self, name: str) -> ir.Value | None:
- """Lookup value by name from all SubScopes."""
- for sub_scope in reversed(self._sub_scopes):
- if (result := sub_scope.lookup(name)) is not None:
- return result
- return None
-
- def bind(self, name: str, value: ir.Value) -> None:
- """Bind value to name in the most recent SubScope."""
- if name == "":
- raise ValueError("Cannot bind to empty name.")
- if value is None:
- raise ValueError(f"Cannot bind None to value {name}.")
- self._sub_scopes[-1].bind(name, value)
-
- def lookup_or_create(self, name: str) -> ir.Value:
- """Lookup value by name from all SubScopes. If not found, create a new one in most recent SubScope."""
- if name == "":
- raise ValueError("Cannot lookup or create empty name.")
- for sub_scope in reversed(self._sub_scopes):
- if (result := sub_scope.lookup(name)) is not None:
- return result
- value = ir.Value(name=name)
- self.bind(name, value)
- return value
-
- def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None:
- for sub_scope in reversed(self._sub_scopes):
- if (result := sub_scope.lookup_ref_attribute(ref_attr_name)) is not None:
- return result
- return None
-
- def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None:
- self._sub_scopes[-1].bind_ref_attribute(ref_attr_name, attr)
-
- def enter_sub_scope(self, owner: onnx.GraphProto) -> None:
- self._sub_scopes.append(SubScope(owner))
-
- def exit_sub_scope(self) -> SubScope:
- return self._sub_scopes.pop()
-
- def current_function_scope(self) -> SubScope | None:
- if len(self._sub_scopes) == 0:
- return None
- if isinstance(self._sub_scopes[0].owner, onnx.FunctionProto):
- return self._sub_scopes[0]
- return None
-
- def current_function(self) -> onnx.FunctionProto | None:
- current_function_scope = self.current_function_scope()
- if current_function_scope is not None:
- return current_function_scope.owner
- return None
-
- def current_graph(self) -> onnx.GraphProto | None:
- for sub_scope in reversed(self._sub_scopes):
- if isinstance(sub_scope.owner, onnx.GraphProto):
- return sub_scope.owner
- return None
-
- def readable_strs(self, indent: int = 0) -> list[str]:
- indent_str = " " * indent
- strs = []
- for i, sub_scope in enumerate(self._sub_scopes):
- strs.append(f"SubScope {i}:")
- strs.extend(sub_scope.readable_strs(indent=indent + 2))
- return [f"{indent_str}{s}" for s in strs]
-
- def __str__(self) -> str:
- return "\n".join(self.readable_strs())
-
-
-@dataclasses.dataclass
-class ScopeStack:
- """Stack of scopes.
-
- Each Scope represents statically-nested SubScopes (where inner SubScopes can access names defined in outer SubScopes)
- produced by subgraphs (occurring as attribute values), except for the first SubScope which could be produced by a function.
- With a ScopeStack, there is no such possibility of referencing variables defined higher up in the stack by name.
- Instead, it is meant to represent a sequence of (nested) function-calls. Each entry in the stack (except the outermost)
- represents a call to a function.
-
- Thus, we would use a ScopeStack for a context-sensitive analysis (where we recursively process a called function).
- For a context-insensitive analysis, we would only need a Scope (where we recursively process subgraphs).
-
- To debug, `print(scope_stack)` will print the scope structure as well as the info stored
- in each scope.
- """
-
- _scopes: list[Scope] = dataclasses.field(default_factory=lambda: [Scope()])
-
- def current_scope(self) -> Scope:
- return self._scopes[-1]
-
- def lookup(self, name: str) -> ir.Value | None:
- """Lookup value by name from the current Scope."""
- return self.current_scope().lookup(name)
-
- def bind(self, name: str, value: ir.Value) -> None:
- """Bind value to name in the current Scope."""
- self.current_scope().bind(name, value)
-
- def lookup_or_create(self, name: str) -> ir.Value:
- """Lookup value by name from the current Scope. If not found, create a new one."""
- return self.current_scope().lookup_or_create(name)
-
- def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None:
- return self.current_scope().lookup_ref_attribute(ref_attr_name)
-
- def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None:
- self.current_scope().bind_ref_attribute(ref_attr_name, attr)
-
- def enter_graph_scope(self, graph: onnx.GraphProto) -> None:
- self.current_scope().enter_sub_scope(graph)
-
- def exit_graph_scope(self) -> SubScope:
- sub_scope = self.current_scope().exit_sub_scope()
- assert isinstance(sub_scope.owner, onnx.GraphProto), "Expected graph scope."
- return sub_scope
-
- def enter_function_scope(self, function: onnx.FunctionProto) -> None:
- self._scopes.append(Scope())
- self.current_scope().enter_sub_scope(function)
-
- def exit_function_scope(self) -> SubScope:
- sub_scope = self.current_scope().exit_sub_scope()
- assert isinstance(sub_scope.owner, onnx.FunctionProto), "Expected function scope."
- self._scopes.pop()
- return sub_scope
-
- def current_function(self) -> onnx.FunctionProto | None:
- return self.current_scope().current_function()
-
- def current_graph(self) -> onnx.GraphProto | None:
- return self.current_scope().current_graph()
-
- def __str__(self) -> str:
- strs = ["ScopeStach:"]
- for i, scope in enumerate(self._scopes):
- strs.append(f" Scope {i}:")
- strs.extend(scope.readable_strs(indent=2))
- return "\n".join(strs)
-
-
-class ProtoVisitorCore:
- def visit_model(self, model: onnx.ModelProto):
- self.process_model(model)
- for opset in model.opset_import:
- self.process_opset_import(opset)
- self.visit_graph(model.graph)
- for function in model.functions:
- self.visit_function(function)
-
- def process_model(self, model: onnx.ModelProto):
- pass
-
- def process_opset_import(self, opset: onnx.OperatorSetIdProto):
- pass
-
- def visit_graph(self, graph: onnx.GraphProto):
- self.enter_scope(graph)
- self.process_graph(graph)
- for input in graph.input:
- self.process_graph_input(input)
- for init in graph.initializer:
- self.process_initializer(init)
- for value_info in graph.value_info:
- self.process_value_info(value_info)
- for node in graph.node:
- self.visit_node(node)
- for output in graph.output:
- self.process_graph_output(output)
- self.exit_scope(graph)
-
- def visit_function(self, function: onnx.FunctionProto):
- self.enter_function_scope(function)
- self.process_function(function)
- for input in function.input:
- self.process_function_input(input)
- for node in function.node:
- self.visit_node(node)
- for output in function.output:
- self.process_function_output(output)
- self.exit_function_scope(function)
-
- def process_function_input(self, input: str):
- pass
-
- def process_function_output(self, output: str):
- pass
-
- def process_function(self, function: onnx.FunctionProto):
- pass
-
- def enter_function_scope(self, function: onnx.FunctionProto):
- pass
-
- def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope:
- pass
-
- def enter_scope(self, graph: onnx.GraphProto):
- pass
-
- def process_graph(self, graph: onnx.GraphProto):
- pass
-
- def exit_scope(self, graph: onnx.GraphProto) -> SubScope:
- pass
-
- def process_graph_input(self, input: onnx.ValueInfoProto):
- pass
-
- def process_initializer(self, init: onnx.TensorProto):
- pass
-
- def process_value_info(self, value_info: onnx.ValueInfoProto):
- pass
-
- def visit_node(self, node: onnx.NodeProto):
- self.process_node(node)
- for attr in node.attribute:
- self.visit_attribute(attr)
-
- def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
- pass
-
- def process_graph_output(self, output: onnx.ValueInfoProto):
- pass
-
- def visit_attribute(self, attr: onnx.AttributeProto):
- self.process_attribute(attr)
- if attr.HasField("g"):
- self.visit_graph(attr.g)
- elif len(attr.graphs) > 0:
- for graph in attr.graphs:
- self.visit_graph(graph)
-
- def process_attribute(self, attr: onnx.AttributeProto):
- pass
-
-
-class ProtoVisitor(ProtoVisitorCore):
- def __init__(
- self, external_data_folder: str = "", *, do_shape_inference: bool = False
- ) -> None:
- super().__init__()
- self.scopes = ScopeStack()
- self.function_shape_env = FunctionShapeEnv()
- self.version_map = {} # Map from domain to version
- self.do_shape_inference = do_shape_inference
- self.external_data_folder = external_data_folder
- self.modified = False
-
- def process_opset_import(self, opset: onnx.OperatorSetIdProto):
- domain = normalize_domain(opset.domain)
- self.version_map[domain] = opset.version
-
- def lookup_version(self, domain: str) -> int:
- domain = normalize_domain(domain)
- return self.version_map.get(domain, 1) # TODO: handle missing domain
-
- def lookup(self, name: str) -> ir.Value | None:
- if name == "":
- return None
- if (result := self.scopes.lookup(name)) is None:
- logger.debug("Lookup value %s unfound.", name)
- raise ValueError(
- f"Undefined variable {name}.\n"
- f"Available variables: {self.scopes.current_scope()}"
- )
- logger.debug("Lookup value %s. Shape %s", name, result.tensor_shape_proto())
- return result
-
- def bind(self, name: str, value: ir.Value) -> None:
- logger.debug("Binding value %s. Shape %s", name, value.tensor_shape_proto())
- self.scopes.bind(name, value)
-
- def lookup_or_create(self, name: str) -> ir.Value:
- return self.scopes.lookup_or_create(name)
-
- def has_input(self, node: onnx.NodeProto, index: int) -> bool:
- return index < len(node.input) and node.input[index] != ""
-
- # TODO: Cleanup handling of undefined variables. May fail in some of methods below.
-
- def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None:
- if index < len(node.input):
- return self.lookup(node.input[index])
- return None
-
- def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None:
- info = self.get_input(node, index)
- return info.type if info is not None else None
-
- def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None:
- info = self.get_input(node, index)
- return info.element_type if info is not None else None
-
- def input_shape(self, node: onnx.NodeProto, index: int) -> onnx.TensorShapeProto | None:
- info = self.get_input(node, index)
- return info.tensor_shape_proto() if info is not None else None
-
- def input_const_value(self, node: onnx.NodeProto, index: int) -> Any:
- if not self.has_input(node, index):
- return None # This is treated as a known constant value "None"
- info = self.get_input(node, index)
- return info.value
-
- def has_output(self, node: onnx.NodeProto, index: int) -> bool:
- return index < len(node.output) and node.output[index] != ""
-
- def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None:
- if index < len(node.output):
- return self.lookup(node.output[index])
- return None
-
- def get_input_value(
- self, node: onnx.NodeProto, index: int, default: Any | None = None
- ) -> Any | None:
- info = self.get_input(node, index)
- if info is not None:
- return info.value
- return default
-
- def get_input_type(
- self, node: onnx.NodeProto, index: int, default: onnx.TypeProto | None = None
- ) -> onnx.TypeProto | None:
- info = self.get_input(node, index)
- if info is not None:
- return info.type
- return default
-
- def enter_scope(self, graph: onnx.GraphProto):
- logger.debug("enter_scope: graph %s", graph.name)
- self.scopes.enter_graph_scope(graph)
-
- def exit_scope(self, graph: onnx.GraphProto) -> SubScope:
- logger.debug("exit_scope: graph %s", graph.name)
- return self.scopes.exit_graph_scope()
-
- def enter_function_scope(self, function: onnx.FunctionProto):
- logger.debug("enter_function_scope: function %s", function.name)
- self.scopes.enter_function_scope(function)
- ir_values = self.function_shape_env.get_ir_values(function)
- for name, ir_value in ir_values.items():
- inferred_ir_value = self.lookup_or_create(name)
- updated_ir_value = _override_inferred_value_type_with_symbolic_value_type(
- ir_value, inferred_ir_value
- )
- self.bind(name, updated_ir_value)
-
- def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope:
- logger.debug("exit_function_scope: function %s", function.name)
- # Sync ir value back to function_shape_env
- function_scope = self.scopes.exit_function_scope()
- for ir_value in function_scope.values.values():
- self.function_shape_env.bind(ir_value, *ir.get_function_id(function))
- return function_scope
-
- def process_initializer(self, init: onnx.TensorProto):
- array = onnx.numpy_helper.to_array(init, self.external_data_folder)
- self.bind(
- init.name,
- ir.Value(name=init.name, value=array, type=get_initializer_type(init)),
- )
-
- def process_graph_input(self, input: onnx.ValueInfoProto):
- self.bind(input.name, ir.Value(name=input.name, type=input.type))
-
- def process_value_info(self, value_info: onnx.ValueInfoProto):
- logger.debug("process_value_info: %s", value_info)
- value = self.lookup_or_create(value_info.name)
- value.type = value_info.type
- # Populate function shape environment
- self.function_shape_env.load_from_value_info(value_info)
-
- def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
- output_types = {}
- if self.do_shape_inference and not is_control_flow_op(node):
- # Control-flow ops are more complicated. Not supported here yet.
- # TODO: handle optional inputs
- def get_constant_value(i: int) -> onnx.TensorProto | None:
- value = self.input_const_value(node, i)
- if isinstance(value, np.ndarray) and value.size < 20:
- return onnx.numpy_helper.from_array(value, node.input[i])
- return None
-
- input_types = {x: self.input_type(node, i) for i, x in enumerate(node.input)}
- input_data = {x: get_constant_value(i) for i, x in enumerate(node.input)}
- input_data = {k: v for k, v in input_data.items() if v is not None}
- if any(t is None for t in input_types.values()):
- logger.debug(
- "Skipping shape inference for node %s due to missing input type.",
- node.name,
- )
- else:
- # TODO: pass in constant values, ir_version
- try:
- schema = onnx.defs.get_schema(
- node.op_type, self.lookup_version(node.domain), node.domain
- )
- output_types = onnx.shape_inference.infer_node_outputs(
- schema, node, input_types, input_data
- )
- except Exception as e:
- logger.debug(
- "Skipping shape inference for node %s due to exception: %s",
- node.name,
- e,
- )
-
- for output in node.output:
- info = self.lookup_or_create(output)
- if output in output_types:
- # TODO: merge types
- info.type = output_types[output]
-
-
-class ProtoTransformer(ProtoVisitor):
- # TODO(lowpri) Practically this is useless.
- # Subgraph only exist in 'if' nodes. 'if' nodes only exist in torchlib functions.
- # There is no pre-existing value_info in torchlib functions.
- # def exit_scope(self, graph: onnx.GraphProto) -> SubScope:
- # # Also sync updated ir values back to value_info in graph.
- # sub_scope = super().exit_scope(graph)
-
- def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None:
- replacement = self.process_node(node)
- logger.debug(
- "visit_node: %s::%s %s replacement %s",
- node.domain,
- node.op_type,
- node.name,
- "found" if replacement is not None else "missed",
- )
- if replacement is None:
- # No change. Process attributes.
- for attr in node.attribute:
- self.visit_attribute(attr)
- return None
- else:
- self.modified = True
- # We recursively visit the replacement nodes.
- result = []
- for newnode in replacement:
- n = self.visit_node(newnode)
- if n is not None:
- result.extend(n)
- else:
- result.append(newnode)
- return result
-
- def visit_graph(self, graph: onnx.GraphProto) -> dict[str, ir.Value]:
- self.enter_scope(graph)
- self.process_graph(graph)
- for input in graph.input:
- self.process_graph_input(input)
- for init in graph.initializer:
- self.process_initializer(init)
- for value_info in graph.value_info:
- self.process_value_info(value_info)
- updates = []
- nodes = graph.node
- for i, node in enumerate(nodes):
- replacement = self.visit_node(node)
- if replacement is not None:
- updates.append((i, replacement))
- for i, replacement in reversed(updates):
- old_node_name = nodes[i].name
- del nodes[i]
- for newnode in reversed(replacement):
- logger.debug(
- "Replacement node %s for %s. Size %s",
- newnode.name,
- old_node_name,
- newnode.ByteSize(),
- )
- nodes.insert(i, newnode)
- for output in graph.output:
- self.process_graph_output(output)
- return self.exit_scope(graph)
-
-
-class FunctionCallsiteAnalysis(ProtoVisitor):
- """Collects the callsites of each function."""
-
- def __init__(self):
- super().__init__()
- self.functions: dict[ir.FunctionId, onnx.FunctionProto] = {}
- self.function_calls: dict[ir.FunctionId, list[onnx.NodeProto]] = {}
-
- def visit_function(self, function: onnx.FunctionProto):
- # Do not visit function via model.functions.
- # Only visit function at callsites.
- # The purpose of this analysis is to collect the callsites of each function.
- pass
-
- def visit_node(self, node: onnx.NodeProto) -> None:
- if is_local_function_node(node, self.functions):
- function_id = ir.get_function_id_from_node(node)
- self.function_calls.setdefault(function_id, []).append(node)
- for subnode in self.functions[function_id].node:
- self.visit_node(subnode)
-
- def visit_model(self, model: onnx.ModelProto) -> None:
- for function in model.functions:
- self.functions[ir.get_function_id(function)] = function
-
- super().visit_model(model)
-
-
-class FunctionRenamer:
- _POSTFIX_FORMAT = "{name}|{postfix}_{count}"
-
- def __init__(self, postfix="folded"):
- self._function_key_to_instance_count = {}
- self._postfix = postfix
-
- def rename(self, function: onnx.FunctionProto) -> None:
- domain = function.domain
- name = function.name
- key = (domain, name)
- self._function_key_to_instance_count.setdefault(key, 0)
- function.name = self._POSTFIX_FORMAT.format(
- name=name,
- postfix=self._postfix,
- count=self._function_key_to_instance_count[key],
- )
- self._function_key_to_instance_count[key] += 1
-
-
-class FunctionCallsiteProtoTransformer(ProtoTransformer):
- """Unlike other base visitors, this is a special visitor that visits functions at their callsite.
-
- This allows transforming and constructing specialized functions based on callsite context.
- """
-
- _functions: dict[ir.FunctionId, onnx.FunctionProto]
- _function_callsites: dict[ir.FunctionId, list[onnx.NodeProto]]
- _new_functions: list[onnx.FunctionProto]
- _function_renamer: FunctionRenamer
-
- def _gather_function_metadata(self, model: onnx.ModelProto):
- analysis = FunctionCallsiteAnalysis()
- analysis.visit_model(model)
- self._functions = analysis.functions
- self._function_callsites = analysis.function_calls
- self._new_functions = []
- self._function_renamer = FunctionRenamer()
-
- def process_function_outputs(self, function: onnx.FunctionProto) -> bool:
- """Process function outputs.
-
- This method is called when a function is visited at its callsite.
-
- Returns:
- True if the function outputs are modified.
- """
- del function # Unused
- return False
-
- def process_function_node_outputs(
- self,
- node: onnx.NodeProto,
- function_scope: SubScope,
- ) -> None:
- """Fetch value infos of function output to re-bind them for function node output."""
- function = function_scope.owner
- output_values = [function_scope.lookup(output) for output in function.output]
- for actual_name, formal_value in zip(node.output, output_values):
- if formal_value is None:
- raise RuntimeError(
- "Missing output %s in function-call to %s",
- actual_name,
- node.op_type,
- )
- actual_value = self.lookup_or_create(actual_name)
- actual_value.identity_merge_from(formal_value)
- if logger.level <= logging.INFO:
- logger.info(
- "Binding outputs for function %s. %s => %s",
- function.name,
- actual_value,
- node.output,
- )
-
- def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None:
- return self.scopes.lookup_ref_attribute(ref_attr_name)
-
- def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None:
- self.scopes.bind_ref_attribute(ref_attr_name, attr)
-
- def visit_model(self, model: onnx.ModelProto):
- self._gather_function_metadata(model)
-
- self.process_model(model)
- for opset in model.opset_import:
- self.process_opset_import(opset)
- self.visit_graph(model.graph)
-
- for new_function in self._new_functions:
- model.functions.append(new_function)
-
- self.function_shape_env.save_to_model_proto(model)
-
- def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None:
- if is_local_function_node(node, self._functions):
- function_id = ir.get_function_id_from_node(node)
- if function_id not in self._functions:
- # Do not recursively visit new functions.
- return None
- replacement, _ = self.process_function_node(node)
- else:
- replacement = self.process_node(node)
- logger.debug(
- "visit_node: %s::%s %s replacement %s",
- node.domain,
- node.op_type,
- node.name,
- "found" if replacement is not None else "missed",
- )
- if replacement is None:
- # No change. Process attributes.
- for attr in node.attribute:
- self.visit_attribute(attr)
- return None
- else:
- self.modified = True
- # We recursively visit the replacement nodes.
- result = []
- for newnode in replacement:
- n = self.visit_node(newnode)
- if n is not None:
- result.extend(n)
- else:
- result.append(newnode)
- return result
-
- def process_function_node(
- self, node: onnx.NodeProto
- ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]:
- function_id = ir.get_function_id_from_node(node)
- function = self._functions[function_id]
-
- is_unique_callsite = len(self._function_callsites[function_id]) == 1
- if not is_unique_callsite:
- mutable_function = onnx.FunctionProto()
- mutable_function.CopyFrom(function)
- else:
- mutable_function = function
-
- logger.info("Visit function %s node %s", function_id, node.name)
- actual_input_value_infos = [self.lookup(input) for input in node.input]
- # Handle omitted inputs, these are considered optional inputs of the function.
- actual_input_value_infos.extend(
- [None] * (len(function.input) - len(actual_input_value_infos))
- )
- ref_attributes = {
- attr_proto.name: self.lookup_ref_attribute(attr_proto.ref_attr_name)
- for attr_proto in node.attribute
- if attr_proto.ref_attr_name
- }
-
- self.enter_function_scope(mutable_function)
- if logger.level <= logging.INFO:
- printable_actual_input_value_infos = [str(x) for x in actual_input_value_infos]
- logger.info(
- "Actual input value infos: %s",
- printable_actual_input_value_infos,
- )
- logger.info("Enter function scope: %s", self.scopes.current_scope())
-
- logger.debug("Binding inputs for function %s", function.name)
- for actual_input_value_info, formal_input in zip(
- actual_input_value_infos, function.input
- ):
- formal_info = ir.Value(formal_input)
- if actual_input_value_info is not None:
- formal_info.identity_merge_from(actual_input_value_info)
- self.bind(formal_input, formal_info)
-
- for attr_proto in function.attribute_proto:
- # Default value of function attributes.
- self.bind_ref_attribute(attr_proto.name, attr_proto)
-
- for attr_proto in node.attribute:
- if attr_proto.ref_attr_name:
- concrete_attribute = ref_attributes.get(attr_proto.name)
- if concrete_attribute is None:
- continue
- self.bind_ref_attribute(attr_proto.name, concrete_attribute)
- else:
- self.bind_ref_attribute(attr_proto.name, attr_proto)
-
- # Visit inner function nodes.
- node_updates: list[tuple[int, list[onnx.NodeProto]]] = []
- nodes = mutable_function.node
- for i, inner_node in enumerate(nodes):
- replacement = self.visit_node(inner_node)
- if replacement is not None:
- node_updates.append((i, replacement))
- for i, replacement in reversed(node_updates):
- old_node_name = nodes[i].name
- old_node_op_type = nodes[i].op_type
- del nodes[i]
- for newnode in reversed(replacement):
- logger.debug(
- "Replacement node inside function %s: %s for %s %s. Size %s",
- node.name,
- newnode.output,
- old_node_name,
- old_node_op_type,
- newnode.ByteSize(),
- )
- nodes.insert(i, newnode)
- added_domains = set()
- del mutable_function.opset_import[:]
- for inner_node in nodes:
- # Update opset_import if needed.
- if inner_node.domain not in added_domains:
- version = self.lookup_version(inner_node.domain)
- mutable_function.opset_import.append(
- onnx.OperatorSetIdProto(domain=inner_node.domain, version=version)
- )
- added_domains.add(inner_node.domain)
-
- output_updates = self.process_function_outputs(mutable_function)
-
- is_new_function = not is_unique_callsite and (node_updates or output_updates)
- if is_new_function:
- self._new_functions.append(mutable_function)
- self._function_renamer.rename(mutable_function)
- node.op_type = mutable_function.name
-
- function_scope = self.exit_function_scope(mutable_function)
-
- self.process_function_node_outputs(node, function_scope)
-
- logger.info("Exit function scope: %s", function_scope)
- logger.info("Exit function %s node %s", function_id, node.name)
-
- if is_new_function:
- return [node], mutable_function
- return None, None
diff --git a/onnxscript/_legacy_ir/visitor_test.py b/onnxscript/_legacy_ir/visitor_test.py
deleted file mode 100644
index e4559472e3..0000000000
--- a/onnxscript/_legacy_ir/visitor_test.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import unittest
-
-import onnx
-
-from onnxscript._legacy_ir import visitor
-
-
-class FunctionCallsiteProtoTransformerTest(unittest.TestCase):
- def test_function_optional_input_is_recorded_by_shape_env(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[N] x) => (float[N] z) {
- z = custom.function(x)
- }
- <
- domain: "custom",
- opset_import: ["" : 18]
- >
- function (x, optional_y, optional_z) => (return_val)
- {
- return_val = custom.custom_op (x, optional_y, optional_z)
- }
- """
- )
-
- model_visitor = visitor.FunctionCallsiteProtoTransformer()
- model_visitor.visit_model(model)
- self.assertIsNotNone(
- model_visitor.function_shape_env.lookup(model.functions[0], "optional_y")
- )
- self.assertIsNotNone(
- model_visitor.function_shape_env.lookup(model.functions[0], "optional_z")
- )
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/_thirdparty/asciichartpy.py b/onnxscript/_thirdparty/asciichartpy.py
index 3cd91f84f5..88c46202ca 100644
--- a/onnxscript/_thirdparty/asciichartpy.py
+++ b/onnxscript/_thirdparty/asciichartpy.py
@@ -1,5 +1,5 @@
-# SPDX-License-Identifier: MIT
-# Modifications Copyright (c) Microsoft.
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
#
# Copyright © 2016 Igor Kroitor
#
@@ -198,8 +198,8 @@ def plot(series, *, bin_edges=None, cfg=None):
height = cfg.get("height", interval)
ratio = height / interval if interval > 0 else 1
- min2 = int(floor(minimum * ratio))
- max2 = int(ceil(maximum * ratio))
+ min2 = floor(minimum * ratio)
+ max2 = ceil(maximum * ratio)
def clamp(n):
return min(max(n, minimum), maximum)
diff --git a/onnxscript/backend/__init__.py b/onnxscript/backend/__init__.py
index 862c45ce31..59e481eb93 100644
--- a/onnxscript/backend/__init__.py
+++ b/onnxscript/backend/__init__.py
@@ -1,4 +1,2 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
diff --git a/onnxscript/backend/onnx_backend.py b/onnxscript/backend/onnx_backend.py
index 83c9bca39a..ef93bb50b7 100644
--- a/onnxscript/backend/onnx_backend.py
+++ b/onnxscript/backend/onnx_backend.py
@@ -1,8 +1,6 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-
+# ruff: noqa: TID251
import os
import textwrap
@@ -291,7 +289,7 @@ def enumerate_onnx_tests(series, fct_filter=None) -> Iterator[OnnxBackendTest]:
sub = os.path.join(root, "data", series)
if not os.path.exists(sub):
raise FileNotFoundError(
- "Unable to find series of tests in {root!r}, subfolders:\n"
+ f"Unable to find series of tests in {root!r}, subfolders:\n"
+ "\n".join(os.listdir(root))
)
tests = os.listdir(sub)
diff --git a/onnxscript/backend/onnx_backend_test.py b/onnxscript/backend/onnx_backend_test.py
index b640331490..efd9d823d8 100644
--- a/onnxscript/backend/onnx_backend_test.py
+++ b/onnxscript/backend/onnx_backend_test.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
import os
import unittest
diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py
index 01ab09c8f2..cfea1a501c 100644
--- a/onnxscript/backend/onnx_export.py
+++ b/onnxscript/backend/onnx_export.py
@@ -1,21 +1,20 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
from typing import Any, Optional, Sequence
-import numpy
+import numpy as np
import onnx
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
-from onnx.helper import make_node
import onnxscript.onnx_types
import onnxscript.type_annotation
_SINGLE_INDENT = " "
+_SMALL_TENSOR_SIZE = 4
+
kwlist = {
"False",
"None",
@@ -70,10 +69,11 @@ def _get_const_repr(const_node):
if tensor_proto.data_type in {TensorProto.FLOAT, TensorProto.INT64}:
rank = len(tensor_proto.dims)
if rank == 0:
- array = onnx.numpy_helper.to_array(tensor_proto).reshape(1)
- return repr(array[0])
+ array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251
+ return str(array[0])
if rank == 1 and tensor_proto.dims[0] < 5:
- return repr(list(onnx.numpy_helper.to_array(tensor_proto)))
+ nparray = onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251
+ return repr(nparray.tolist())
return None
@@ -121,7 +121,7 @@ def renamer(name):
def _translate_type(onnx_type):
"""Converts a onnx type into a type defined by *onnxscript*."""
- return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type)
+ return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False)
def _translate_signature(inputs, outputs):
@@ -141,6 +141,15 @@ def input_sig(inp: ValueInfoProto | str):
return f"{result}:"
+def _translate_value_infos(value_infos: Sequence[ValueInfoProto]) -> str:
+ def _translate_value_info(value_info: ValueInfoProto) -> str:
+ return f"{_SINGLE_INDENT}'{_cleanup_variable_name(value_info.name)}': {_translate_type(value_info.type)},"
+
+ lines = [_translate_value_info(x) for x in value_infos]
+ lines_joined = "\n".join(lines)
+ return "{\n" + lines_joined + "\n}"
+
+
def _to_str(s):
if isinstance(s, bytes):
return s.decode("utf-8")
@@ -163,7 +172,7 @@ def _attribute_value(attr: onnx.AttributeProto):
if onnx.external_data_helper.uses_external_data(tensor_proto):
return tensor_proto
else:
- return onnx.numpy_helper.to_array(tensor_proto)
+ return onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251
# TODO:
# - onnx.AttributeProto.GRAPH
# - onnx.AttributeProto.SPARSE_TENSOR
@@ -249,11 +258,11 @@ def _cond_is_used_in_loop_body(graph: GraphProto) -> bool:
return False
-class Exporter:
+class _Exporter:
"""Class used for recursive traversal of Proto structures."""
def __init__(
- self, rename: bool, use_operators: bool = False, inline_const: bool = False
+ self, *, rename: bool, use_operators: bool, inline_const: bool, skip_initializers: bool
) -> None:
self.use_operators = use_operators
if rename:
@@ -268,6 +277,8 @@ def __init__(
# _name_remappings: used to undo the SSA-renaming in ONNX control-flow ops.
# We map the multiple SSA-variants back to the same Python variable name.
self._name_remappings: list[dict[str, str]] = []
+ self.skip_initializers = skip_initializers
+ self.skipped_initializers: dict[str, onnx.TensorProto] = {}
def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""
@@ -311,7 +322,7 @@ def _translate_onnx_var_ref(self, var):
def _rename_domain(self, domain: str) -> str:
if domain in {"", "ai.onnx"}:
- return "opset" # TODO: Need checks to avoid name conflicts.
+ return "opset" # TODO: Need checks to avoid name conflicts.
return _cleanup_variable_name(domain) # type: ignore[return-value]
def _make_opset_name(self, domain, version):
@@ -340,18 +351,34 @@ def _translate_graph_body(self, graph, opsets, indent=0):
code = []
if hasattr(graph, "initializer"):
for init in graph.initializer:
- node = make_node(
+ if self.skip_initializers:
+ size = 1
+ for d in init.dims:
+ size *= d
+ if size > _SMALL_TENSOR_SIZE:
+ init_py_name = self._translate_onnx_var(init.name)
+ if init_py_name in self.skipped_initializers:
+ raise RuntimeError(
+ f"Initializer {init.name!r} is already present in skipped_initializers."
+ )
+ self.skipped_initializers[init_py_name] = init
+ continue
+ node = onnx.helper.make_node( # noqa: TID251
"Constant",
[],
[self._translate_onnx_var(init.name)], # type: ignore[list-item]
value=init,
)
- code.append(self._translate_node(node, opsets, indent=indent))
+ pyinit = self._translate_node(node, opsets, indent=indent)
+ if pyinit:
+ code.append(pyinit)
if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0:
raise NotImplementedError("Unable to convert sparse_initilizer into python.")
for node in graph.node:
pynode = self._translate_node(node, opsets, indent=indent)
if pynode:
+ if node.name:
+ pynode += f" # {node.name}"
code.append(pynode)
final = "\n".join(code)
@@ -367,17 +394,17 @@ def _translate_attributes(self, node):
if isinstance(value, str):
attributes.append((at.name, f"{value!r}"))
continue
- if isinstance(value, numpy.ndarray):
+ if isinstance(value, np.ndarray):
onnx_dtype = at.t.data_type
if len(value.shape) == 0:
text = (
f'make_tensor("value", {onnx_dtype}, dims=[], '
- f"vals=[{value.tolist()!r}])"
+ f"vals=[{repr(value.tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')}])"
)
else:
text = (
f'make_tensor("value", {onnx_dtype}, dims={list(value.shape)!r}, '
- f"vals={value.ravel().tolist()!r})"
+ f"vals={repr(value.ravel().tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')})"
)
attributes.append((at.name, text))
continue
@@ -391,6 +418,7 @@ def _translate_attributes(self, node):
text += f", offset={metadata.offset!r}"
if metadata.length:
text += f", length={metadata.length!r}"
+ text += ")"
attributes.append((at.name, text))
continue
attributes.append((at.name, repr(value)))
@@ -400,7 +428,8 @@ def _translate_attributes(self, node):
def _translate_if(self, node, opsets, indent=0):
"""Translates a node If into python."""
sindent = _SINGLE_INDENT * indent
- code = [f"{sindent}if {node.input[0]}:"]
+ cond = self._translate_onnx_var_ref(node.input[0])
+ code = [f"{sindent}if {cond}:"]
if len(node.attribute) != 2:
raise RuntimeError(
f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}."
@@ -484,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0):
rows.extend(self._emit_assign(formal_ins, actual_ins, indent))
+ if node.name:
+ node_name = " # " + node.name
+ else:
+ node_name = ""
if use_iter_var and not use_loop_cond:
- rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
+ rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
# The following is a hacky way to suppress the generation of
# "cond_out = cond_in", which ONNX forces for a FOR loop.
# TODO: a cleaner solution for this.
self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in)
elif not use_iter_var and use_loop_cond:
- rows.append(f"{sindent}while {py_cond}:")
+ rows.append(f"{sindent}while {py_cond}:{node_name}")
elif use_iter_var and use_loop_cond:
# TODO: This needs fixing
- rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
+ rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:")
rows.append(f"{sindent}{_SINGLE_INDENT * 2}break")
else:
@@ -685,15 +718,68 @@ def _translate_graph(self, model: onnx.ModelProto, function_name: Optional[str])
def add(line: str) -> None:
result.append(line)
- add("@script()")
- add(f"def {function_name}{_translate_signature(graph.input, graph.output)}")
+ if self.skip_initializers:
+ indent_level = 2
+ indent = _SINGLE_INDENT
+ else:
+ indent_level = 1
+ indent = ""
+ add(f"{indent}@script()")
+ add(f"{indent}def {function_name}{_translate_signature(graph.input, graph.output)}")
+ indent = indent + _SINGLE_INDENT
doc = graph.doc_string
if doc:
- add(f' """{doc}"""')
- add(self._translate_graph_body(graph, opsets, indent=1))
+ add(f'{indent}"""{doc}"""')
+ add(self._translate_graph_body(graph, opsets, indent=indent_level))
return_values = ", ".join(self._translate_onnx_var(x) for x in graph.output)
- add(f" return {return_values}")
- return "\n".join(result)
+ add(f"{indent}return {return_values}")
+ script = "\n".join(result)
+ if self.skipped_initializers:
+ value_infos = _translate_value_infos(graph.value_info)
+ return self._substitute_initializers(script, function_name, value_infos)
+ return script
+
+ def _substitute_initializers(
+ self, script: str, script_function_name: str, value_infos: str
+ ) -> str:
+ init_names = self.skipped_initializers.keys()
+ # Formal parameters representing initializers (single level indentation)
+ __ = _SINGLE_INDENT
+ initializers_as_params = "\n".join(f"{__}{x}," for x in init_names)
+
+ def generate_rand(name: str, value: TensorProto) -> str:
+ shape = ",".join(str(d) for d in value.dims)
+ if value.data_type == TensorProto.FLOAT:
+ return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
+ if value.data_type == TensorProto.INT8:
+ return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)"
+ raise NotImplementedError(
+ f"Unable to generate random initializer for data type {value.data_type}."
+ )
+
+ random_initializer_values = "\n".join(
+ generate_rand(key, value) for key, value in self.skipped_initializers.items()
+ )
+ # Actual parameter values for initializers (double level indentation)
+ indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names)
+ return f"""
+value_infos = {value_infos}
+
+def make_model(
+{initializers_as_params}
+):
+{script}
+
+{__}model = {script_function_name}.to_model_proto(value_infos=value_infos)
+{__}return model
+
+def make_model_with_random_weights():
+{random_initializer_values}
+{__}model = make_model(
+{indented_initializers_as_params}
+{__})
+{__}return model
+"""
def _import_onnx_types(
self, proto: onnx.ModelProto | onnx.GraphProto | onnx.FunctionProto
@@ -724,7 +810,7 @@ def add(line: str) -> None:
result.append(line)
# Generic imports.
- add("import numpy")
+ add("import numpy as np")
add("from onnx import TensorProto")
add("from onnx.helper import make_tensor")
add("from onnxscript import script, external_tensor")
@@ -779,9 +865,11 @@ def visit_graph(graph: onnx.GraphProto) -> None:
def export2python(
model_onnx,
function_name: Optional[str] = None,
+ *,
rename: bool = False,
use_operators: bool = False,
inline_const: bool = False,
+ skip_initializers: bool = False,
):
"""Exports an ONNX model to the *python* syntax.
@@ -791,6 +879,9 @@ def export2python(
function_name: main function name
use_operators: use Python operators.
inline_const: replace ONNX constants inline if compact
+ skip_initializers: generated script will not include initializers.
+ Instead, a function that generates the model, given initializer values, is generated,
+ along with one that generates random values for the initializers.
Returns:
python code
@@ -799,11 +890,11 @@ def export2python(
.. runpython::
:showcode:
:process:
- import numpy
+ import numpy as np
from sklearn.cluster import KMeans
from mlprodict.onnx_conv import to_onnx
from mlprodict.onnx_tools.onnx_export import export2python
- X = numpy.arange(20).reshape(10, 2).astype(numpy.float32)
+ X = np.arange(20).reshape(10, 2).astype(np.float32)
tr = KMeans(n_clusters=2)
tr.fit(X)
onx = to_onnx(tr, X, target_opset=14)
@@ -816,5 +907,10 @@ def export2python(
if not isinstance(model_onnx, (ModelProto, FunctionProto)):
raise TypeError(f"The function expects a ModelProto not {type(model_onnx)!r}.")
- exporter = Exporter(rename, use_operators, inline_const)
+ exporter = _Exporter(
+ rename=rename,
+ use_operators=use_operators,
+ inline_const=inline_const,
+ skip_initializers=skip_initializers,
+ )
return exporter.export(model_onnx, function_name)
diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py
index efcc8ae8a2..1f913ed897 100644
--- a/onnxscript/backend/onnx_export_test.py
+++ b/onnxscript/backend/onnx_export_test.py
@@ -1,13 +1,13 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import importlib
+import os
import pathlib
import re
+import sys
import unittest
from typing import Pattern
@@ -45,14 +45,8 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
SKIP_TESTS = (
- skip(
- r"^test_ai_onnx_ml_array_feature_extractor",
- "ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'",
- ),
- skip(
- r"^test_ai_onnx_ml_binarizer",
- "ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'",
- ),
+ skip(r"^test_ai_onnx_ml_array_feature_extractor", "ORT doesn't support this op"),
+ skip(r"^test_ai_onnx_ml_binarizer", "ORT doesn't support this op"),
skip(r"^test_center_crop_pad_crop_negative_axes_hwc", "fixme: ORT segfaults"),
skip(r"_scan_", "Operator Scan is not supported by onnxscript"),
skip(r"^test_scan", "Operator Scan is not supported by onnxscript"),
@@ -89,8 +83,27 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
"Change when the converter supports support something like 'while i < n and cond:'",
),
skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"),
+ skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"),
+ skip(r"^test_attention", "ONNX Runtime 1.23 fails on these tests"),
)
+if sys.platform == "win32":
+ SKIP_TESTS = (
+ *SKIP_TESTS,
+ skip(r"^test_gemm_beta", "cannot import module, import_module does not work"),
+ skip(
+ r"^test_averagepool_2d_default",
+ "cannot import module, import_module does not work",
+ ),
+ skip("^test_bitwise_not_3d", "cannot import module, import_module does not work"),
+ skip(
+ "^test_resize_upsample_scales_linear_half_pixel_symmetric",
+ "cannot import module, import_module does not work",
+ ),
+ # tests are too unstable on Windows, not always the same ones are failing.
+ skip("test_", "cannot import module"),
+ )
+
def load_function(obj):
return ort.InferenceSession(obj.SerializeToString(), providers=("CPUExecutionProvider",))
@@ -108,16 +121,24 @@ def run_function(obj, *inputs):
def extract_functions(name: str, content: str, test_folder: pathlib.Path):
if not test_folder.exists():
test_folder.mkdir(exist_ok=True, parents=True)
- init = test_folder / "__init__.py"
- init.touch(exist_ok=True)
- file = test_folder / f"{name}.py"
- file.write_text(content, encoding="utf-8")
+ init = str(test_folder / "__init__.py")
+ with open(init, "w", encoding="utf-8") as f:
+ f.write("\n")
+ filename = str(test_folder / f"{name}.py")
+ with open(filename, "w", encoding="utf-8") as f:
+ f.write(content + "\n")
+ assert os.path.exists(filename), (
+ f"{filename!r} ({os.path.abspath(filename)!r} does not exist."
+ )
import_name = f"tests.{test_folder.parts[-1]}.{name}"
try:
mod = importlib.import_module(import_name)
except (SyntaxError, ImportError) as e:
raise AssertionError(
- f"Unable to import {import_name!r} (file: {file!r})\n----\n{content}"
+ f"Unable to import {import_name!r} (e={e}) (file: {filename!r}, "
+ f"absolute path: {os.path.abspath(filename)!r}, "
+ f"current folder: {os.getcwd()}"
+ f"\n---- CONTENT --\n{content}"
) from e
functions = {
k: v for k, v in mod.__dict__.items() if isinstance(v, onnxscript.OnnxFunction)
@@ -137,7 +158,7 @@ class TestOnnxBackEnd(unittest.TestCase):
test_folder = root_folder / "tests" / "onnx_backend_test_code"
temp_folder = root_folder / "tests" / "export"
- def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options):
+ def _proto_to_os_and_back(self, proto: onnx.FunctionProto, **export_options):
"""Convert a proto to onnxscript code and convert it back to a proto."""
code = onnx_export.export2python(proto, **export_options)
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)
@@ -267,16 +288,6 @@ def _load_function(_):
return session
def _run_function(obj, *inputs):
- print(" run ONNX")
- for i, inp in enumerate(inputs):
- if inp is None:
- print(f" input {i}: None")
- else:
- print(
- f" input {i}: "
- f"dtype={inp.dtype!r} shape={inp.shape!r}"
- f"{inp.ravel().tolist()!r}"
- )
try:
return run_function(obj, *inputs)
except Exception as e:
diff --git a/onnxscript/converter.py b/onnxscript/converter.py
index 515829488d..dfcddefbd3 100644
--- a/onnxscript/converter.py
+++ b/onnxscript/converter.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
import ast
@@ -25,9 +23,6 @@
from onnxscript import type_annotation as ta
from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation
-PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39
-
-
logger = logging.getLogger("onnxscript")
@@ -303,7 +298,7 @@ def generate_unique_name(self, candidate: str = "tmp") -> str:
return r
def _make_onnx_attr(
- self, attrname: str, attrval: Any, attrtype: Optional[int] = None
+ self, attrname: str, attrval: Any, attrtype: int | None = None
) -> irbuilder.IRAttributeValue:
def tensor_name_generator() -> str:
"""Return name to be used for tensor, if we need to create one."""
@@ -429,9 +424,7 @@ def _emit_copy(self, original_var: str, suggested_name: str) -> str:
def _is_constant_expr(self, node: ast.AST) -> None:
if isinstance(node, ast.UnaryOp):
- if self._is_constant_expr(node.operand):
- return True
- return False
+ return self._is_constant_expr(node.operand)
if isinstance(
node,
(
@@ -439,14 +432,10 @@ def _is_constant_expr(self, node: ast.AST) -> None:
ast.BinOp,
ast.UnaryOp,
ast.Compare,
- ast.Num,
- ast.Str,
ast.Attribute,
ast.List,
ast.Load,
- ast.NameConstant,
ast.Constant,
- ast.Str,
),
):
return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node))
@@ -527,10 +516,10 @@ def _translate_attr(
# in a NodeProto.
if val is None:
if attr_meta and attr_meta.required:
- self.fail(expr, "Attribute '{attr_name}' is required.")
+ self.fail(expr, f"Attribute '{attr_name}' is required.")
return None
- attr_type = attr_meta.type if attr_meta else None
- attr = self._make_onnx_attr(attr_name, val, attr_type)
+ attr_type = int(attr_meta.type) if attr_meta else None
+ attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type)
if attr_meta and (attr.type != attr_meta.type):
self.fail(
expr,
@@ -582,9 +571,9 @@ def _translate_expr(
def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]:
"""Translation of an expression where "None" is permitted (eg., for an optional argument).
- None is represented as a NameConstant in Python 3.7 and Constant in Python 3.9.
+ None is represented as a Constant in Python 3.9+.
"""
- if isinstance(node, (ast.NameConstant, ast.Constant)) and (node.value is None):
+ if isinstance(node, ast.Constant) and (node.value is None):
return None
return self._translate_expr(node)
@@ -633,7 +622,7 @@ def _translate_subscript_expr(
target = f"{var_name}_subscripted"
target = self.generate_unique_name(target)
indices = ast_utils.normalize_subscript_expr(node)
- info = self._source_of(node.slice if PY_VERSION_GE_39 else node)
+ info = self._source_of(node.slice)
# Create cached int constants:
# TODO: Do this at a graph-scope level.
@@ -804,6 +793,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
non_scalar_indices.extend(scalar_indices)
if non_scalar_indices:
last_axis, _ = non_scalar_indices[-1]
+ else:
+ # TODO(justinchuby): Clarify what last_axis should be when non_scalar_indices is False
+ last_axis = None
for axis, index_expr in non_scalar_indices:
index_value = self._translate_expr(index_expr)
axis_attr = self._make_onnx_attr("axis", axis)
@@ -943,7 +935,6 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable
opname = node.attr
if opname in module:
return values.Op(module, node.attr)
- warn(f"'{opname}' is not a known op in '{module}'")
return values.Op(module, node.attr)
if isinstance(node, ast.Name):
function_name = node.id
@@ -1243,14 +1234,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
if i != len(loop_stmt.body) - 1:
self.fail(s, "Instruction break must be the last one of the loop.")
- _current_scope = self._current_scope()
- if s.test.id not in _current_scope:
+ current_scope = self._current_scope()
+ if s.test.id not in current_scope:
self.fail(
loop_stmt,
f"Unable to find condition variable {s.test.id!r} in known "
- f"variables {list(_current_scope)!r}.",
+ f"variables {list(current_scope)!r}.",
)
- condition_name = _current_scope[s.test.id].value
+ condition_name = current_scope[s.test.id].value
operator_name = "Not"
continue
self._translate_stmt(s)
@@ -1259,14 +1250,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
if cond_while is not None:
# Loop while
- _current_scope = self._current_scope()
- if cond_while not in _current_scope:
+ current_scope = self._current_scope()
+ if cond_while not in current_scope:
self.fail(
loop_stmt,
f"Unable to find condition variable {cond_while!r} in known "
- f"variables {list(_current_scope)!r}.",
+ f"variables {list(current_scope)!r}.",
)
- o_cond_var = _current_scope[cond_while].value
+ o_cond_var = current_scope[cond_while].value
self.emit(
[o_cond_out],
diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py
index 58ed379686..9a7ca504a7 100644
--- a/onnxscript/converter_test.py
+++ b/onnxscript/converter_test.py
@@ -1,13 +1,10 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
import ast
import inspect
import os
import pathlib
-import sys
import textwrap
import types
import typing
@@ -194,6 +191,27 @@ def cast_add(x, y):
self.assertEqual(y_value_info.type.tensor_type.elem_type, onnx.TensorProto.INT64)
self.assertEqual(output_value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT)
+ def test_set_value_info(self):
+ @script()
+ def double_square(x):
+ square = op.Mul(x, x)
+ return op.Add(square, square)
+
+ # Converting "cast_add" to a ModelProto will generate an incomplete ModelProto,
+ # with input-types undefined (since the script has no type-annotation).
+ model = double_square.to_model_proto()
+ graph = model.graph
+ self.assertEqual(len(graph.value_info), 0)
+ model = double_square.to_model_proto(
+ io_types=FLOAT["N"], value_infos={"square": FLOAT["N"]}
+ )
+ graph = model.graph
+ self.assertEqual(len(graph.value_info), 1)
+ value_info = graph.value_info[0]
+ self.assertEqual(value_info.name, "square")
+ self.assertEqual(value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT)
+ self.assertEqual(value_info.type.tensor_type.shape.dim[0].dim_param, "N")
+
def test_onnxfns1(self):
from tests.models import onnxfns1
@@ -439,8 +457,7 @@ def f1(A: FLOAT[...]) -> FLOAT[...]:
r = A[index]
return r
- ast_name = "_ast" if sys.version_info[:2] < (3, 9) else "ast"
- self.check_failure(f1, f"Left term must be a tuple not ''")
+ self.check_failure(f1, "Left term must be a tuple not ''")
def check_run(self, onnxfn, inputs, expected_output):
# Test by converting to model and running with ORT
@@ -675,6 +692,24 @@ def sum(n: INT64) -> INT64:
self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64))
self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64))
+ def test_function_opset_import(self):
+ """Test that model inherits opset version from the function."""
+ from onnxscript import opset19
+
+ @script()
+ def double(x):
+ return opset19.Add(x, x)
+
+ @script()
+ def model(x):
+ return double(x)
+
+ model_proto = model.to_model_proto()
+ onnx_opset_import = [opset for opset in model_proto.opset_import if opset.domain == ""]
+
+ self.assertEqual(len(onnx_opset_import), 1)
+ self.assertEqual(onnx_opset_import[0].version, 19)
+
if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/onnxscript/diagnostics/infra/__init__.py b/onnxscript/diagnostics/infra/__init__.py
deleted file mode 100644
index 1d771666f2..0000000000
--- a/onnxscript/diagnostics/infra/__init__.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from ._infra import (
- DiagnosticOptions,
- Graph,
- Invocation,
- Level,
- Location,
- Rule,
- RuleCollection,
- Stack,
- StackFrame,
- Tag,
- ThreadFlowLocation,
- levels,
-)
-from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnosticError
-
-__all__ = [
- "Diagnostic",
- "DiagnosticContext",
- "DiagnosticOptions",
- "Graph",
- "Invocation",
- "Level",
- "levels",
- "Location",
- "Rule",
- "RuleCollection",
- "RuntimeErrorWithDiagnosticError",
- "Stack",
- "StackFrame",
- "Tag",
- "ThreadFlowLocation",
-]
diff --git a/onnxscript/diagnostics/infra/_infra.py b/onnxscript/diagnostics/infra/_infra.py
deleted file mode 100644
index f225a191fe..0000000000
--- a/onnxscript/diagnostics/infra/_infra.py
+++ /dev/null
@@ -1,319 +0,0 @@
-"""This file defines an additional layer of abstraction on top of the SARIF OM."""
-
-from __future__ import annotations
-
-import dataclasses
-import enum
-import pprint
-from typing import FrozenSet, List, Mapping, Optional, Sequence, Tuple
-
-from onnxscript.diagnostics.infra import formatter, sarif
-
-
-class Level(enum.IntEnum):
- """The level of a diagnostic.
-
- This class is used to represent the level of a diagnostic. The levels are defined
- by the SARIF specification, and are not modifiable. For alternative categories,
- please use infra.Tag instead. When selecting a level, please consider the following
- guidelines:
-
- - NONE: Informational result that does not indicate the presence of a problem.
- - NOTE: An opportunity for improvement was found.
- - WARNING: A potential problem was found.
- - ERROR: A serious problem was found.
-
- This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
- value maps to the logging levels in Python's logging module. The mapping is as
- follows:
-
- Level.NONE = logging.DEBUG = 10
- Level.NOTE = logging.INFO = 20
- Level.WARNING = logging.WARNING = 30
- Level.ERROR = logging.ERROR = 40
- """
-
- NONE = 10
- NOTE = 20
- WARNING = 30
- ERROR = 40
-
-
-levels = Level
-
-
-class Tag(enum.Enum):
- """The tag of a diagnostic. This class can be inherited to define custom tags."""
-
-
-class PatchedPropertyBag(sarif.PropertyBag):
- """Key/value pairs that provide additional information about the object.
-
- The definition of PropertyBag via SARIF spec is "A property bag is an object (§3.6)
- containing an unordered set of properties with arbitrary names." However it is not
- reflected in the json file, and therefore not captured by the python representation.
- This patch adds additional **kwargs to the `__init__` method to allow recording
- arbitrary key/value pairs.
- """
-
- def __init__(self, tags: Optional[List[str]] = None, **kwargs):
- super().__init__(tags=tags)
- self.__dict__.update(kwargs)
-
-
-@dataclasses.dataclass(frozen=True)
-class Rule:
- id: str
- name: str
- message_default_template: str
- short_description: Optional[str] = None
- full_description: Optional[str] = None
- full_description_markdown: Optional[str] = None
- help_uri: Optional[str] = None
-
- @classmethod
- def from_sarif(cls, **kwargs):
- """Returns a rule from the SARIF reporting descriptor."""
- short_description = kwargs.get("short_description", {}).get("text")
- full_description = kwargs.get("full_description", {}).get("text")
- full_description_markdown = kwargs.get("full_description", {}).get("markdown")
- help_uri = kwargs.get("help_uri")
-
- rule = cls(
- id=kwargs["id"],
- name=kwargs["name"],
- message_default_template=kwargs["message_strings"]["default"]["text"],
- short_description=short_description,
- full_description=full_description,
- full_description_markdown=full_description_markdown,
- help_uri=help_uri,
- )
- return rule
-
- def sarif(self) -> sarif.ReportingDescriptor:
- """Returns a SARIF reporting descriptor of this Rule."""
- short_description = (
- sarif.MultiformatMessageString(text=self.short_description)
- if self.short_description is not None
- else None
- )
- full_description = (
- sarif.MultiformatMessageString(
- text=self.full_description, markdown=self.full_description_markdown
- )
- if self.full_description is not None
- else None
- )
- return sarif.ReportingDescriptor(
- id=self.id,
- name=self.name,
- short_description=short_description,
- full_description=full_description,
- help_uri=self.help_uri,
- )
-
- def format(self, level: Level, *args, **kwargs) -> Tuple[Rule, Level, str]:
- """Returns a tuple of (rule, level, message) for a diagnostic.
-
- This method is used to format the message of a diagnostic. The message is
- formatted using the default template of this rule, and the arguments passed in
- as `*args` and `**kwargs`. The level is used to override the default level of
- this rule.
- """
- return (self, level, self.format_message(*args, **kwargs))
-
- def format_message(self, *args, **kwargs) -> str:
- """Returns the formatted default message of this Rule.
-
- This method should be overridden (with code generation) by subclasses to reflect
- the exact arguments needed by the message template. This is a helper method to
- create the default message for a diagnostic.
- """
- return self.message_default_template.format(*args, **kwargs)
-
- def pretty_print(self):
- pass
-
-
-@dataclasses.dataclass
-class Location:
- uri: Optional[str] = None
- line: Optional[int] = None
- message: Optional[str] = None
- start_column: Optional[int] = None
- end_column: Optional[int] = None
- snippet: Optional[str] = None
- function: Optional[str] = None
-
- def sarif(self) -> sarif.Location:
- """Returns the SARIF representation of this location."""
- return sarif.Location(
- physical_location=sarif.PhysicalLocation(
- artifact_location=sarif.ArtifactLocation(uri=self.uri),
- region=sarif.Region(
- start_line=self.line,
- start_column=self.start_column,
- end_column=self.end_column,
- snippet=sarif.ArtifactContent(text=self.snippet),
- ),
- ),
- message=sarif.Message(text=self.message) if self.message is not None else None,
- )
-
- def pretty_print(self):
- """Prints the location in a traceback style format."""
- unknown = ""
- snippet = self.snippet or unknown
- uri = self.uri or unknown
- function = self.function or unknown
- lineno = self.line if self.line is not None else unknown
- message = f" # {self.message}" if self.message is not None else ""
- print(f' File "{uri}", line {lineno}, in {function}\n {snippet}{message}')
-
-
-@dataclasses.dataclass
-class StackFrame:
- location: Location
-
- def sarif(self) -> sarif.StackFrame:
- """Returns the SARIF representation of this stack frame."""
- return sarif.StackFrame(location=self.location.sarif())
-
- def pretty_print(self):
- """Prints the stack frame in a human-readable format."""
- self.location.pretty_print()
-
-
-@dataclasses.dataclass
-class Stack:
- """Records a stack trace. The frames are in order from newest to oldest stack frame."""
-
- frames: List[StackFrame] = dataclasses.field(default_factory=list)
- message: Optional[str] = None
-
- def sarif(self) -> sarif.Stack:
- """Returns the SARIF representation of this stack."""
- return sarif.Stack(
- frames=[frame.sarif() for frame in self.frames],
- message=sarif.Message(text=self.message) if self.message is not None else None,
- )
-
- def pretty_print(self):
- """Prints the stack in a human-readable format."""
- formatter.pretty_print_title(f"Stack: {self.message}", fill_char="-")
- for frame in reversed(self.frames):
- frame.pretty_print()
-
-
-@dataclasses.dataclass
-class ThreadFlowLocation:
- """Records code location and the initial state."""
-
- location: Location
- state: Mapping[str, str]
- index: int
- stack: Optional[Stack] = None
-
- def sarif(self) -> sarif.ThreadFlowLocation:
- """Returns the SARIF representation of this thread flow location."""
- return sarif.ThreadFlowLocation(
- location=self.location.sarif(),
- state=self.state,
- stack=self.stack.sarif() if self.stack is not None else None,
- )
-
- def pretty_print(self, verbose: bool = False):
- """Prints the thread flow location in a human-readable format."""
- formatter.pretty_print_title(f"Step {self.index}", fill_char="-")
- self.location.pretty_print()
- if verbose:
- print(f"State: {pprint.pformat(self.state)}")
- if self.stack is not None:
- self.stack.pretty_print()
-
-
-@dataclasses.dataclass
-class Graph:
- """A graph of diagnostics.
-
- This class stores the string representation of a model graph.
- The `nodes` and `edges` fields are unused in the current implementation.
- """
-
- graph: str
- name: str
- description: Optional[str] = None
-
- def sarif(self) -> sarif.Graph:
- """Returns the SARIF representation of this graph."""
- return sarif.Graph(
- description=sarif.Message(text=self.graph),
- properties=PatchedPropertyBag(name=self.name, description=self.description),
- )
-
- def pretty_print(
- self,
- verbose: bool = False,
- ):
- """Prints the diagnostics in a human-readable format.
-
- Args:
- verbose: If True, prints all information. Otherwise, only prints compact
- information. E.g., graph name and description.
- log_level: The minimum level of diagnostics to print.
- """
- formatter.pretty_print_title(f"Graph: {self.name}", fill_char="-")
- print(self.description)
- if verbose:
- print(self.graph)
-
-
-@dataclasses.dataclass
-class RuleCollection:
- _rule_id_name_set: FrozenSet[Tuple[str, str]] = dataclasses.field(init=False)
-
- def __post_init__(self) -> None:
- self._rule_id_name_set = frozenset(
- {
- (field.default.id, field.default.name)
- for field in dataclasses.fields(self)
- if isinstance(field.default, Rule)
- }
- )
-
- def __contains__(self, rule: Rule) -> bool:
- """Checks if the rule is in the collection."""
- return (rule.id, rule.name) in self._rule_id_name_set
-
- @classmethod
- def custom_collection_from_list(
- cls, new_collection_class_name: str, rules: Sequence[Rule]
- ) -> RuleCollection:
- """Creates a custom class inherited from RuleCollection with the list of rules."""
- return dataclasses.make_dataclass(
- new_collection_class_name,
- [
- (
- formatter.kebab_case_to_snake_case(rule.name),
- type(rule),
- dataclasses.field(default=rule),
- )
- for rule in rules
- ],
- bases=(cls,),
- )()
-
-
-class Invocation:
- # TODO: Implement this.
- # Tracks top level call arguments and diagnostic options.
- def __init__(self) -> None:
- raise NotImplementedError()
-
-
-@dataclasses.dataclass
-class DiagnosticOptions:
- """Options for diagnostic context."""
-
- log_verbose: bool = dataclasses.field(default=False)
- log_level: Level = dataclasses.field(default=Level.ERROR)
diff --git a/onnxscript/diagnostics/infra/context.py b/onnxscript/diagnostics/infra/context.py
deleted file mode 100644
index 26d0c1bd27..0000000000
--- a/onnxscript/diagnostics/infra/context.py
+++ /dev/null
@@ -1,347 +0,0 @@
-"""A diagnostic context based on SARIF."""
-
-from __future__ import annotations
-
-import contextlib
-import dataclasses
-import gzip
-import logging
-import typing
-from typing import Callable, Generator, List, Literal, Mapping, Optional
-
-from onnxscript.diagnostics import infra
-from onnxscript.diagnostics.infra import formatter, sarif, utils
-from onnxscript.diagnostics.infra.sarif import version as sarif_version
-
-if typing.TYPE_CHECKING:
- from typing_extensions import Self
-
-
-@dataclasses.dataclass
-class Diagnostic:
- rule: infra.Rule
- level: infra.Level
- message: Optional[str] = None
- locations: List[infra.Location] = dataclasses.field(default_factory=list)
- stacks: List[infra.Stack] = dataclasses.field(default_factory=list)
- graphs: List[infra.Graph] = dataclasses.field(default_factory=list)
- thread_flow_locations: List[infra.ThreadFlowLocation] = dataclasses.field(
- default_factory=list
- )
- additional_message: Optional[str] = None
- tags: List[infra.Tag] = dataclasses.field(default_factory=list)
- source_exception: Optional[Exception] = None
- """The exception that caused this diagnostic to be created."""
-
- def __post_init__(self) -> None:
- pass
-
- def sarif(self) -> sarif.Result:
- """Returns the SARIF Result representation of this diagnostic."""
- message = self.message or self.rule.message_default_template
- if self.additional_message:
- message_markdown = (
- f"{message}\n\n## Additional Message:\n\n{self.additional_message}"
- )
- else:
- message_markdown = message
-
- kind: Literal["informational", "fail"] = (
- "informational" if self.level == infra.Level.NONE else "fail"
- )
-
- sarif_result = sarif.Result(
- message=sarif.Message(text=message, markdown=message_markdown),
- level=self.level.name.lower(), # type: ignore[arg-type]
- rule_id=self.rule.id,
- kind=kind,
- )
- sarif_result.locations = [location.sarif() for location in self.locations]
- sarif_result.stacks = [stack.sarif() for stack in self.stacks]
- sarif_result.graphs = [graph.sarif() for graph in self.graphs]
- sarif_result.code_flows = [
- sarif.CodeFlow(
- thread_flows=[
- sarif.ThreadFlow(
- locations=[loc.sarif() for loc in self.thread_flow_locations]
- )
- ]
- )
- ]
- sarif_result.properties = sarif.PropertyBag(tags=[tag.value for tag in self.tags])
- return sarif_result
-
- def with_location(self: Self, location: infra.Location) -> Self:
- """Adds a location to the diagnostic."""
- self.locations.append(location)
- return self
-
- def with_thread_flow_location(self: Self, location: infra.ThreadFlowLocation) -> Self:
- """Adds a thread flow location to the diagnostic."""
- self.thread_flow_locations.append(location)
- return self
-
- def with_stack(self: Self, stack: infra.Stack) -> Self:
- """Adds a stack to the diagnostic."""
- self.stacks.append(stack)
- return self
-
- def with_graph(self: Self, graph: infra.Graph) -> Self:
- """Adds a graph to the diagnostic."""
- self.graphs.append(graph)
- return self
-
- def with_additional_message(self: Self, message: str) -> Self:
- """Adds an additional message to the diagnostic."""
- if self.additional_message is None:
- self.additional_message = message
- else:
- self.additional_message = f"{self.additional_message}\n{message}"
- return self
-
- def with_source_exception(self: Self, exception: Exception) -> Self:
- """Adds the source exception to the diagnostic."""
- self.source_exception = exception
- return self
-
- def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
- """Records the current Python call stack."""
- frames_to_skip += 1 # Skip this function.
- stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
- self.with_stack(stack)
- if len(stack.frames) > 0:
- self.with_location(stack.frames[0].location)
- return stack
-
- def record_python_call(
- self,
- fn: Callable,
- state: Mapping[str, str],
- message: Optional[str] = None,
- frames_to_skip: int = 0,
- ) -> infra.ThreadFlowLocation:
- """Records a python call as one thread flow step."""
- frames_to_skip += 1 # Skip this function.
- stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
- location = utils.function_location(fn)
- location.message = message
- # Add function location to the top of the stack.
- stack.frames.insert(0, infra.StackFrame(location=location))
- thread_flow_location = infra.ThreadFlowLocation(
- location=location,
- state=state,
- index=len(self.thread_flow_locations),
- stack=stack,
- )
- self.with_thread_flow_location(thread_flow_location)
- return thread_flow_location
-
- def pretty_print(self, verbose: bool = False, log_level: infra.Level = infra.Level.ERROR):
- """Prints the diagnostics in a human-readable format.
-
- Args:
- verbose: If True, prints all information. E.g. stack frames, graphs, etc.
- Otherwise, only prints compact information. E.g., rule name and display message.
- log_level: The minimum level of diagnostics to print.
- """
- if self.level.value < log_level.value:
- return
- formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}")
- print(self.message)
- print(self.additional_message)
-
- if not verbose:
- print("\n")
- return
-
- formatter.pretty_print_title("Locations", fill_char="-")
- for location in self.locations:
- location.pretty_print()
- for stack in self.stacks:
- stack.pretty_print()
- formatter.pretty_print_title("Thread Flow Locations", fill_char="-")
- for thread_flow_location in self.thread_flow_locations:
- thread_flow_location.pretty_print(verbose=verbose)
- for graph in self.graphs:
- graph.pretty_print(verbose=verbose)
-
- print()
-
- # TODO: print help url to rule at the end.
-
-
-class RuntimeErrorWithDiagnosticError(RuntimeError):
- """Runtime error with enclosed diagnostic information."""
-
- def __init__(self, diagnostic: Diagnostic):
- super().__init__(diagnostic.message)
- self.diagnostic = diagnostic
-
-
-@dataclasses.dataclass
-class DiagnosticContext:
- name: str
- version: str
- options: infra.DiagnosticOptions = dataclasses.field(
- default_factory=infra.DiagnosticOptions
- )
- diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
- logger: logging.Logger = dataclasses.field(
- init=True, default_factory=lambda: logging.getLogger().getChild("diagnostics")
- )
- # TODO(bowbao): Implement this.
- # _invocation: infra.Invocation = dataclasses.field(init=False)
- _inflight_diagnostics: List[Diagnostic] = dataclasses.field(
- init=False, default_factory=list
- )
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- return None
-
- def sarif(self) -> sarif.Run:
- """Returns the SARIF Run object."""
- unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
- return sarif.Run(
- tool=sarif.Tool(
- driver=sarif.ToolComponent(
- name=self.name,
- version=self.version,
- rules=[rule.sarif() for rule in unique_rules],
- )
- ),
- results=[diagnostic.sarif() for diagnostic in self.diagnostics],
- )
-
- def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined]
- """Returns the SARIF Log object."""
- return sarif.SarifLog(
- version=sarif_version.SARIF_VERSION,
- schema_uri=sarif_version.SARIF_SCHEMA_LINK,
- runs=[self.sarif()],
- )
-
- def to_json(self) -> str:
- return formatter.sarif_to_json(self.sarif_log())
-
- def dump(self, file_path: str, compress: bool = False) -> None:
- """Dumps the SARIF log to a file."""
- if compress:
- with gzip.open(file_path, "wt", encoding="utf-8") as f:
- f.write(self.to_json())
- else:
- with open(file_path, "w", encoding="utf-8") as f:
- f.write(self.to_json())
-
- def log(self, diagnostic: Diagnostic) -> None:
- """Adds a diagnostic to the context.
-
- Use this method to add diagnostics that are not created by the context.
-
- Args:
- diagnostic: The diagnostic to add.
- """
- if not isinstance(diagnostic, Diagnostic):
- raise TypeError(
- f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
- )
- self.diagnostics.append(diagnostic)
- self.logger.log(diagnostic.level, diagnostic.message)
- self.logger.log(diagnostic.level, diagnostic.additional_message)
-
- def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
- self.log(diagnostic)
- if diagnostic.level == infra.Level.ERROR:
- raise RuntimeErrorWithDiagnosticError(diagnostic) from diagnostic.source_exception
-
- @contextlib.contextmanager
- def add_inflight_diagnostic(
- self, diagnostic: Diagnostic
- ) -> Generator[Diagnostic, None, None]:
- """Adds a diagnostic to the context.
-
- Use this method to add diagnostics that are not created by the context.
-
- Args:
- diagnostic: The diagnostic to add.
- """
- self._inflight_diagnostics.append(diagnostic)
- try:
- yield diagnostic
- finally:
- self._inflight_diagnostics.pop()
-
- def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
- """Pushes a diagnostic to the inflight diagnostics stack.
-
- Args:
- diagnostic: The diagnostic to push.
-
- Raises:
- ValueError: If the rule is not supported by the tool.
- """
- self._inflight_diagnostics.append(diagnostic)
-
- def pop_inflight_diagnostic(self) -> Diagnostic:
- """Pops the last diagnostic from the inflight diagnostics stack.
-
- Returns:
- The popped diagnostic.
- """
- return self._inflight_diagnostics.pop()
-
- def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
- if rule is None:
- # TODO(bowbao): Create builtin-rules and create diagnostic using that.
- if len(self._inflight_diagnostics) <= 0:
- raise AssertionError("No inflight diagnostics")
-
- return self._inflight_diagnostics[-1]
- else:
- # TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]]
- for diagnostic in reversed(self._inflight_diagnostics):
- if diagnostic.rule == rule:
- return diagnostic
- raise AssertionError(f"No inflight diagnostic for rule {rule.name}")
-
- def pretty_print(
- self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None
- ) -> None:
- """Prints the diagnostics in a human-readable format.
-
- Args:
- verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print.
- If not specified, uses the value of 'self.options.log_verbose'.
- log_level: The minimum level of diagnostics to print.
- If not specified, uses the value of 'self.options.log_level'.
- """
- if verbose is None:
- verbose = self.options.log_verbose
- if log_level is None:
- log_level = self.options.log_level
-
- formatter.pretty_print_title(f"Diagnostic Run {self.name} version {self.version}")
- print(f"verbose: {verbose}, log level: {log_level}")
- diagnostic_stats = {level: 0 for level in infra.Level}
- for diagnostic in self.diagnostics:
- diagnostic_stats[diagnostic.level] += 1
- formatter.pretty_print_title(
- " ".join(f"{diagnostic_stats[level]} {level.name}" for level in infra.Level)
- )
-
- for diagnostic in self.diagnostics:
- diagnostic.pretty_print(verbose, log_level)
-
- unprinted_diagnostic_stats = [
- (level, count)
- for level, count in diagnostic_stats.items()
- if count > 0 and level.value < log_level.value
- ]
- if unprinted_diagnostic_stats:
- print(
- f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} "
- "were not printed due to the log level."
- )
- print()
diff --git a/onnxscript/diagnostics/infra/decorator.py b/onnxscript/diagnostics/infra/decorator.py
deleted file mode 100644
index e72da19c42..0000000000
--- a/onnxscript/diagnostics/infra/decorator.py
+++ /dev/null
@@ -1,151 +0,0 @@
-from __future__ import annotations
-
-import functools
-import traceback
-from typing import Any, Callable, Dict, Optional, Tuple, Type
-
-from onnxscript._internal import runtime_typing
-from onnxscript.diagnostics import infra
-from onnxscript.diagnostics.infra import formatter, utils
-
-MessageFormatterType = Callable[..., str]
-
-
-@runtime_typing.checked
-def format_message_in_text(
- fn: Callable, # pylint: disable=unused-argument
- *args: Any,
- **kwargs: Any,
-) -> str:
- return f"{formatter.display_name(fn)}. "
-
-
-@runtime_typing.checked
-def format_exception_in_markdown(exception: Exception) -> str:
- msg_list = ["### Exception log", "```"]
- msg_list.extend(
- traceback.format_exception(type(exception), exception, exception.__traceback__)
- )
- msg_list.append("```")
- return "\n".join(msg_list)
-
-
-@runtime_typing.checked
-def format_function_signature_in_markdown(
- fn: Callable,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- format_argument: Callable[[Any], str] = formatter.format_argument,
-) -> str:
- msg_list = [f"### Function Signature {formatter.display_name(fn)}"]
-
- state = utils.function_state(fn, args, kwargs)
-
- for k, v in state.items():
- msg_list.append(f"- {k}: {format_argument(v)}")
-
- return "\n".join(msg_list)
-
-
-@runtime_typing.checked
-def format_return_values_in_markdown(
- return_values: Any,
- format_argument: Callable[[Any], str] = formatter.format_argument,
-) -> str:
- return f"- Return value: {format_argument(return_values)}"
-
-
-ModifierCallableType = Callable[
- [infra.Diagnostic, Callable, Tuple[Any, ...], Dict[str, Any], Any], None
-]
-
-
-@runtime_typing.checked
-def diagnose_call(
- rule: infra.Rule,
- *,
- level: infra.Level = infra.Level.NONE,
- diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic,
- format_argument: Callable[[Any], str] = formatter.format_argument,
- diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
-) -> Callable:
- def decorator(fn):
- @functools.wraps(fn)
- def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
- common_error_message = "diagnose_call can only be applied to callables"
- if not callable(fn):
- raise AssertionError( # noqa: TRY004
- f"{common_error_message}. Got {type(fn)} instead of callable."
- )
- arg0 = args[0] if len(args) > 0 else None
- if isinstance(ctx := arg0, infra.DiagnosticContext):
- pass
- elif isinstance(
- ctx := getattr(arg0, "diagnostic_context", None),
- infra.DiagnosticContext,
- ):
- pass
- else:
- # NOTE: At decorate time, it can't tell if a callable is function or method.
- # Technically both are regarded as function at that time.
- raise AssertionError( # noqa: TRY004
- f"{common_error_message}. For {fn}, "
- f"If it is a function, a DiagnosticContext instance must be present as "
- f"the first argument. "
- f"If it is a method, a DiagnosticContext instance must be present as "
- f"the attribute 'diagnostic_context' of the 'self' argument."
- )
-
- diag = diagnostic_type(
- rule,
- level,
- diagnostic_message_formatter(fn, *args, **kwargs),
- )
-
- # pop the decorator frame
- # TODO(bowbao): by default diagnostic doesn't have stack.
- # So need to check before doing this. Make the code cleaner.
- # Option: do not capture stack by default in diagnostic initialization.
- stack: Optional[infra.Stack] = None
- if len(diag.stacks) > 0:
- stack = diag.stacks[0]
- stack.frames.pop(0)
-
- # set function location
- fn_location = utils.function_location(fn)
- diag.locations.insert(0, fn_location)
- # Add function location to the top of the stack.
- if stack is not None:
- stack.frames.insert(0, infra.StackFrame(location=fn_location))
-
- additional_messages = [
- format_function_signature_in_markdown(fn, args, kwargs, format_argument),
- ]
-
- return_values: Any = None
- with ctx.add_inflight_diagnostic(diag) as diag:
- try:
- return_values = fn(*args, **kwargs)
- additional_messages.append(
- format_return_values_in_markdown(return_values, format_argument)
- )
- except Exception as e: # pylint: disable=broad-exception-caught
- # Record exception.
- diag.level = infra.levels.ERROR
- # TODO(bowbao): Message emitting api.
- diag.message = diag.message or ""
- diag.message += f"Raised from:\n {type(e).__name__}: {e}"
- diag.with_source_exception(e)
- additional_messages.append(format_exception_in_markdown(e))
- else:
- return return_values
- finally:
- diag.with_additional_message("\n".join(additional_messages).strip())
- ctx.log_and_raise_if_error(diag)
-
- return wrapper
-
- return decorator
-
-
-# TODO(bowbao): decorator to report only when failed.
diff --git a/onnxscript/diagnostics/infra/formatter.py b/onnxscript/diagnostics/infra/formatter.py
deleted file mode 100644
index c54e81fed4..0000000000
--- a/onnxscript/diagnostics/infra/formatter.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from __future__ import annotations
-
-import dataclasses
-import json
-import re
-from typing import Any, Callable, Dict, List, Optional, Union
-
-from onnxscript._internal import runtime_typing
-from onnxscript.diagnostics.infra import sarif
-
-# A list of types in the SARIF module to support pretty printing.
-# This is solely for type annotation for the functions below.
-_SarifClass = Union[
- sarif.SarifLog,
- sarif.Run,
- sarif.ReportingDescriptor,
- sarif.Result,
-]
-
-
-@runtime_typing.checked
-def snake_case_to_camel_case(s: str) -> str:
- splits = s.split("_")
- if len(splits) <= 1:
- return s
- return "".join([splits[0], *map(str.capitalize, splits[1:])])
-
-
-@runtime_typing.checked
-def camel_case_to_snake_case(s: str) -> str:
- return re.sub(r"([A-Z])", r"_\1", s).lower()
-
-
-@runtime_typing.checked
-def kebab_case_to_snake_case(s: str) -> str:
- return s.replace("-", "_")
-
-
-@runtime_typing.checked
-def _convert_key(
- object: Union[Dict[str, Any], Any], convert: Callable[[str], str]
-) -> Union[Dict[str, Any], Any]:
- """Convert and update keys in a dictionary with "convert".
-
- Any value that is a dictionary will be recursively updated.
- Any value that is a list will be recursively searched.
-
- Args:
- object: The object to update.
- convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`.
-
- Returns:
- The updated object.
- """
- if not isinstance(object, Dict):
- return object
- new_dict = {}
- for k, v in object.items():
- new_k = convert(k)
- if isinstance(v, Dict):
- new_v = _convert_key(v, convert)
- elif isinstance(v, List):
- new_v = [_convert_key(elem, convert) for elem in v]
- else:
- new_v = v
- if new_v is None:
- # Otherwise unnesseraily bloated sarif log with "null"s.
- continue
- if new_v == -1:
- # WAR: -1 as default value shouldn't be logged into sarif.
- continue
-
- new_dict[new_k] = new_v
-
- return new_dict
-
-
-@runtime_typing.checked
-def sarif_to_json(attr_cls_obj: _SarifClass, indent: Optional[str] = " ") -> str:
- dict = dataclasses.asdict(attr_cls_obj)
- dict = _convert_key(dict, snake_case_to_camel_case)
- return json.dumps(dict, indent=indent, separators=(",", ":"))
-
-
-@runtime_typing.checked
-def pretty_print_title(
- title: str, width: int = 80, fill_char: str = "=", print_output: bool = True
-) -> str:
- """Pretty prints title in below format:
-
- ==================== title ====================
- """
- msg = f" {title} ".center(width, fill_char)
- if print_output:
- print(msg)
- return msg
-
-
-@runtime_typing.checked
-def pretty_print_item_title(
- title: str, fill_char: str = "=", print_output: bool = True
-) -> str:
- """Pretty prints title in below format:
-
- title
- =====
- """
- msg_list = []
- msg_list.append(title)
- msg_list.append(fill_char * len(title))
-
- msg = "\n".join(msg_list)
- if print_output:
- print(msg)
- return msg
-
-
-@runtime_typing.checked
-def format_argument(obj: Any) -> str:
- return f"{type(obj)}"
-
-
-@runtime_typing.checked
-def display_name(fn: Callable) -> str:
- if hasattr(fn, "__qualname__"):
- return fn.__qualname__
- elif hasattr(fn, "__name__"):
- return fn.__name__
- else:
- return str(fn)
diff --git a/onnxscript/diagnostics/infra/sarif/__init__.py b/onnxscript/diagnostics/infra/sarif/__init__.py
deleted file mode 100644
index e610c3b754..0000000000
--- a/onnxscript/diagnostics/infra/sarif/__init__.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from onnxscript.diagnostics.infra.sarif._address import Address
-from onnxscript.diagnostics.infra.sarif._artifact import Artifact
-from onnxscript.diagnostics.infra.sarif._artifact_change import ArtifactChange
-from onnxscript.diagnostics.infra.sarif._artifact_content import ArtifactContent
-from onnxscript.diagnostics.infra.sarif._artifact_location import ArtifactLocation
-from onnxscript.diagnostics.infra.sarif._attachment import Attachment
-from onnxscript.diagnostics.infra.sarif._code_flow import CodeFlow
-from onnxscript.diagnostics.infra.sarif._configuration_override import (
- ConfigurationOverride,
-)
-from onnxscript.diagnostics.infra.sarif._conversion import Conversion
-from onnxscript.diagnostics.infra.sarif._edge import Edge
-from onnxscript.diagnostics.infra.sarif._edge_traversal import EdgeTraversal
-from onnxscript.diagnostics.infra.sarif._exception import Exception
-from onnxscript.diagnostics.infra.sarif._external_properties import ExternalProperties
-from onnxscript.diagnostics.infra.sarif._external_property_file_reference import (
- ExternalPropertyFileReference,
-)
-from onnxscript.diagnostics.infra.sarif._external_property_file_references import (
- ExternalPropertyFileReferences,
-)
-from onnxscript.diagnostics.infra.sarif._fix import Fix
-from onnxscript.diagnostics.infra.sarif._graph import Graph
-from onnxscript.diagnostics.infra.sarif._graph_traversal import GraphTraversal
-from onnxscript.diagnostics.infra.sarif._invocation import Invocation
-from onnxscript.diagnostics.infra.sarif._location import Location
-from onnxscript.diagnostics.infra.sarif._location_relationship import (
- LocationRelationship,
-)
-from onnxscript.diagnostics.infra.sarif._logical_location import LogicalLocation
-from onnxscript.diagnostics.infra.sarif._message import Message
-from onnxscript.diagnostics.infra.sarif._multiformat_message_string import (
- MultiformatMessageString,
-)
-from onnxscript.diagnostics.infra.sarif._node import Node
-from onnxscript.diagnostics.infra.sarif._notification import Notification
-from onnxscript.diagnostics.infra.sarif._physical_location import PhysicalLocation
-from onnxscript.diagnostics.infra.sarif._property_bag import PropertyBag
-from onnxscript.diagnostics.infra.sarif._rectangle import Rectangle
-from onnxscript.diagnostics.infra.sarif._region import Region
-from onnxscript.diagnostics.infra.sarif._replacement import Replacement
-from onnxscript.diagnostics.infra.sarif._reporting_configuration import (
- ReportingConfiguration,
-)
-from onnxscript.diagnostics.infra.sarif._reporting_descriptor import ReportingDescriptor
-from onnxscript.diagnostics.infra.sarif._reporting_descriptor_reference import (
- ReportingDescriptorReference,
-)
-from onnxscript.diagnostics.infra.sarif._reporting_descriptor_relationship import (
- ReportingDescriptorRelationship,
-)
-from onnxscript.diagnostics.infra.sarif._result import Result
-from onnxscript.diagnostics.infra.sarif._result_provenance import ResultProvenance
-from onnxscript.diagnostics.infra.sarif._run import Run
-from onnxscript.diagnostics.infra.sarif._run_automation_details import (
- RunAutomationDetails,
-)
-from onnxscript.diagnostics.infra.sarif._sarif_log import SarifLog
-from onnxscript.diagnostics.infra.sarif._special_locations import SpecialLocations
-from onnxscript.diagnostics.infra.sarif._stack import Stack
-from onnxscript.diagnostics.infra.sarif._stack_frame import StackFrame
-from onnxscript.diagnostics.infra.sarif._suppression import Suppression
-from onnxscript.diagnostics.infra.sarif._thread_flow import ThreadFlow
-from onnxscript.diagnostics.infra.sarif._thread_flow_location import ThreadFlowLocation
-from onnxscript.diagnostics.infra.sarif._tool import Tool
-from onnxscript.diagnostics.infra.sarif._tool_component import ToolComponent
-from onnxscript.diagnostics.infra.sarif._tool_component_reference import (
- ToolComponentReference,
-)
-from onnxscript.diagnostics.infra.sarif._translation_metadata import TranslationMetadata
-from onnxscript.diagnostics.infra.sarif._version_control_details import (
- VersionControlDetails,
-)
-from onnxscript.diagnostics.infra.sarif._web_request import WebRequest
-from onnxscript.diagnostics.infra.sarif._web_response import WebResponse
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_address.py b/onnxscript/diagnostics/infra/sarif/_address.py
deleted file mode 100644
index c4b691f348..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_address.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag
-
-
-@dataclasses.dataclass
-class Address:
- """A physical or virtual address, or a range of addresses, in an 'addressable region' (memory or a binary file)."""
-
- absolute_address: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "absoluteAddress"}
- )
- fully_qualified_name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "fullyQualifiedName"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- kind: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "kind"}
- )
- length: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "length"}
- )
- name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "name"}
- )
- offset_from_parent: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "offsetFromParent"}
- )
- parent_index: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "parentIndex"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- relative_address: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "relativeAddress"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_artifact.py b/onnxscript/diagnostics/infra/sarif/_artifact.py
deleted file mode 100644
index afec8b5e97..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_artifact.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_content,
- _artifact_location,
- _message,
- _property_bag,
-)
-
-
-@dataclasses.dataclass
-class Artifact:
- """A single artifact. In some cases, this artifact might be nested within another artifact."""
-
- contents: Optional[_artifact_content.ArtifactContent] = dataclasses.field(
- default=None, metadata={"schema_property_name": "contents"}
- )
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- encoding: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "encoding"}
- )
- hashes: Any = dataclasses.field(default=None, metadata={"schema_property_name": "hashes"})
- last_modified_time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "lastModifiedTimeUtc"}
- )
- length: int = dataclasses.field(default=-1, metadata={"schema_property_name": "length"})
- location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "location"}
- )
- mime_type: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "mimeType"}
- )
- offset: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "offset"}
- )
- parent_index: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "parentIndex"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- roles: Optional[
- List[
- Literal[
- "analysisTarget",
- "attachment",
- "responseFile",
- "resultFile",
- "standardStream",
- "tracedFile",
- "unmodified",
- "modified",
- "added",
- "deleted",
- "renamed",
- "uncontrolled",
- "driver",
- "extension",
- "translation",
- "taxonomy",
- "policy",
- "referencedOnCommandLine",
- "memoryContents",
- "directory",
- "userSpecifiedConfiguration",
- "toolSpecifiedConfiguration",
- "debugOutputFile",
- ]
- ]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "roles"})
- source_language: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "sourceLanguage"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_change.py b/onnxscript/diagnostics/infra/sarif/_artifact_change.py
deleted file mode 100644
index 3db2c0444b..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_artifact_change.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_location,
- _property_bag,
- _replacement,
-)
-
-
-@dataclasses.dataclass
-class ArtifactChange:
- """A change to a single artifact."""
-
- artifact_location: _artifact_location.ArtifactLocation = dataclasses.field(
- metadata={"schema_property_name": "artifactLocation"}
- )
- replacements: List[_replacement.Replacement] = dataclasses.field(
- metadata={"schema_property_name": "replacements"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_content.py b/onnxscript/diagnostics/infra/sarif/_artifact_content.py
deleted file mode 100644
index 4038066198..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_artifact_content.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _multiformat_message_string,
- _property_bag,
-)
-
-
-@dataclasses.dataclass
-class ArtifactContent:
- """Represents the contents of an artifact."""
-
- binary: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "binary"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- rendered: Optional[_multiformat_message_string.MultiformatMessageString] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "rendered"})
- )
- text: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "text"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_location.py b/onnxscript/diagnostics/infra/sarif/_artifact_location.py
deleted file mode 100644
index ed6f9b3916..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_artifact_location.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag
-
-
-@dataclasses.dataclass
-class ArtifactLocation:
- """Specifies the location of an artifact."""
-
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "uri"}
- )
- uri_base_id: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "uriBaseId"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_attachment.py b/onnxscript/diagnostics/infra/sarif/_attachment.py
deleted file mode 100644
index b58b858e0c..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_attachment.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_location,
- _message,
- _property_bag,
- _rectangle,
- _region,
-)
-
-
-@dataclasses.dataclass
-class Attachment:
- """An artifact relevant to a result."""
-
- artifact_location: _artifact_location.ArtifactLocation = dataclasses.field(
- metadata={"schema_property_name": "artifactLocation"}
- )
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- rectangles: Optional[List[_rectangle.Rectangle]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "rectangles"}
- )
- regions: Optional[List[_region.Region]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "regions"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_code_flow.py b/onnxscript/diagnostics/infra/sarif/_code_flow.py
deleted file mode 100644
index 69615f18f2..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_code_flow.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag, _thread_flow
-
-
-@dataclasses.dataclass
-class CodeFlow:
- """A set of threadFlows which together describe a pattern of code execution relevant to detecting a result."""
-
- thread_flows: List[_thread_flow.ThreadFlow] = dataclasses.field(
- metadata={"schema_property_name": "threadFlows"}
- )
- message: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_configuration_override.py b/onnxscript/diagnostics/infra/sarif/_configuration_override.py
deleted file mode 100644
index c2fa3ae0a6..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_configuration_override.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _property_bag,
- _reporting_configuration,
- _reporting_descriptor_reference,
-)
-
-
-@dataclasses.dataclass
-class ConfigurationOverride:
- """Information about how a specific rule or notification was reconfigured at runtime."""
-
- configuration: _reporting_configuration.ReportingConfiguration = dataclasses.field(
- metadata={"schema_property_name": "configuration"}
- )
- descriptor: _reporting_descriptor_reference.ReportingDescriptorReference = (
- dataclasses.field(metadata={"schema_property_name": "descriptor"})
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_conversion.py b/onnxscript/diagnostics/infra/sarif/_conversion.py
deleted file mode 100644
index 6078c525f0..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_conversion.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_location,
- _invocation,
- _property_bag,
- _tool,
-)
-
-
-@dataclasses.dataclass
-class Conversion:
- """Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format."""
-
- tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"})
- analysis_tool_log_files: Optional[List[_artifact_location.ArtifactLocation]] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "analysisToolLogFiles"}
- )
- )
- invocation: Optional[_invocation.Invocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "invocation"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_edge.py b/onnxscript/diagnostics/infra/sarif/_edge.py
deleted file mode 100644
index 1142e61dca..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_edge.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag
-
-
-@dataclasses.dataclass
-class Edge:
- """Represents a directed edge in a graph."""
-
- id: str = dataclasses.field(metadata={"schema_property_name": "id"})
- source_node_id: str = dataclasses.field(metadata={"schema_property_name": "sourceNodeId"})
- target_node_id: str = dataclasses.field(metadata={"schema_property_name": "targetNodeId"})
- label: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "label"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_edge_traversal.py b/onnxscript/diagnostics/infra/sarif/_edge_traversal.py
deleted file mode 100644
index dbaba449e4..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_edge_traversal.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag
-
-
-@dataclasses.dataclass
-class EdgeTraversal:
- """Represents the traversal of a single edge during a graph traversal."""
-
- edge_id: str = dataclasses.field(metadata={"schema_property_name": "edgeId"})
- final_state: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "finalState"}
- )
- message: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- step_over_edge_count: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stepOverEdgeCount"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_exception.py b/onnxscript/diagnostics/infra/sarif/_exception.py
deleted file mode 100644
index 71c0db73a8..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_exception.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _exception, _property_bag, _stack
-
-
-@dataclasses.dataclass
-class Exception:
- """Describes a runtime exception encountered during the execution of an analysis tool."""
-
- inner_exceptions: Optional[List[_exception.Exception]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "innerExceptions"}
- )
- kind: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "kind"}
- )
- message: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- stack: Optional[_stack.Stack] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stack"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_external_properties.py b/onnxscript/diagnostics/infra/sarif/_external_properties.py
deleted file mode 100644
index d63a16aff8..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_external_properties.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _address,
- _artifact,
- _conversion,
- _graph,
- _invocation,
- _logical_location,
- _property_bag,
- _result,
- _thread_flow_location,
- _tool_component,
- _web_request,
- _web_response,
-)
-
-
-@dataclasses.dataclass
-class ExternalProperties:
- """The top-level element of an external property file."""
-
- addresses: Optional[List[_address.Address]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "addresses"}
- )
- artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "artifacts"}
- )
- conversion: Optional[_conversion.Conversion] = dataclasses.field(
- default=None, metadata={"schema_property_name": "conversion"}
- )
- driver: Optional[_tool_component.ToolComponent] = dataclasses.field(
- default=None, metadata={"schema_property_name": "driver"}
- )
- extensions: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "extensions"}
- )
- externalized_properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "externalizedProperties"}
- )
- graphs: Optional[List[_graph.Graph]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "graphs"}
- )
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- invocations: Optional[List[_invocation.Invocation]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "invocations"}
- )
- logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "logicalLocations"}
- )
- policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "policies"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- results: Optional[List[_result.Result]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "results"}
- )
- run_guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "runGuid"}
- )
- schema: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "schema"}
- )
- taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "taxonomies"}
- )
- thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "threadFlowLocations"}
- )
- )
- translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "translations"}
- )
- version: Optional[Literal["2.1.0"]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "version"}
- )
- web_requests: Optional[List[_web_request.WebRequest]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webRequests"}
- )
- web_responses: Optional[List[_web_response.WebResponse]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webResponses"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py b/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py
deleted file mode 100644
index b5bfec0320..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag
-
-
-@dataclasses.dataclass
-class ExternalPropertyFileReference:
- """Contains information that enables a SARIF consumer to locate the external property file that contains the value of an externalized property associated with the run."""
-
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- item_count: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "itemCount"}
- )
- location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "location"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py b/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py
deleted file mode 100644
index d596a7a87a..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _external_property_file_reference,
- _property_bag,
-)
-
-
-@dataclasses.dataclass
-class ExternalPropertyFileReferences:
- """References to external property files that should be inlined with the content of a root log file."""
-
- addresses: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "addresses"})
- artifacts: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "artifacts"})
- conversion: Optional[_external_property_file_reference.ExternalPropertyFileReference] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "conversion"})
- )
- driver: Optional[_external_property_file_reference.ExternalPropertyFileReference] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "driver"})
- )
- extensions: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "extensions"})
- externalized_properties: Optional[
- _external_property_file_reference.ExternalPropertyFileReference
- ] = dataclasses.field(
- default=None, metadata={"schema_property_name": "externalizedProperties"}
- )
- graphs: Optional[List[_external_property_file_reference.ExternalPropertyFileReference]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "graphs"})
- )
- invocations: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "invocations"})
- logical_locations: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "logicalLocations"})
- policies: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "policies"})
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- results: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "results"})
- taxonomies: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "taxonomies"})
- thread_flow_locations: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(
- default=None, metadata={"schema_property_name": "threadFlowLocations"}
- )
- translations: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "translations"})
- web_requests: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "webRequests"})
- web_responses: Optional[
- List[_external_property_file_reference.ExternalPropertyFileReference]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "webResponses"})
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_fix.py b/onnxscript/diagnostics/infra/sarif/_fix.py
deleted file mode 100644
index 042f70f47a..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_fix.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _artifact_change, _message, _property_bag
-
-
-@dataclasses.dataclass
-class Fix:
- """A proposed fix for the problem represented by a result object. A fix specifies a set of artifacts to modify. For each artifact, it specifies a set of bytes to remove, and provides a set of new bytes to replace them."""
-
- artifact_changes: List[_artifact_change.ArtifactChange] = dataclasses.field(
- metadata={"schema_property_name": "artifactChanges"}
- )
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_graph.py b/onnxscript/diagnostics/infra/sarif/_graph.py
deleted file mode 100644
index f068e663de..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_graph.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _edge, _message, _node, _property_bag
-
-
-@dataclasses.dataclass
-class Graph:
- """A network of nodes and directed edges that describes some aspect of the structure of the code (for example, a call graph)."""
-
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- edges: Optional[List[_edge.Edge]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "edges"}
- )
- nodes: Optional[List[_node.Node]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "nodes"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_graph_traversal.py b/onnxscript/diagnostics/infra/sarif/_graph_traversal.py
deleted file mode 100644
index ec9c92a9f8..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_graph_traversal.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _edge_traversal, _message, _property_bag
-
-
-@dataclasses.dataclass
-class GraphTraversal:
- """Represents a path through a graph."""
-
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- edge_traversals: Optional[List[_edge_traversal.EdgeTraversal]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "edgeTraversals"}
- )
- immutable_state: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "immutableState"}
- )
- initial_state: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "initialState"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- result_graph_index: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "resultGraphIndex"}
- )
- run_graph_index: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "runGraphIndex"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_invocation.py b/onnxscript/diagnostics/infra/sarif/_invocation.py
deleted file mode 100644
index 6f96c9a86c..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_invocation.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_location,
- _configuration_override,
- _notification,
- _property_bag,
-)
-
-
-@dataclasses.dataclass
-class Invocation:
- """The runtime environment of the analysis tool run."""
-
- execution_successful: bool = dataclasses.field(
- metadata={"schema_property_name": "executionSuccessful"}
- )
- account: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "account"}
- )
- arguments: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "arguments"}
- )
- command_line: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "commandLine"}
- )
- end_time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "endTimeUtc"}
- )
- environment_variables: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "environmentVariables"}
- )
- executable_location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "executableLocation"}
- )
- exit_code: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "exitCode"}
- )
- exit_code_description: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "exitCodeDescription"}
- )
- exit_signal_name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "exitSignalName"}
- )
- exit_signal_number: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "exitSignalNumber"}
- )
- machine: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "machine"}
- )
- notification_configuration_overrides: Optional[
- List[_configuration_override.ConfigurationOverride]
- ] = dataclasses.field(
- default=None,
- metadata={"schema_property_name": "notificationConfigurationOverrides"},
- )
- process_id: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "processId"}
- )
- process_start_failure_message: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "processStartFailureMessage"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- response_files: Optional[List[_artifact_location.ArtifactLocation]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "responseFiles"}
- )
- rule_configuration_overrides: Optional[
- List[_configuration_override.ConfigurationOverride]
- ] = dataclasses.field(
- default=None, metadata={"schema_property_name": "ruleConfigurationOverrides"}
- )
- start_time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "startTimeUtc"}
- )
- stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stderr"}
- )
- stdin: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stdin"}
- )
- stdout: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stdout"}
- )
- stdout_stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stdoutStderr"}
- )
- tool_configuration_notifications: Optional[List[_notification.Notification]] = (
- dataclasses.field(
- default=None,
- metadata={"schema_property_name": "toolConfigurationNotifications"},
- )
- )
- tool_execution_notifications: Optional[List[_notification.Notification]] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "toolExecutionNotifications"}
- )
- )
- working_directory: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "workingDirectory"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_location.py b/onnxscript/diagnostics/infra/sarif/_location.py
deleted file mode 100644
index 319856f8df..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_location.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _location_relationship,
- _logical_location,
- _message,
- _physical_location,
- _property_bag,
- _region,
-)
-
-
-@dataclasses.dataclass
-class Location:
- """A location within a programming artifact."""
-
- annotations: Optional[List[_region.Region]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "annotations"}
- )
- id: int = dataclasses.field(default=-1, metadata={"schema_property_name": "id"})
- logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "logicalLocations"}
- )
- message: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- physical_location: Optional[_physical_location.PhysicalLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "physicalLocation"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- relationships: Optional[List[_location_relationship.LocationRelationship]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "relationships"})
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_location_relationship.py b/onnxscript/diagnostics/infra/sarif/_location_relationship.py
deleted file mode 100644
index 35ca00c8a6..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_location_relationship.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag
-
-
-@dataclasses.dataclass
-class LocationRelationship:
- """Information about the relation of one location to another."""
-
- target: int = dataclasses.field(metadata={"schema_property_name": "target"})
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- kinds: List[str] = dataclasses.field(
- default_factory=lambda: ["relevant"], metadata={"schema_property_name": "kinds"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_logical_location.py b/onnxscript/diagnostics/infra/sarif/_logical_location.py
deleted file mode 100644
index 7f2880eef2..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_logical_location.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag
-
-
-@dataclasses.dataclass
-class LogicalLocation:
- """A logical location of a construct that produced a result."""
-
- decorated_name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "decoratedName"}
- )
- fully_qualified_name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "fullyQualifiedName"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- kind: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "kind"}
- )
- name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "name"}
- )
- parent_index: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "parentIndex"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_message.py b/onnxscript/diagnostics/infra/sarif/_message.py
deleted file mode 100644
index 0c9adce220..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_message.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag
-
-
-@dataclasses.dataclass
-class Message:
- """Encapsulates a message intended to be read by the end user."""
-
- arguments: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "arguments"}
- )
- id: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "id"}
- )
- markdown: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "markdown"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- text: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "text"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py b/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py
deleted file mode 100644
index 154b9cc416..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag
-
-
-@dataclasses.dataclass
-class MultiformatMessageString:
- """A message string or message format string rendered in multiple formats."""
-
- text: str = dataclasses.field(metadata={"schema_property_name": "text"})
- markdown: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "markdown"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_node.py b/onnxscript/diagnostics/infra/sarif/_node.py
deleted file mode 100644
index 0f11e37318..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_node.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _location, _message, _node, _property_bag
-
-
-@dataclasses.dataclass
-class Node:
- """Represents a node in a graph."""
-
- id: str = dataclasses.field(metadata={"schema_property_name": "id"})
- children: Optional[List[_node.Node]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "children"}
- )
- label: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "label"}
- )
- location: Optional[_location.Location] = dataclasses.field(
- default=None, metadata={"schema_property_name": "location"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_notification.py b/onnxscript/diagnostics/infra/sarif/_notification.py
deleted file mode 100644
index f41a9f8d5b..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_notification.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _exception,
- _location,
- _message,
- _property_bag,
- _reporting_descriptor_reference,
-)
-
-
-@dataclasses.dataclass
-class Notification:
- """Describes a condition relevant to the tool itself, as opposed to being relevant to a target being analyzed by the tool."""
-
- message: _message.Message = dataclasses.field(metadata={"schema_property_name": "message"})
- associated_rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "associatedRule"})
- )
- descriptor: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "descriptor"})
- )
- exception: Optional[_exception.Exception] = dataclasses.field(
- default=None, metadata={"schema_property_name": "exception"}
- )
- level: Literal["none", "note", "warning", "error"] = dataclasses.field(
- default="warning", metadata={"schema_property_name": "level"}
- )
- locations: Optional[List[_location.Location]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "locations"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- thread_id: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "threadId"}
- )
- time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "timeUtc"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_physical_location.py b/onnxscript/diagnostics/infra/sarif/_physical_location.py
deleted file mode 100644
index 357e85af4e..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_physical_location.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _address,
- _artifact_location,
- _property_bag,
- _region,
-)
-
-
-@dataclasses.dataclass
-class PhysicalLocation:
- """A physical location relevant to a result. Specifies a reference to a programming artifact together with a range of bytes or characters within that artifact."""
-
- address: Optional[_address.Address] = dataclasses.field(
- default=None, metadata={"schema_property_name": "address"}
- )
- artifact_location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "artifactLocation"}
- )
- context_region: Optional[_region.Region] = dataclasses.field(
- default=None, metadata={"schema_property_name": "contextRegion"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- region: Optional[_region.Region] = dataclasses.field(
- default=None, metadata={"schema_property_name": "region"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_property_bag.py b/onnxscript/diagnostics/infra/sarif/_property_bag.py
deleted file mode 100644
index 0b95c6e6e5..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_property_bag.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-
-@dataclasses.dataclass
-class PropertyBag:
- """Key/value pairs that provide additional information about the object."""
-
- tags: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "tags"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_rectangle.py b/onnxscript/diagnostics/infra/sarif/_rectangle.py
deleted file mode 100644
index a7c9aecd1a..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_rectangle.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag
-
-
-@dataclasses.dataclass
-class Rectangle:
- """An area within an image."""
-
- bottom: Optional[float] = dataclasses.field(
- default=None, metadata={"schema_property_name": "bottom"}
- )
- left: Optional[float] = dataclasses.field(
- default=None, metadata={"schema_property_name": "left"}
- )
- message: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- right: Optional[float] = dataclasses.field(
- default=None, metadata={"schema_property_name": "right"}
- )
- top: Optional[float] = dataclasses.field(
- default=None, metadata={"schema_property_name": "top"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_region.py b/onnxscript/diagnostics/infra/sarif/_region.py
deleted file mode 100644
index 35a4b7f316..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_region.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_content,
- _message,
- _property_bag,
-)
-
-
-@dataclasses.dataclass
-class Region:
- """A region within an artifact where a result was detected."""
-
- byte_length: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "byteLength"}
- )
- byte_offset: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "byteOffset"}
- )
- char_length: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "charLength"}
- )
- char_offset: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "charOffset"}
- )
- end_column: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "endColumn"}
- )
- end_line: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "endLine"}
- )
- message: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- snippet: Optional[_artifact_content.ArtifactContent] = dataclasses.field(
- default=None, metadata={"schema_property_name": "snippet"}
- )
- source_language: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "sourceLanguage"}
- )
- start_column: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "startColumn"}
- )
- start_line: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "startLine"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_replacement.py b/onnxscript/diagnostics/infra/sarif/_replacement.py
deleted file mode 100644
index 125ed75708..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_replacement.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag, _region
-
-
-@dataclasses.dataclass
-class Replacement:
- """The replacement of a single region of an artifact."""
-
- deleted_region: _region.Region = dataclasses.field(
- metadata={"schema_property_name": "deletedRegion"}
- )
- inserted_content: Optional[_artifact_content.ArtifactContent] = dataclasses.field(
- default=None, metadata={"schema_property_name": "insertedContent"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py b/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py
deleted file mode 100644
index e3da0a77b8..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag
-
-
-@dataclasses.dataclass
-class ReportingConfiguration:
- """Information about a rule or notification that can be configured at runtime."""
-
- enabled: bool = dataclasses.field(
- default=True, metadata={"schema_property_name": "enabled"}
- )
- level: Literal["none", "note", "warning", "error"] = dataclasses.field(
- default="warning", metadata={"schema_property_name": "level"}
- )
- parameters: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "parameters"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- rank: float = dataclasses.field(default=-1.0, metadata={"schema_property_name": "rank"})
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py
deleted file mode 100644
index 85e14f3763..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _multiformat_message_string,
- _property_bag,
- _reporting_configuration,
- _reporting_descriptor_relationship,
-)
-
-
-@dataclasses.dataclass
-class ReportingDescriptor:
- """Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime reporting."""
-
- id: str = dataclasses.field(metadata={"schema_property_name": "id"})
- default_configuration: Optional[_reporting_configuration.ReportingConfiguration] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "defaultConfiguration"}
- )
- )
- deprecated_guids: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "deprecatedGuids"}
- )
- deprecated_ids: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "deprecatedIds"}
- )
- deprecated_names: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "deprecatedNames"}
- )
- full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"})
- )
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- help: Optional[_multiformat_message_string.MultiformatMessageString] = dataclasses.field(
- default=None, metadata={"schema_property_name": "help"}
- )
- help_uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "helpUri"}
- )
- message_strings: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "messageStrings"}
- )
- name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "name"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- relationships: Optional[
- List[_reporting_descriptor_relationship.ReportingDescriptorRelationship]
- ] = dataclasses.field(default=None, metadata={"schema_property_name": "relationships"})
- short_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"})
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py
deleted file mode 100644
index f4e6f2260d..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag, _tool_component_reference
-
-
-@dataclasses.dataclass
-class ReportingDescriptorReference:
- """Information about how to locate a relevant reporting descriptor."""
-
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- id: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "id"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- tool_component: Optional[_tool_component_reference.ToolComponentReference] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "toolComponent"})
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py
deleted file mode 100644
index 52db517db5..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _message,
- _property_bag,
- _reporting_descriptor_reference,
-)
-
-
-@dataclasses.dataclass
-class ReportingDescriptorRelationship:
- """Information about the relation of one reporting descriptor to another."""
-
- target: _reporting_descriptor_reference.ReportingDescriptorReference = dataclasses.field(
- metadata={"schema_property_name": "target"}
- )
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- kinds: List[str] = dataclasses.field(
- default_factory=lambda: ["relevant"], metadata={"schema_property_name": "kinds"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_result.py b/onnxscript/diagnostics/infra/sarif/_result.py
deleted file mode 100644
index 3dfa564b54..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_result.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_location,
- _attachment,
- _code_flow,
- _fix,
- _graph,
- _graph_traversal,
- _location,
- _message,
- _property_bag,
- _reporting_descriptor_reference,
- _result_provenance,
- _stack,
- _suppression,
- _web_request,
- _web_response,
-)
-
-
-@dataclasses.dataclass
-class Result:
- """A result produced by an analysis tool."""
-
- message: _message.Message = dataclasses.field(metadata={"schema_property_name": "message"})
- analysis_target: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "analysisTarget"}
- )
- attachments: Optional[List[_attachment.Attachment]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "attachments"}
- )
- baseline_state: Optional[Literal["new", "unchanged", "updated", "absent"]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "baselineState"})
- )
- code_flows: Optional[List[_code_flow.CodeFlow]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "codeFlows"}
- )
- correlation_guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "correlationGuid"}
- )
- fingerprints: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "fingerprints"}
- )
- fixes: Optional[List[_fix.Fix]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "fixes"}
- )
- graph_traversals: Optional[List[_graph_traversal.GraphTraversal]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "graphTraversals"}
- )
- graphs: Optional[List[_graph.Graph]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "graphs"}
- )
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- hosted_viewer_uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "hostedViewerUri"}
- )
- kind: Literal["notApplicable", "pass", "fail", "review", "open", "informational"] = (
- dataclasses.field(default="fail", metadata={"schema_property_name": "kind"})
- )
- level: Literal["none", "note", "warning", "error"] = dataclasses.field(
- default="warning", metadata={"schema_property_name": "level"}
- )
- locations: Optional[List[_location.Location]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "locations"}
- )
- occurrence_count: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "occurrenceCount"}
- )
- partial_fingerprints: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "partialFingerprints"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- provenance: Optional[_result_provenance.ResultProvenance] = dataclasses.field(
- default=None, metadata={"schema_property_name": "provenance"}
- )
- rank: float = dataclasses.field(default=-1.0, metadata={"schema_property_name": "rank"})
- related_locations: Optional[List[_location.Location]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "relatedLocations"}
- )
- rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "rule"})
- )
- rule_id: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "ruleId"}
- )
- rule_index: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "ruleIndex"}
- )
- stacks: Optional[List[_stack.Stack]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stacks"}
- )
- suppressions: Optional[List[_suppression.Suppression]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "suppressions"}
- )
- taxa: Optional[List[_reporting_descriptor_reference.ReportingDescriptorReference]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "taxa"})
- )
- web_request: Optional[_web_request.WebRequest] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webRequest"}
- )
- web_response: Optional[_web_response.WebResponse] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webResponse"}
- )
- work_item_uris: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "workItemUris"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_result_provenance.py b/onnxscript/diagnostics/infra/sarif/_result_provenance.py
deleted file mode 100644
index 74ea9e1e9f..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_result_provenance.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _physical_location, _property_bag
-
-
-@dataclasses.dataclass
-class ResultProvenance:
- """Contains information about how and when a result was detected."""
-
- conversion_sources: Optional[List[_physical_location.PhysicalLocation]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "conversionSources"})
- )
- first_detection_run_guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "firstDetectionRunGuid"}
- )
- first_detection_time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "firstDetectionTimeUtc"}
- )
- invocation_index: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "invocationIndex"}
- )
- last_detection_run_guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "lastDetectionRunGuid"}
- )
- last_detection_time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "lastDetectionTimeUtc"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_run.py b/onnxscript/diagnostics/infra/sarif/_run.py
deleted file mode 100644
index 8df4f9b577..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_run.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _address,
- _artifact,
- _conversion,
- _external_property_file_references,
- _graph,
- _invocation,
- _logical_location,
- _property_bag,
- _result,
- _run_automation_details,
- _special_locations,
- _thread_flow_location,
- _tool,
- _tool_component,
- _version_control_details,
- _web_request,
- _web_response,
-)
-
-
-@dataclasses.dataclass
-class Run:
- """Describes a single run of an analysis tool, and contains the reported output of that run."""
-
- tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"})
- addresses: Optional[List[_address.Address]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "addresses"}
- )
- artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "artifacts"}
- )
- automation_details: Optional[_run_automation_details.RunAutomationDetails] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "automationDetails"})
- )
- baseline_guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "baselineGuid"}
- )
- column_kind: Optional[Literal["utf16CodeUnits", "unicodeCodePoints"]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "columnKind"}
- )
- conversion: Optional[_conversion.Conversion] = dataclasses.field(
- default=None, metadata={"schema_property_name": "conversion"}
- )
- default_encoding: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "defaultEncoding"}
- )
- default_source_language: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "defaultSourceLanguage"}
- )
- external_property_file_references: Optional[
- _external_property_file_references.ExternalPropertyFileReferences
- ] = dataclasses.field(
- default=None,
- metadata={"schema_property_name": "externalPropertyFileReferences"},
- )
- graphs: Optional[List[_graph.Graph]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "graphs"}
- )
- invocations: Optional[List[_invocation.Invocation]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "invocations"}
- )
- language: str = dataclasses.field(
- default="en-US", metadata={"schema_property_name": "language"}
- )
- logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "logicalLocations"}
- )
- newline_sequences: List[str] = dataclasses.field(
- default_factory=lambda: ["\r\n", "\n"],
- metadata={"schema_property_name": "newlineSequences"},
- )
- original_uri_base_ids: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "originalUriBaseIds"}
- )
- policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "policies"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- redaction_tokens: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "redactionTokens"}
- )
- results: Optional[List[_result.Result]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "results"}
- )
- run_aggregates: Optional[List[_run_automation_details.RunAutomationDetails]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "runAggregates"})
- )
- special_locations: Optional[_special_locations.SpecialLocations] = dataclasses.field(
- default=None, metadata={"schema_property_name": "specialLocations"}
- )
- taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "taxonomies"}
- )
- thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "threadFlowLocations"}
- )
- )
- translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "translations"}
- )
- version_control_provenance: Optional[
- List[_version_control_details.VersionControlDetails]
- ] = dataclasses.field(
- default=None, metadata={"schema_property_name": "versionControlProvenance"}
- )
- web_requests: Optional[List[_web_request.WebRequest]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webRequests"}
- )
- web_responses: Optional[List[_web_response.WebResponse]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webResponses"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_run_automation_details.py b/onnxscript/diagnostics/infra/sarif/_run_automation_details.py
deleted file mode 100644
index f41dfcc284..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_run_automation_details.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag
-
-
-@dataclasses.dataclass
-class RunAutomationDetails:
- """Information that describes a run's identity and role within an engineering system process."""
-
- correlation_guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "correlationGuid"}
- )
- description: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "description"}
- )
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- id: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "id"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_sarif_log.py b/onnxscript/diagnostics/infra/sarif/_sarif_log.py
deleted file mode 100644
index aa39c52f15..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_sarif_log.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import _external_properties, _property_bag, _run
-
-
-@dataclasses.dataclass
-class SarifLog:
- """Static Analysis Results Format (SARIF) Version 2.1.0 JSON Schema: a standard format for the output of static analysis tools."""
-
- runs: List[_run.Run] = dataclasses.field(metadata={"schema_property_name": "runs"})
- version: Literal["2.1.0"] = dataclasses.field(metadata={"schema_property_name": "version"})
- schema_uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "$schema"}
- )
- inline_external_properties: Optional[List[_external_properties.ExternalProperties]] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "inlineExternalProperties"}
- )
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_special_locations.py b/onnxscript/diagnostics/infra/sarif/_special_locations.py
deleted file mode 100644
index ee78979514..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_special_locations.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag
-
-
-@dataclasses.dataclass
-class SpecialLocations:
- """Defines locations of special significance to SARIF consumers."""
-
- display_base: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "displayBase"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_stack.py b/onnxscript/diagnostics/infra/sarif/_stack.py
deleted file mode 100644
index e250b75df4..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_stack.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _message, _property_bag, _stack_frame
-
-
-@dataclasses.dataclass
-class Stack:
- """A call stack that is relevant to a result."""
-
- frames: List[_stack_frame.StackFrame] = dataclasses.field(
- metadata={"schema_property_name": "frames"}
- )
- message: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_stack_frame.py b/onnxscript/diagnostics/infra/sarif/_stack_frame.py
deleted file mode 100644
index 24d9fe8201..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_stack_frame.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _location, _property_bag
-
-
-@dataclasses.dataclass
-class StackFrame:
- """A function call within a stack trace."""
-
- location: Optional[_location.Location] = dataclasses.field(
- default=None, metadata={"schema_property_name": "location"}
- )
- module: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "module"}
- )
- parameters: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "parameters"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- thread_id: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "threadId"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_suppression.py b/onnxscript/diagnostics/infra/sarif/_suppression.py
deleted file mode 100644
index ae477178b0..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_suppression.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import _location, _property_bag
-
-
-@dataclasses.dataclass
-class Suppression:
- """A suppression that is relevant to a result."""
-
- kind: Literal["inSource", "external"] = dataclasses.field(
- metadata={"schema_property_name": "kind"}
- )
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- justification: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "justification"}
- )
- location: Optional[_location.Location] = dataclasses.field(
- default=None, metadata={"schema_property_name": "location"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- state: Optional[Literal["accepted", "underReview", "rejected"]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "state"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_thread_flow.py b/onnxscript/diagnostics/infra/sarif/_thread_flow.py
deleted file mode 100644
index d3d1693677..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_thread_flow.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _message,
- _property_bag,
- _thread_flow_location,
-)
-
-
-@dataclasses.dataclass
-class ThreadFlow:
- """Describes a sequence of code locations that specify a path through a single thread of execution such as an operating system or fiber."""
-
- locations: List[_thread_flow_location.ThreadFlowLocation] = dataclasses.field(
- metadata={"schema_property_name": "locations"}
- )
- id: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "id"}
- )
- immutable_state: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "immutableState"}
- )
- initial_state: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "initialState"}
- )
- message: Optional[_message.Message] = dataclasses.field(
- default=None, metadata={"schema_property_name": "message"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py b/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py
deleted file mode 100644
index 949c42d80e..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _location,
- _property_bag,
- _reporting_descriptor_reference,
- _stack,
- _web_request,
- _web_response,
-)
-
-
-@dataclasses.dataclass
-class ThreadFlowLocation:
- """A location visited by an analysis tool while simulating or monitoring the execution of a program."""
-
- execution_order: int = dataclasses.field(
- default=-1, metadata={"schema_property_name": "executionOrder"}
- )
- execution_time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "executionTimeUtc"}
- )
- importance: Literal["important", "essential", "unimportant"] = dataclasses.field(
- default="important", metadata={"schema_property_name": "importance"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- kinds: Optional[List[str]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "kinds"}
- )
- location: Optional[_location.Location] = dataclasses.field(
- default=None, metadata={"schema_property_name": "location"}
- )
- module: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "module"}
- )
- nesting_level: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "nestingLevel"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- stack: Optional[_stack.Stack] = dataclasses.field(
- default=None, metadata={"schema_property_name": "stack"}
- )
- state: Any = dataclasses.field(default=None, metadata={"schema_property_name": "state"})
- taxa: Optional[List[_reporting_descriptor_reference.ReportingDescriptorReference]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "taxa"})
- )
- web_request: Optional[_web_request.WebRequest] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webRequest"}
- )
- web_response: Optional[_web_response.WebResponse] = dataclasses.field(
- default=None, metadata={"schema_property_name": "webResponse"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_tool.py b/onnxscript/diagnostics/infra/sarif/_tool.py
deleted file mode 100644
index 79589ddf77..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_tool.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import List, Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag, _tool_component
-
-
-@dataclasses.dataclass
-class Tool:
- """The analysis tool that was run."""
-
- driver: _tool_component.ToolComponent = dataclasses.field(
- metadata={"schema_property_name": "driver"}
- )
- extensions: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "extensions"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_tool_component.py b/onnxscript/diagnostics/infra/sarif/_tool_component.py
deleted file mode 100644
index 47925ed748..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_tool_component.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, List, Literal, Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _artifact_location,
- _multiformat_message_string,
- _property_bag,
- _reporting_descriptor,
- _tool_component_reference,
- _translation_metadata,
-)
-
-
-@dataclasses.dataclass
-class ToolComponent:
- """A component, such as a plug-in or the driver, of the analysis tool that was run."""
-
- name: str = dataclasses.field(metadata={"schema_property_name": "name"})
- associated_component: Optional[_tool_component_reference.ToolComponentReference] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "associatedComponent"}
- )
- )
- contents: List[Literal["localizedData", "nonLocalizedData"]] = dataclasses.field(
- default_factory=lambda: ["localizedData", "nonLocalizedData"],
- metadata={"schema_property_name": "contents"},
- )
- dotted_quad_file_version: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "dottedQuadFileVersion"}
- )
- download_uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "downloadUri"}
- )
- full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"})
- )
- full_name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "fullName"}
- )
- global_message_strings: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "globalMessageStrings"}
- )
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- information_uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "informationUri"}
- )
- is_comprehensive: Optional[bool] = dataclasses.field(
- default=None, metadata={"schema_property_name": "isComprehensive"}
- )
- language: str = dataclasses.field(
- default="en-US", metadata={"schema_property_name": "language"}
- )
- localized_data_semantic_version: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "localizedDataSemanticVersion"}
- )
- locations: Optional[List[_artifact_location.ArtifactLocation]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "locations"}
- )
- minimum_required_localized_data_semantic_version: Optional[str] = dataclasses.field(
- default=None,
- metadata={"schema_property_name": "minimumRequiredLocalizedDataSemanticVersion"},
- )
- notifications: Optional[List[_reporting_descriptor.ReportingDescriptor]] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "notifications"})
- )
- organization: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "organization"}
- )
- product: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "product"}
- )
- product_suite: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "productSuite"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- release_date_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "releaseDateUtc"}
- )
- rules: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "rules"}
- )
- semantic_version: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "semanticVersion"}
- )
- short_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"})
- )
- supported_taxonomies: Optional[List[_tool_component_reference.ToolComponentReference]] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "supportedTaxonomies"}
- )
- )
- taxa: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field(
- default=None, metadata={"schema_property_name": "taxa"}
- )
- translation_metadata: Optional[_translation_metadata.TranslationMetadata] = (
- dataclasses.field(
- default=None, metadata={"schema_property_name": "translationMetadata"}
- )
- )
- version: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "version"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py b/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py
deleted file mode 100644
index 09cc2b9087..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _property_bag
-
-
-@dataclasses.dataclass
-class ToolComponentReference:
- """Identifies a particular toolComponent object, either the driver or an extension."""
-
- guid: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "guid"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "name"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_translation_metadata.py b/onnxscript/diagnostics/infra/sarif/_translation_metadata.py
deleted file mode 100644
index f05125a599..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_translation_metadata.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import (
- _multiformat_message_string,
- _property_bag,
-)
-
-
-@dataclasses.dataclass
-class TranslationMetadata:
- """Provides additional metadata related to translation."""
-
- name: str = dataclasses.field(metadata={"schema_property_name": "name"})
- download_uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "downloadUri"}
- )
- full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"})
- )
- full_name: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "fullName"}
- )
- information_uri: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "informationUri"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- short_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
- dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"})
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_version_control_details.py b/onnxscript/diagnostics/infra/sarif/_version_control_details.py
deleted file mode 100644
index f56498bb69..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_version_control_details.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Optional
-
-from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag
-
-
-@dataclasses.dataclass
-class VersionControlDetails:
- """Specifies the information necessary to retrieve a desired revision from a version control system."""
-
- repository_uri: str = dataclasses.field(metadata={"schema_property_name": "repositoryUri"})
- as_of_time_utc: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "asOfTimeUtc"}
- )
- branch: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "branch"}
- )
- mapped_to: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
- default=None, metadata={"schema_property_name": "mappedTo"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- revision_id: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "revisionId"}
- )
- revision_tag: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "revisionTag"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_web_request.py b/onnxscript/diagnostics/infra/sarif/_web_request.py
deleted file mode 100644
index b574882f9b..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_web_request.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, Optional
-
-from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag
-
-
-@dataclasses.dataclass
-class WebRequest:
- """Describes an HTTP request."""
-
- body: Optional[_artifact_content.ArtifactContent] = dataclasses.field(
- default=None, metadata={"schema_property_name": "body"}
- )
- headers: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "headers"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- method: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "method"}
- )
- parameters: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "parameters"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- protocol: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "protocol"}
- )
- target: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "target"}
- )
- version: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "version"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/_web_response.py b/onnxscript/diagnostics/infra/sarif/_web_response.py
deleted file mode 100644
index 3753036ab1..0000000000
--- a/onnxscript/diagnostics/infra/sarif/_web_response.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
-# with extension for dataclasses and type annotation.
-
-from __future__ import annotations
-
-import dataclasses
-from typing import Any, Optional
-
-from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag
-
-
-@dataclasses.dataclass
-class WebResponse:
- """Describes the response to an HTTP request."""
-
- body: Optional[_artifact_content.ArtifactContent] = dataclasses.field(
- default=None, metadata={"schema_property_name": "body"}
- )
- headers: Any = dataclasses.field(
- default=None, metadata={"schema_property_name": "headers"}
- )
- index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"})
- no_response_received: Optional[bool] = dataclasses.field(
- default=None, metadata={"schema_property_name": "noResponseReceived"}
- )
- properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
- default=None, metadata={"schema_property_name": "properties"}
- )
- protocol: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "protocol"}
- )
- reason_phrase: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "reasonPhrase"}
- )
- status_code: Optional[int] = dataclasses.field(
- default=None, metadata={"schema_property_name": "statusCode"}
- )
- version: Optional[str] = dataclasses.field(
- default=None, metadata={"schema_property_name": "version"}
- )
-
-
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/sarif/version.py b/onnxscript/diagnostics/infra/sarif/version.py
deleted file mode 100644
index 020a28bf76..0000000000
--- a/onnxscript/diagnostics/infra/sarif/version.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from typing import Final
-
-SARIF_VERSION: Final = "2.1.0"
-SARIF_SCHEMA_LINK: Final = (
- "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
-)
-# flake8: noqa
diff --git a/onnxscript/diagnostics/infra/utils.py b/onnxscript/diagnostics/infra/utils.py
deleted file mode 100644
index bc8f5f9c78..0000000000
--- a/onnxscript/diagnostics/infra/utils.py
+++ /dev/null
@@ -1,74 +0,0 @@
-from __future__ import annotations
-
-import functools
-import inspect
-import traceback
-from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple
-
-from onnxscript._internal import runtime_typing
-from onnxscript.diagnostics.infra import _infra, formatter
-
-
-@runtime_typing.checked
-def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
- """Returns a StackFrame for the given traceback.FrameSummary."""
- snippet = frame.line
-
- return _infra.StackFrame(
- location=_infra.Location(
- uri=frame.filename,
- line=frame.lineno,
- snippet=snippet,
- function=frame.name,
- message=snippet,
- )
- )
-
-
-@runtime_typing.checked
-def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack:
- """Returns the current Python call stack."""
- if frames_to_skip < 0:
- raise ValueError("frames_to_skip must be non-negative")
- if frames_to_log < 0:
- raise ValueError("frames_to_log must be non-negative")
- frames_to_skip += 2 # Skip this function and beartype.
- stack = _infra.Stack()
- # Frames are returned in order of oldest to newest.
- frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log)
- frames.reverse()
- stack.frames = [python_frame(frame) for frame in frames[frames_to_skip:]]
- stack.message = "Python call stack"
- return stack
-
-
-@functools.lru_cache
-def _function_source_info(fn: Callable) -> Tuple[Sequence[str], int, Optional[str]]:
- """Returns the source lines, line number, and source file path for the given function.
-
- Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined.
- Caching is applied to reduce the performance impact of this function.
- """
- source_lines, lineno = inspect.getsourcelines(fn)
- return source_lines, lineno, inspect.getsourcefile(fn)
-
-
-@runtime_typing.checked
-def function_location(fn: Callable) -> _infra.Location:
- """Returns a Location for the given function."""
- source_lines, lineno, uri = _function_source_info(fn)
- snippet = source_lines[0].strip() if len(source_lines) > 0 else ""
- return _infra.Location(
- uri=uri,
- line=lineno,
- snippet=snippet,
- message=formatter.display_name(fn),
- )
-
-
-@runtime_typing.checked
-def function_state(
- fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
-) -> Mapping[str, Any]:
- bind = inspect.signature(fn).bind(*args, **kwargs)
- return bind.arguments
diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py
index a936824cab..1d87ee135e 100644
--- a/onnxscript/evaluator.py
+++ b/onnxscript/evaluator.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
import abc
@@ -22,7 +20,6 @@
import numpy as np
import onnx
import onnx.defs
-import onnx.helper
import onnx.reference
from typing_extensions import TypeAlias
@@ -292,16 +289,16 @@ def eval_function(
has_array = False
for arg, param_schema in tagged_args:
if param_schema.is_input:
- adapted_arg, _has_array = _adapt_to_eager_mode(arg)
- has_array = has_array or _has_array
+ adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
+ has_array = has_array or has_array_
adapted_args.append(adapted_arg)
else:
adapted_args.append(arg)
for key, (arg, param_schema) in tagged_kwargs.items():
if param_schema.is_input:
- adapted_arg, _has_array = _adapt_to_eager_mode(arg)
- has_array = has_array or _has_array
+ adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
+ has_array = has_array or has_array_
adapted_kwargs[key] = adapted_arg
else:
adapted_kwargs[key] = arg
@@ -369,7 +366,7 @@ def _onnxscript_to_numpy_value(v):
if isinstance(v, list):
return [_onnxscript_to_numpy_value(x) for x in v]
if isinstance(v, tuple):
- if len(v) > 0 and type(v[0]) is int: # noqa: E721 # pylint: disable=unidiomatic-typecheck
+ if len(v) > 0 and type(v[0]) is int: # pylint: disable=unidiomatic-typecheck
return np.array(v, dtype=np.int64)
return np.array(v)
if v is None:
@@ -389,8 +386,10 @@ def _numpy_to_onnxscript_value(
):
"""Converts an ORT encoding of an ONNX value into the encoding used by onnxscript."""
if isinstance(v, np.ndarray):
- return tensor.Tensor(v)
- if np.issctype(type(v)): # noqa: NPY201
+ # ORT may reuse buffers when the output numpy array is provided back as input.
+ # We need to make a copy to ensure that the tensor is not modified in-place.
+ return tensor.Tensor(v.copy())
+ if issubclass(type(v), np.generic):
# Numpy scalar types that are not ndarray
# https://numpy.org/doc/stable/reference/arrays.scalars.html
return tensor.Tensor(np.array(v))
@@ -421,7 +420,7 @@ def make_tensor_name() -> str:
return f"attr_{key}"
return autocast.pyvalue_to_onnx_attribute(
- key, value, make_tensor_name, schema.attributes[key].type
+ key, value, make_tensor_name, int(schema.attributes[key].type)
)
# Construct ONNX model with a single op call:
@@ -430,21 +429,22 @@ def make_tensor_name() -> str:
num_outputs = compute_num_outputs(schema, args, kwargs)
outputs = [f"output{i}" for i in range(num_outputs)]
- node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain)
+ node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251
node.attribute.extend(
make_attr(key, value) for key, value in kwargs.items() if value is not None
)
input_value_infos = utils.values_to_value_infos(zip(inputs, args))
implicit_value_infos = utils.values_to_value_infos(implicit_args.items())
output_value_infos = [
- onnx.helper.make_value_info(name, onnx.TypeProto()) for name in outputs
+ onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251
+ for name in outputs
]
- graph = onnx.helper.make_graph(
+ graph = onnx.helper.make_graph( # noqa: TID251
[node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos
)
- opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version)
- model = onnx.helper.make_model(
+ opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251
+ model = onnx.helper.make_model( # noqa: TID251
graph,
opset_imports=[opset_id],
ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain),
diff --git a/onnxscript/evaluator_test.py b/onnxscript/evaluator_test.py
index a5ad41a78f..d42b1bab75 100644
--- a/onnxscript/evaluator_test.py
+++ b/onnxscript/evaluator_test.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
import unittest
import numpy as np
diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py
index 37232c84eb..c5b87898c9 100644
--- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py
+++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from __future__ import annotations
import copy
@@ -151,11 +153,9 @@ def __repr__(self):
" Type Constraints: ",
]
# Trick to get unique type constraints but maintain the order.
- ordered_unique_type_constraints = {
- v: None for v in self.input_type_constraints.values()
- }
+ ordered_unique_type_constraints = dict.fromkeys(self.input_type_constraints.values())
ordered_unique_type_constraints.update(
- {v: None for v in self.output_type_constraints.values()}
+ dict.fromkeys(self.output_type_constraints.values())
)
repr_strs += [
f" {type_constraint.name}: {type_constraint.type_strs}"
@@ -175,9 +175,9 @@ def __repr__(self):
repr_strs += [
" Intermediate Type Constraints: ",
]
- ordered_unique_type_constraints = {
- v: None for v in self.intermediate_type_constraints.values()
- }
+ ordered_unique_type_constraints = dict.fromkeys(
+ self.intermediate_type_constraints.values()
+ )
repr_strs += [
f" {type_constraint.name}: {type_constraint.type_strs}"
for type_constraint in ordered_unique_type_constraints
@@ -210,15 +210,15 @@ def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConst
)
# Rename type constraints to T0, T1, T2, ...
- _seen_type_constraints: Set[TypeConstraint] = set()
+ seen_type_constraints: Set[TypeConstraint] = set()
for type_constraint in (
*input_type_constraints.values(),
*output_type_constraints.values(),
*intermediate_type_constraints.values(),
):
- if type_constraint is not None and type_constraint not in _seen_type_constraints:
- type_constraint.name = f"T{len(_seen_type_constraints)}"
- _seen_type_constraints.add(type_constraint)
+ if type_constraint is not None and type_constraint not in seen_type_constraints:
+ type_constraint.name = f"T{len(seen_type_constraints)}"
+ seen_type_constraints.add(type_constraint)
return OnnxFunctionTypeConstraints(
input_type_constraints, output_type_constraints, intermediate_type_constraints
diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py
index 25586085ef..a8d15c242a 100644
--- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py
+++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Test cases for type constraint deduction functionality."""
from __future__ import annotations
diff --git a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py
index 44c3980668..eb2d8015a4 100644
--- a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py
+++ b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py
@@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None:
functions[module_name] = {}
op_name = get_op_name(func)
if op_name in functions[module_name]:
- logging.warning(
+ logging.warning( # noqa: LOG015
"Duplicated function: %s, overload: %s", op_name, func.func.name.overload_name
)
continue
diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py
index e96d24ed4a..ebbdd43bd8 100644
--- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py
+++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py
@@ -258,7 +258,7 @@ def _get_func_schema_in_namespace(namespaces: List[_OpNamespace]) -> Dict[str, F
# to "resize(Tensor a, SymInt[] shape) -> Tensor"
if "!" in op_overload_packet.schema:
op_overload_packet.schema = re.sub( # type: ignore[attr-defined]
- "[(][A-Za-z]![)]", "", op_overload_packet.schema
+ r"[(][A-Za-z]![)]", "", op_overload_packet.schema
)
# FIXME: remove below code if the issue below is fixed.
@@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None:
if module_name not in functions:
functions[module_name] = {}
if op_name in functions[module_name]:
- logging.warning(
+ logging.warning( # noqa: LOG015
"Duplicated function: %s, overload: %s",
op_name,
func_schema.name.overload_name,
diff --git a/onnxscript/function_libs/torch_lib/__init__.py b/onnxscript/function_libs/torch_lib/__init__.py
index 4c4966c2b4..18e9054a6f 100644
--- a/onnxscript/function_libs/torch_lib/__init__.py
+++ b/onnxscript/function_libs/torch_lib/__init__.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""A torch function library for onnxscript.
The modules are named after the torch module names for grouping:
diff --git a/onnxscript/function_libs/torch_lib/_constants.py b/onnxscript/function_libs/torch_lib/_constants.py
index 58cc2c0680..f4e14061ec 100644
--- a/onnxscript/function_libs/torch_lib/_constants.py
+++ b/onnxscript/function_libs/torch_lib/_constants.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Shared constants for the library."""
DOMAIN = "pkg.onnxscript.torch_lib"
diff --git a/onnxscript/function_libs/torch_lib/_flags.py b/onnxscript/function_libs/torch_lib/_flags.py
index 560cd5baaa..79593f3464 100644
--- a/onnxscript/function_libs/torch_lib/_flags.py
+++ b/onnxscript/function_libs/torch_lib/_flags.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Experimental flags.
NOTE: These flags are experimental only. Any flag here can be removed at any
@@ -15,6 +17,7 @@ def _load_boolean_flag(
*,
this_will: str,
deprecated: bool = False,
+ default: bool = False,
) -> bool:
"""Load a boolean flag from environment variable.
@@ -22,7 +25,9 @@ def _load_boolean_flag(
name: The name of the environment variable.
this_will: A string that describes what this flag will do.
deprecated: Whether this flag is deprecated.
+ default: The default value if envvar not defined.
"""
+ undefined = os.getenv(name) is None
state = os.getenv(name) == "1"
if state:
if deprecated:
@@ -32,6 +37,8 @@ def _load_boolean_flag(
)
else:
logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will)
+ if undefined:
+ state = default
return state
@@ -42,8 +49,5 @@ def _load_boolean_flag(
EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_PREFER_TRACING",
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
-)
-EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
- "TORCHLIB_EXPERIMENTAL_USE_IR",
- this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
+ default=True,
)
diff --git a/onnxscript/function_libs/torch_lib/graph_building/__init__.py b/onnxscript/function_libs/torch_lib/graph_building/__init__.py
index e70f7f4c27..70a35d729f 100644
--- a/onnxscript/function_libs/torch_lib/graph_building/__init__.py
+++ b/onnxscript/function_libs/torch_lib/graph_building/__init__.py
@@ -1,35 +1,5 @@
-"""APIs for building an ONNX graph from a PyTorch model.
-
-This module exposes only three classes that will be used to build an ONNX graph
-by the ONNX exporter in PyTorch:
-
-- :class:`TorchScriptTensor`: Represents a symbolic value in the ONNX graph.
-- :class:`TorchScriptGraph`: Stores the graph being built.
-- :class:`TorchScriptTracingEvaluator`: An evaluator that will record all operators
- applied on the ``TorchScriptTensor``. It has a reference to the ``TorchScriptGraph``
- being built, will write to it, and will handle eager evaluations of Torch Lib
- functions when desired.
-
-The usage is in https://github.com/pytorch/pytorch/blob/136f8378e1b5a8cb7127977b8d068fbf9c3e1247/torch/onnx/_internal/fx/fx_onnx_interpreter.py#L698-L702,
-and it is very simple::
-
- with onnxscript.evaluator.default_as(onnxscript_tracer): # onnxscript_tracer is a TorchScriptTracingEvaluator
- output: Union[
- onnxscript_graph_building.TorchScriptTensor,
- Tuple[onnxscript_graph_building.TorchScriptTensor, ...],
- ] = symbolic_fn(*onnx_args, **onnx_kwargs)
-
-Here, we set the default evaluator to be ``onnxscript_tracer`` so
-that ONNX Script will dispatch all operators calls to the evaluator. The ``symbolic_fn``
-can be a pure Python function (e.g. trace-only) or an ONNX Script function. Either way,
-they are recorded by ``onnxscript_tracer`` and onto the graph.
-
-The outputs, as ``TorchScriptTensor``, are then handed by to the exporter. On line
-https://github.com/pytorch/pytorch/blob/136f8378e1b5a8cb7127977b8d068fbf9c3e1247/torch/onnx/_internal/fx/fx_onnx_interpreter.py#L707
-the exporter fills in type and shape information from PyTorch by calling the setters
-on ``TorchScriptTensor.dtype`` and ``TorchScriptTensor.shape``.
-"""
-
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
from __future__ import annotations
__all__ = [
@@ -38,17 +8,17 @@
"TorchScriptTracingEvaluator",
]
-from onnxscript.function_libs.torch_lib import _flags
-if _flags.EXPERIMENTAL_USE_IR:
- from ._graph_building_ir import (
- TorchScriptGraph,
- TorchScriptTensor,
- TorchScriptTracingEvaluator,
- )
-else:
- from ._graph_building_torch import ( # type: ignore[assignment]
- TorchScriptGraph,
- TorchScriptTensor,
- TorchScriptTracingEvaluator,
- )
+class _RemovedClass:
+ """A onnxscript tensor that wraps a torchscript Value."""
+
+ def __init__(self, *_, **__):
+ raise NotImplementedError(
+ "Support for dynamo_export has been dropped since onnxscript 0.4.0. "
+ "Please use `torch.onnx.export(..., dynamo=True)`, or downgrade to onnxscript<0.4"
+ )
+
+
+TorchScriptTensor = _RemovedClass
+TorchScriptGraph = _RemovedClass
+TorchScriptTracingEvaluator = _RemovedClass
diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py
deleted file mode 100644
index a26a612ba8..0000000000
--- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py
+++ /dev/null
@@ -1,755 +0,0 @@
-"""Graph building functions using the ONNX IR, compatible with the original TorchScriptGraph usage."""
-
-from __future__ import annotations
-
-import ctypes
-from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
-
-import numpy as np
-import onnx
-import onnx.checker
-import onnx.defs
-import onnx.helper
-import onnx.shape_inference
-import torch
-from typing_extensions import TypeAlias
-
-import onnxscript
-from onnxscript import evaluator, ir
-from onnxscript import tensor as onnxscript_tensor
-from onnxscript._internal import param_manipulation
-from onnxscript.function_libs.torch_lib import _flags
-from onnxscript.function_libs.torch_lib.ops import common as common_ops
-
-__all__ = [
- "TorchScriptTensor",
- "TorchScriptGraph",
- "TorchScriptTracingEvaluator",
-]
-
-
-ValidArgumentType: TypeAlias = Union[
- "TorchScriptTensor",
- Sequence["TorchScriptTensor"],
- Sequence[float],
- Sequence[int],
- complex,
- str,
- int,
- float,
- bool,
- None,
-]
-ValidInputType: TypeAlias = Union[
- "TorchScriptTensor",
- Sequence["TorchScriptTensor"],
- Sequence[float],
- Sequence[int],
- complex,
- str,
- int,
- float,
- bool,
- None,
-]
-
-_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
- torch.bfloat16: ir.DataType.BFLOAT16,
- torch.bool: ir.DataType.BOOL,
- torch.complex128: ir.DataType.COMPLEX128,
- torch.complex64: ir.DataType.COMPLEX64,
- torch.float16: ir.DataType.FLOAT16,
- torch.float32: ir.DataType.FLOAT,
- torch.float64: ir.DataType.DOUBLE,
- torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
- torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
- torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
- torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
- torch.int16: ir.DataType.INT16,
- torch.int32: ir.DataType.INT32,
- torch.int64: ir.DataType.INT64,
- torch.int8: ir.DataType.INT8,
- torch.uint8: ir.DataType.UINT8,
-}
-
-
-def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
- return _TORCH_DTYPE_TO_ONNX[dtype]
-
-
-class _TorchTensor(ir.Tensor): # pylint: disable=too-many-ancestors
- def __init__(self, tensor: torch.Tensor):
- super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
-
- def tobytes(self) -> bytes:
- # Support native PyTorch types so we can use types like bloat16
- assert isinstance(self.raw, torch.Tensor)
- tensor = self.raw.detach().cpu().contiguous()
- return bytes(
- (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
- tensor.data_ptr()
- )
- )
-
-
-class TorchScriptTensor(ir.Value, onnxscript_tensor.Tensor):
- """A onnxscript tensor that wraps a torchscript Value."""
-
- def __init__(
- self,
- _=None, # Unused argument for backward compatibility
- producer=None,
- index=None,
- name: str | None = None,
- ):
- onnxscript_tensor.Tensor.__init__(self, None)
- ir.Value.__init__(self, producer, index=index, name=name)
- self._is_complex: bool = False
- self._concrete_value: np.ndarray | None = None
- self._device: torch.device | None = None
-
- @property
- def value(self) -> Optional[np.ndarray]:
- return self._concrete_value
-
- @value.setter
- def value(self, value: np.ndarray) -> None:
- self._concrete_value = value
-
- @property # type: ignore[override]
- def rank(self) -> int | None:
- if self.shape is None:
- return None
- return len(self.shape)
-
- @property # type: ignore[override]
- def shape(self) -> ir.Shape | None:
- return super().shape
-
- @shape.setter
- def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]):
- # Normalize torch symbolic dimension size to str.
- torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool)
- self._shape = ir.Shape(
- tuple(str(dim.node) if isinstance(dim, torch_sym_types) else dim for dim in shape) # type: ignore[union-attr]
- )
-
- @property
- def dtype(self) -> ir.DataType | None:
- return super().dtype
-
- @dtype.setter
- def dtype(self, dtype: torch.dtype | ir.DataType | None):
- if dtype is None:
- onnx_dtype = ir.DataType.UNDEFINED
- elif isinstance(dtype, ir.DataType):
- onnx_dtype = dtype
- else:
- onnx_dtype = _torch_dtype_to_onnx_dtype(dtype)
- if self._type is None:
- self._type = ir.TensorType(onnx_dtype)
- else:
- self._type.dtype = onnx_dtype
-
- # TODO: Remove this when there is no mismatch output shapes between device:
- # https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457
- @property
- def device(self) -> torch.device | None:
- return self._device
-
- @device.setter
- def device(self, device: torch.device):
- self._device = device
-
- @property
- def is_complex(self) -> bool:
- return self._is_complex
-
- @is_complex.setter
- def is_complex(self, is_complex: bool):
- self._is_complex = is_complex
-
- @property
- def onnx_dtype(self) -> int:
- raise NotImplementedError("onnx_dtype is not supported for TorchScriptTensor.")
-
- def value_info(self) -> Optional[onnx.ValueInfoProto]:
- raise NotImplementedError("value_info is not supported for TorchScriptTensor.")
-
-
-class _Node(ir.Node):
- """A node that will produce TorchScriptTensor as outputs for compatibility."""
-
- def __init__(
- self,
- domain: str,
- op_type: str,
- inputs: Sequence[ir.Value | None],
- attributes: Sequence[ir.Attr | ir.RefAttr] = (),
- *,
- overload: str = "",
- num_outputs: int = 1,
- version: int | None = None,
- name: str | None = None,
- doc_string: str | None = None,
- ):
- super().__init__(
- domain=domain,
- op_type=op_type,
- inputs=inputs,
- attributes=attributes,
- overload=overload,
- num_outputs=num_outputs,
- version=version,
- name=name,
- doc_string=doc_string,
- )
- self._outputs: tuple[TorchScriptTensor, ...] = tuple(
- TorchScriptTensor(producer=self, index=i) for i in range(num_outputs)
- )
-
- @property # type: ignore[misc]
- def outputs(self) -> Sequence[TorchScriptTensor]:
- return self._outputs
-
-
-class TorchScriptTracingEvaluator(evaluator.Evaluator):
- """An onnxscript Evaluator that captures the graph."""
-
- def __init__(self, graph: TorchScriptGraph):
- self._graph: TorchScriptGraph = graph
-
- @property
- def graph(self) -> TorchScriptGraph:
- return self._graph
-
- def eval(self, schema, inputs: Sequence[ValidInputType], attributes):
- return self._graph.add_op_call(schema, inputs, attributes)
-
- def eval_function( # type: ignore[override]
- self,
- function: onnxscript.OnnxFunction,
- args: Sequence[ValidArgumentType],
- kwargs: Mapping[str, ValidArgumentType],
- ):
- if _flags.EXPERIMENTAL_PREFER_TRACING:
- # Special cases for handling IsScalar and Rank
- if function.name == "IsScalar":
- if len(args) != 1:
- raise TypeError(
- f"Expected 1 positional argument for function '{function}', got {len(args)}."
- )
- if isinstance(args[0], TorchScriptTensor):
- if args[0].rank is not None:
- return args[0].rank == 0
- else:
- # Fall to call add_function_call
- pass
- elif isinstance(args[0], Sequence): # noqa: SIM103
- return False
- else:
- # Python constants are scalars
- return True
- if function.name == "Rank":
- if len(args) != 1:
- raise TypeError(
- f"Expected 1 positional argument for function '{function}', got {len(args)}."
- )
- if isinstance(args[0], TorchScriptTensor):
- if args[0].rank is not None:
- return args[0].rank
- else:
- # Fall to call add_function_call
- pass
- elif isinstance(args[0], Sequence):
- if all(isinstance(arg, (int, float)) for arg in args[0]):
- return 1
- else:
- # Fall to call add_function_call
- pass
- else:
- # Python constants are scalars
- return 0
- elif function.experimental_traceable:
- # Trace the function call instead of adding the function as a node
- return function.function(*args, **kwargs)
-
- # args/kwargs are TorchScriptTensor/python built-in based
- param_schemas = function.param_schemas()
- (
- inputs,
- attributes,
- ) = param_manipulation.separate_input_attributes_from_arguments(
- param_schemas, args, kwargs, fill_defaults=True, allow_extra_kwargs=True
- )
-
- # Cast attributes to the correct type based on function signature
- op_schema = function.op_schema
- assert op_schema is not None
- for name, value in attributes.items():
- attribute = op_schema.attributes[name]
- if attribute.type == onnx.defs.OpSchema.AttrType.FLOAT:
- # Cast int to float if the attribute is FLOAT
- attributes[name] = float(value)
-
- # In PyTorch, an attribute annotated as `int[1]?` accepts an integer
- # or a sequence. When the attribute is an integer, it is treated as
- # a single element sequence. ONNX requires an attribute to either be
- # an integer or a sequence. So we promote the value to a sequence here.
- if attribute.type == onnx.defs.OpSchema.AttrType.INTS and isinstance(value, int):
- attributes[name] = (value,)
- if attribute.type == onnx.defs.OpSchema.AttrType.FLOATS and isinstance(
- value, float
- ):
- attributes[name] = (value,)
- return self._graph.add_function_call(function, inputs, attributes)
-
-
-def _build_attribute(
- key: str,
- value: Union[
- float,
- int,
- str,
- Sequence[float],
- Sequence[int],
- torch.Tensor,
- _TorchTensor,
- ir.TensorProtocol,
- ],
-):
- """Initializes the right attribute based on type of value."""
- if isinstance(value, float):
- return ir.AttrFloat32(key, value)
- if isinstance(value, int):
- return ir.AttrInt64(key, value)
- if isinstance(value, str):
- return ir.AttrString(key, value)
- if isinstance(value, torch.Tensor):
- return ir.AttrTensor(key, _TorchTensor(value))
- if isinstance(value, (_TorchTensor, ir.TensorProtocol)):
- return ir.AttrTensor(key, value)
- if isinstance(value, Sequence):
- if not value:
- # Treat empty sequences as empty list tensors
- # TODO(justinchuby): Revisit ways to determine the type of the empty list
- return ir.AttrInt64s(key, [])
- if isinstance(value[0], float):
- return ir.AttrFloat32s(key, list(value)) # type: ignore[arg-type]
- if isinstance(value[0], int):
- return ir.AttrInt64s(key, list(value)) # type: ignore
- raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'")
- raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'")
-
-
-def _create_op_call_in_graph(
- graph: ir.Graph,
- domain: str,
- op_type: str,
- *,
- inputs: Sequence[TorchScriptTensor],
- attributes: Mapping[str, Any],
- num_outputs: int = 1,
-) -> Sequence[TorchScriptTensor]:
- """Creates a node representing an onnx op in `graph`.
-
- Args:
- graph: The torch graph to add the node to.
- domain: The domain of the op.
- op_type: The name of the op. E.g. "Add".
- inputs: The onnx inputs to the op.
- attributes: The onnx attributes to the op.
- num_outputs: The number of outputs the op has.
-
- Returns:
- The outputs of the created node.
- """
- # Filter out None attributes, this can be convenient client side because
- # now they can pass through None attributes, and have them not show up
- attributes = {k: v for k, v in attributes.items() if v is not None}
-
- node = _Node(
- domain,
- op_type,
- inputs=inputs,
- attributes=[_build_attribute(key, value) for key, value in attributes.items()],
- num_outputs=num_outputs,
- )
- graph.append(node)
-
- return node.outputs
-
-
-def _shared_functions() -> list[ir.Function]:
- """Hack to always include the share ops."""
-
- # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
- return [
- ir.serde.deserialize_function(common_ops.Rank.to_function_proto()),
- ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto()),
- ]
-
-
-class TorchScriptGraph:
- def __init__(
- self,
- parent_torch_script_graph: Optional[TorchScriptGraph] = None,
- domain_name: Optional[str] = None,
- ):
- self._graph = ir.Graph((), (), nodes=(), name="main_graph")
- # All the functions used, deduplicated by name
- # key: (name, domain)
- self._function_store: Dict[ir.OperatorIdentifier, ir.Function] = {}
- self._initializers: Dict[str, torch.Tensor] = {}
- # Mapping from initializer name to input(TorchScriptTensor).
- self._initializers_inputs: Dict[str, TorchScriptTensor] = {}
- # Mapping from initializer name to input(TorchScriptTensor) from parent graph.
- self._initializers_inputs_from_parent: Dict[str, TorchScriptTensor] = {}
- # Mapping from model local function type name to function graph.
- # Local function type name is expected to be unique. Converter creates
- # a unique name and a unique function graph for every module call.
- self._sub_torch_script_graphs: Dict[str, TorchScriptGraph] = {}
- # Parent graph. None if this is the top level graph.
- self._parent_torch_script_graph = parent_torch_script_graph
- # Domain name of the graph. None if this is the top level graph.
- self._domain_name: Optional[str] = domain_name
-
- if self._domain_name is None and self._parent_torch_script_graph is not None:
- raise RuntimeError(
- "Domain name is not set. It is required because this 'TorchScriptGraph' instance "
- "is a subgraph that represents an ONNX local function."
- )
-
- @property
- def initializers(self) -> Mapping[str, torch.Tensor]:
- return self._initializers
-
- # NOTE: This setter is used in torch converter when we activate fake mode,
- # we need to filter out the initializers that has fake tensor. This
- # is because we don't want to introduce fake tensor in onnxscript.
- @initializers.setter
- def initializers(self, initializers: Dict[str, torch.Tensor]):
- self._initializers = initializers
-
- @property
- def initializers_inputs(self) -> Mapping[str, TorchScriptTensor]:
- return self._initializers_inputs
-
- @property
- def initializers_inputs_from_parent(self) -> Mapping[str, TorchScriptTensor]:
- return self._initializers_inputs_from_parent
-
- @property
- def num_outputs(self) -> int:
- return len(self._graph.outputs)
-
- @property
- def domain_name(self) -> Optional[str]:
- return self._domain_name
-
- def add_input(
- self,
- input_name: Optional[str],
- shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- ) -> TorchScriptTensor | None:
- if input_name is None:
- # This input argument is None, which is mapped
- # to a NULL value in TorchScript type system.
- value = None
- else:
- value = TorchScriptTensor(name=input_name)
- value.shape = shape # type: ignore[arg-type,assignment]
- value.device = device
- if dtype is not None:
- value.dtype = dtype # type: ignore[assignment]
- # TODO(titaiwang): This approach loses the information that "same SymInts
- # indicates same shape", for example, [symint0, symint0, symint1]
- # would all be [None, None, None]
- # torch_value.setType(
- # torch_value.type().with_sizes(
- # [dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr]
- # )
- # )
- self._graph.inputs.append(value) # type: ignore[arg-type]
- return value
-
- def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor:
- if name in self._initializers_inputs:
- # NOTE: Previously it raises when `name` is already set. This is relaxed
- # because this will be invoked multiple times when submodule is called
- # multiple times.
- if name in self._initializers and self._initializers[name] is not value:
- raise ValueError(
- f"Initializer '{name}' exists already with a different value."
- )
- return self._initializers_inputs[name] # type: ignore[return-value]
-
- if (
- self != self._parent_torch_script_graph
- and self._parent_torch_script_graph is not None
- ):
- # Only the root graph can have initializers. Add as initializer
- # to root graph, and add as input to current graph.
- self._initializers_inputs_from_parent[name] = (
- self._parent_torch_script_graph.add_initializer(name, value)
- )
- else:
- input = TorchScriptTensor(name=name)
- self._initializers_inputs[name] = input
- self._initializers[name] = value
- return input
-
- def register_outputs(
- self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]
- ):
- if isinstance(outputs, TorchScriptTensor):
- outputs = (outputs,)
- for output in outputs:
- assert isinstance(
- output, TorchScriptTensor
- ), f"output must be a TorchScriptTensor, not {type(output)}"
- self._graph.outputs.append(output)
-
- def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]:
- """Add a constant to the graph.
-
- Returns:
- A single element of sequence of the constant value.
- """
- if constant is None:
- return (None,)
-
- if isinstance(constant, bool):
- # Be sure to put bool before int, because bool is a subclass of int
- constant_tensor = torch.tensor(constant, dtype=torch.bool)
- elif isinstance(constant, float):
- constant_tensor = torch.tensor(constant, dtype=torch.float)
- elif isinstance(constant, int):
- constant_tensor = torch.tensor(constant, dtype=torch.int64)
- elif isinstance(constant, (tuple, list)) and all(
- isinstance(val, int) for val in constant
- ):
- constant_tensor = torch.tensor(constant, dtype=torch.int64)
- elif isinstance(constant, (tuple, list)) and all(
- isinstance(val, float) for val in constant
- ):
- constant_tensor = torch.tensor(constant, dtype=torch.float)
- elif isinstance(constant, complex):
- # NOTE: ONNX doesn't support tensor of complex64/complex128, so we
- # convert them to float32/float64 with real representation.
- constant_tensor = torch.view_as_real(torch.tensor(constant).resolve_conj())
- else:
- raise TypeError(
- f"Constant input '{constant}' of type '{type(constant)}' is not supported"
- )
- onnx_tensor = _TorchTensor(constant_tensor)
- value = _create_op_call_in_graph(
- self._graph,
- "",
- "Constant",
- inputs=(),
- attributes=dict(value=onnx_tensor),
- )
- return value
-
- def _add_ir_graph_op_call(
- self,
- *,
- domain: str,
- op_type: str,
- onnx_inputs: Sequence[ValidInputType],
- onnx_attributes: Mapping[str, ValidArgumentType],
- num_outputs: int,
- ) -> Sequence[TorchScriptTensor]:
- graph_inputs: list[TorchScriptTensor] = []
- assert isinstance(onnx_inputs, Sequence)
- for input in onnx_inputs:
- # NOTE(titaiwang): input could be empty list
- if (
- isinstance(input, Sequence)
- and input
- and all(isinstance(elem, TorchScriptTensor) for elem in input)
- ):
- # If all elements in the Sequence are TorchScriptTensor we know it
- # should be a Sequence input in ONNX.
- input_sequence = _create_op_call_in_graph(
- self._graph,
- "",
- "SequenceConstruct",
- inputs=input, # type: ignore
- attributes={},
- )
- graph_inputs.extend(input_sequence)
- elif not isinstance(input, TorchScriptTensor):
- graph_inputs.extend(self._add_constant_to_graph(input)) # type: ignore
- else:
- # TODO(justinchuby): What is this case?
- graph_inputs.append(input)
- for key, value in onnx_attributes.items():
- assert not isinstance(
- value, TorchScriptTensor
- ), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
- tensors = _create_op_call_in_graph(
- self._graph,
- domain,
- op_type,
- inputs=graph_inputs,
- attributes=onnx_attributes,
- num_outputs=num_outputs,
- )
- assert tensors, "Expected at least one output from ONNX op call."
- # NOTE: TorchScriptTensor is created here, however neither dtype nor shape is
- # set. It is expected that exporter will modify the tensor being returned and
- # set these info.
- return tensors
-
- def _fetch_function_dict(
- self, opset_version: int
- ) -> Mapping[ir.OperatorIdentifier, ir.Function]:
- function_dict: Dict[ir.OperatorIdentifier, ir.Function] = {}
- # Fetch local function protos. E.g., local functions representing module calls.
- for (
- sub_graph_name,
- sub_torch_script_graph,
- ) in self._sub_torch_script_graphs.items():
- function_dict.update(sub_torch_script_graph._fetch_function_dict(opset_version)) # pylint: disable=protected-access
- domain = sub_torch_script_graph.domain_name
- assert domain is not None
- name_domain = (sub_graph_name, domain, "")
- assert (
- name_domain not in function_dict
- ), f"Sub graph name already exists. {name_domain}"
- function_dict[name_domain] = sub_torch_script_graph._to_function( # pylint: disable=protected-access
- opset_version, sub_graph_name
- )
- # Fetch torchlib function protos.
- for identifier, function in self._function_store.items():
- function_dict[identifier] = function
- return function_dict
-
- def add_op_call(
- self,
- onnx_op_schema: onnx.defs.OpSchema,
- onnx_inputs: Sequence[ValidInputType],
- onnx_attributes: Mapping[str, ValidArgumentType],
- ) -> Union[TorchScriptTensor, Sequence[TorchScriptTensor]]:
- # Compute outputs from the onnx_op op schema
- num_outputs = evaluator.compute_num_outputs(
- onnx_op_schema, onnx_inputs, onnx_attributes
- )
- result = self._add_ir_graph_op_call(
- domain="",
- op_type=onnx_op_schema.name,
- onnx_inputs=onnx_inputs,
- onnx_attributes=onnx_attributes,
- num_outputs=num_outputs,
- )
-
- if num_outputs == 1:
- return result[0]
-
- return result
-
- def add_function_call(
- self,
- onnx_function: onnxscript.OnnxFunction,
- onnx_inputs: Sequence[ValidInputType],
- onnx_attributes: Mapping[str, ValidArgumentType],
- ) -> Union[TorchScriptTensor, Sequence[TorchScriptTensor]]:
- ir_function = ir.serde.deserialize_function(onnx_function.to_function_proto())
- self._function_store[ir_function.identifier()] = ir_function
- num_outputs = len(onnx_function.function_ir.outputs)
- # Compute outputs from the function schema
- result = self._add_ir_graph_op_call(
- domain=ir_function.domain,
- op_type=ir_function.name,
- onnx_inputs=onnx_inputs,
- onnx_attributes=onnx_attributes,
- num_outputs=num_outputs,
- )
-
- if num_outputs == 1:
- return result[0]
-
- return result
-
- def add_module_call(
- self,
- name: str,
- sub_torch_script_graph: TorchScriptGraph,
- onnx_inputs: Sequence[ValidInputType],
- ) -> Union[TorchScriptTensor, Sequence[TorchScriptTensor]]:
- self._sub_torch_script_graphs[name] = sub_torch_script_graph
- domain_name = sub_torch_script_graph.domain_name
- assert domain_name is not None
-
- num_outputs = sub_torch_script_graph.num_outputs
- result = self._add_ir_graph_op_call(
- domain=domain_name,
- op_type=name,
- onnx_inputs=(
- *onnx_inputs,
- *sub_torch_script_graph.initializers_inputs_from_parent.values(),
- ),
- onnx_attributes={},
- num_outputs=num_outputs,
- )
-
- if num_outputs == 1:
- return result[0]
-
- return result
-
- def _to_function(self, opset_version: int, function_name: str) -> ir.Function:
- assert len(self.initializers) == 0, "Model local functions cannot have initializers."
-
- # Dissect the model proto and transform to function proto.
- domain = self.domain_name
- if domain is None:
- raise RuntimeError("Domain name is not set.")
- onnx_function = ir.Function(
- domain=domain,
- name=function_name,
- graph=self._graph,
- attributes=(),
- )
- onnx_function.opset_imports[""] = opset_version
-
- return onnx_function
-
- def to_model_proto(
- self, opset_version: int, include_initializers: bool = True
- ) -> onnx.ModelProto:
- function_dict: Mapping[ir.OperatorIdentifier, ir.Function] = self._fetch_function_dict(
- opset_version
- )
- unique_custom_domains: Dict[str, int] = {"": opset_version}
-
- for function in function_dict.values():
- # TODO(BowenBao): All local function domain versions are hardcoded as 1.
- unique_custom_domains[function.domain] = 1
-
- if include_initializers:
- self._graph.initializers.update(
- {name: _TorchTensor(value) for name, value in self._initializers.items()}
- )
- else:
- self._graph.initializers.clear()
-
- onnx_model = ir.Model(
- self._graph,
- ir_version=8,
- producer_name=f"pytorch {torch.__version__}",
- functions=[*function_dict.values(), *_shared_functions()],
- )
-
- onnx_model.opset_imports.update(unique_custom_domains)
- # Include the library shared opset domain
- # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
- onnx_model.opset_imports[common_ops.common_opset.domain] = (
- common_ops.common_opset.version
- )
- model_proto = ir.serde.serialize_model(onnx_model)
- return model_proto
diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py
deleted file mode 100644
index c07ba3ce81..0000000000
--- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py
+++ /dev/null
@@ -1,1093 +0,0 @@
-"""Graph building functions for torchscript graph backend."""
-
-from __future__ import annotations
-
-import os
-import tempfile
-import typing
-from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
-
-import numpy as np
-import onnx
-import onnx.checker
-import onnx.defs
-import onnx.helper
-import onnx.shape_inference
-import torch
-from typing_extensions import TypeAlias
-
-import onnxscript
-from onnxscript import evaluator
-from onnxscript import tensor as onnxscript_tensor
-from onnxscript._internal import param_manipulation, runtime_typing
-from onnxscript.function_libs.torch_lib import _flags
-from onnxscript.function_libs.torch_lib.ops import common as common_ops
-
-__all__ = [
- "TorchScriptTensor",
- "TorchScriptGraph",
- "TorchScriptTracingEvaluator",
-]
-
-
-ValidArgumentType: TypeAlias = Union[
- "TorchScriptTensor",
- Sequence["TorchScriptTensor"],
- Sequence[float],
- Sequence[int],
- complex,
- str,
- int,
- float,
- bool,
- None,
-]
-ValidInputType: TypeAlias = Union[
- "TorchScriptTensor",
- Sequence["TorchScriptTensor"],
- Sequence[float],
- Sequence[int],
- complex,
- str,
- int,
- float,
- bool,
- None,
-]
-ValidTorchValueType: TypeAlias = Union[
- torch.Value,
- Sequence[torch.Value],
- Sequence[float],
- Sequence[int],
- complex,
- str,
- int,
- float,
- bool,
- None,
-]
-
-# Be sure to leave ample room for the rest of the proto fields.
-_LARGE_MODEL_SIZE_THRESHOLD = int(2**30 * 1.8) # 1.8GB
-
-# TODO(justinchuby): Build a context manager to handle source information.
-
-
-def _rename_intermediate_value(name: str) -> str:
- """Prepend `_val_` to a numeric tensor name make it valid in ONNX.
-
- The TorchScript graph creates numeric value names by default. e.g. `0`, `1`.
- These are not legal ONNX tensor names, since ONNX requires the names to be valid
- C variable names.
-
- It also improves readability by making the names less likely to be confused
- with shape values.
- """
- if name.isdigit():
- # Prefix with `_` to avoid name collision
- return f"_val_{name}"
- return name
-
-
-def _function_id(domain: str | None, name: str) -> str:
- """Create a unique function id for a function in a domain.
-
- Used for generating model level unique ids for values inside a function.
- """
- return f"{domain if domain is not None else ''}::{name}"
-
-
-class TorchScriptTensor(onnxscript_tensor.Tensor):
- """A onnxscript tensor that wraps a torchscript Value."""
-
- def __init__(
- self,
- value: torch.Value,
- ):
- super().__init__(None)
- self._torch_value: torch.Value = value
- self._concrete_value: Optional[np.ndarray] = None
- self._shape: Optional[Tuple[int | str | None, ...]] = None
- self._torch_dtype: Optional[torch.dtype] = None
- self._name: Optional[str] = None
- self._is_complex: bool = False
- self._device: Optional[torch.device] = None
-
- def __repr__(self):
- return f"TorchScriptTensor('{self._torch_value!r}')"
-
- @property # type: ignore[override]
- def value(self) -> Optional[np.ndarray]:
- return self._concrete_value
-
- @value.setter
- def value(self, value: np.ndarray):
- self._concrete_value = value
-
- @property
- @runtime_typing.checked
- def name(self) -> str:
- if self._name is not None:
- return self._name
- return self._torch_value.debugName()
-
- @name.setter
- @runtime_typing.checked
- def name(self, name: str):
- self._name = name
- self._torch_value.setDebugName(name)
-
- @property # type: ignore[override]
- def rank(self) -> int | None:
- if self._shape is not None:
- return len(self._shape)
-
- value_type = self._torch_value.type()
- if value_type is None:
- return None
- value_type = typing.cast(torch.TensorType, value_type)
- return value_type.dim()
-
- @property # type: ignore[override]
- def shape(self) -> Tuple[int | str | None, ...] | None:
- if self._shape is not None:
- return self._shape
-
- value_type = self._torch_value.type()
- if value_type is None:
- return None
- value_type = typing.cast(torch.TensorType, value_type)
- if isinstance(value_type, torch.OptionalType):
- shape = value_type.getElementType().varyingSizes() # type: ignore[attr-defined]
- else:
- shape = value_type.varyingSizes()
- if shape is None:
- return None
- return tuple(shape)
-
- @shape.setter
- def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]):
- # Normalize torch symbolic dimension size to str.
- torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool)
- self._shape = tuple(
- str(dim.node) if isinstance(dim, torch_sym_types) else dim # type: ignore[union-attr]
- for dim in shape
- )
- # jit api does not support assigning symbolic shapes,
- # hence symbols are replaced as None.
- jit_shape = tuple(dim if isinstance(dim, int) else None for dim in shape)
- self._torch_value.setType(self._torch_value.type().with_sizes(list(jit_shape)))
-
- @property # type: ignore[override]
- def dtype(self) -> torch.dtype | None:
- # TODO: Return numpy dtype
- if self._torch_dtype is not None:
- return self._torch_dtype
- # Local import to avoid circular dependency
- from torch.onnx import _type_utils # pylint: disable=import-outside-toplevel
-
- torch_dtype = _type_utils.JitScalarType.from_value( # type: ignore[attr-defined]
- self._torch_value, default=_type_utils.JitScalarType.UNDEFINED
- )
- if torch_dtype == _type_utils.JitScalarType.UNDEFINED:
- return None
- self._torch_dtype = torch_dtype.dtype()
- return self._torch_dtype
-
- @dtype.setter
- def dtype(self, dtype: torch.dtype):
- self._torch_dtype = dtype
- self._torch_value.setType(self._torch_value.type().with_dtype(dtype))
-
- @property
- def is_complex(self) -> bool:
- return self._is_complex
-
- @is_complex.setter
- def is_complex(self, is_complex: bool):
- self._is_complex = is_complex
-
- # TODO: Remove this when there is no mismatch output shapes between device:
- # https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457
- @property
- def device(self) -> torch.device | None:
- return self._device
-
- @device.setter
- def device(self, device: torch.device):
- self._device = device
-
- @property
- def onnx_dtype(self):
- # Local import to avoid circular dependency
- from torch.onnx import _type_utils # pylint: disable=import-outside-toplevel
-
- return _type_utils.JitScalarType.from_value( # type: ignore[attr-defined]
- self._torch_value, _type_utils.JitScalarType.UNDEFINED
- ).onnx_type()
-
- def symbolic_value(self) -> torch.Value:
- """The symbolic Value in torch.Graph."""
- return self._torch_value
-
- def value_info(self) -> Optional[onnx.ValueInfoProto]:
- try:
- dtype = self.onnx_dtype.value
- except torch.onnx.errors.OnnxExporterError:
- return None
- if dtype == onnx.TensorProto.UNDEFINED:
- return None
- return onnx.helper.make_tensor_value_info(self.name, dtype, self.shape)
-
-
-@runtime_typing.checked
-def _unwrap_tensor_to_torch_value(
- value: Union[
- ValidArgumentType, Mapping[str, ValidArgumentType], Sequence[ValidArgumentType]
- ],
-) -> Union[
- ValidTorchValueType,
- Dict[str, ValidTorchValueType],
- List[ValidTorchValueType],
- Tuple[ValidTorchValueType, ...],
-]:
- """Unwrap the TorchScriptTensor to torch.Value."""
- if isinstance(value, TorchScriptTensor):
- return value.symbolic_value()
- if isinstance(value, dict):
- return {k: _unwrap_tensor_to_torch_value(v) for k, v in value.items()} # type: ignore[misc,return-value]
- if isinstance(value, list):
- return [_unwrap_tensor_to_torch_value(v) for v in value] # type: ignore[misc,return-value]
- if isinstance(value, tuple):
- return tuple(_unwrap_tensor_to_torch_value(v) for v in value) # type: ignore[misc,return-value]
-
- # A normal python value
- return value # type: ignore[return-value]
-
-
-@runtime_typing.checked
-def _wrap_torch_value_to_tensor(
- value: Union[
- torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType]
- ],
- *,
- shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
-) -> Union[
- ValidArgumentType,
- Dict[str, ValidArgumentType],
- List[ValidArgumentType],
- Tuple[ValidArgumentType, ...],
-]:
- """Wrap torch.Value to TorchScriptTensor."""
- if isinstance(value, torch.Value):
- tensor = TorchScriptTensor(value)
- if shape is not None:
- tensor.shape = shape
- if dtype is not None:
- tensor.dtype = dtype
- if device is not None:
- tensor.device = device
- return tensor
- if isinstance(value, dict):
- return {k: _wrap_torch_value_to_tensor(v) for k, v in value.items()} # type: ignore[misc,return-value]
- if isinstance(value, list):
- return [_wrap_torch_value_to_tensor(v) for v in value] # type: ignore[misc,return-value]
- if isinstance(value, tuple):
- return tuple(_wrap_torch_value_to_tensor(v) for v in value) # type: ignore[misc,return-value]
-
- return value # type: ignore[return-value]
-
-
-def _unwrap_tensors_to_torch_values(tensors):
- # TODO(justinchuby): Do we really need this?
- if isinstance(tensors, Sequence):
- return [_unwrap_tensor_to_torch_value(output) for output in tensors]
- return _unwrap_tensor_to_torch_value(tensors)
-
-
-class TorchScriptTracingEvaluator(evaluator.Evaluator):
- """An onnxscript Evaluator that captures the graph into torchscript."""
-
- def __init__(self, graph: TorchScriptGraph):
- self._graph: TorchScriptGraph = graph
-
- @property
- def graph(self) -> TorchScriptGraph:
- return self._graph
-
- def eval(self, schema, inputs, attributes):
- if _flags.EXPERIMENTAL_PREFER_TRACING:
- if schema.name == "CastLike":
- assert len(inputs) == 2
- # Skip CastLike if the input and output types are the same
- src_input = inputs[0]
- target_input = inputs[1]
- dtypes_available = (
- isinstance(src_input, TorchScriptTensor)
- and isinstance(target_input, TorchScriptTensor)
- and src_input.dtype is not None
- and target_input.dtype is not None
- )
- if dtypes_available:
- if src_input.dtype == target_input.dtype:
- # Same type. No cast needed
- return src_input
- else:
- # Create a Cast node
- return self._graph.add_op_call(
- onnx.defs.get_schema("Cast"),
- (src_input,),
- {"to": target_input.onnx_dtype},
- )
- return self._graph.add_op_call(schema, inputs, attributes)
-
- @runtime_typing.checked
- def eval_function( # type: ignore[override]
- self,
- function: onnxscript.OnnxFunction,
- args: Sequence[ValidArgumentType],
- kwargs: Mapping[str, ValidArgumentType],
- ):
- if _flags.EXPERIMENTAL_PREFER_TRACING:
- # Special cases for handling IsScalar and Rank
- if function.name == "IsScalar":
- if len(args) != 1:
- raise TypeError(
- f"Expected 1 positional argument for function '{function}', got {len(args)}."
- )
- if isinstance(args[0], TorchScriptTensor):
- if args[0].rank is not None:
- return args[0].rank == 0
- else:
- # Fall to call add_function_call
- pass
- elif isinstance(args[0], Sequence): # noqa: SIM103
- return False
- else:
- # Python constants are scalars
- return True
- if function.name == "Rank":
- if len(args) != 1:
- raise TypeError(
- f"Expected 1 positional argument for function '{function}', got {len(args)}."
- )
- if isinstance(args[0], TorchScriptTensor):
- if args[0].rank is not None:
- return args[0].rank
- else:
- # Fall to call add_function_call
- pass
- elif isinstance(args[0], Sequence):
- if all(isinstance(arg, (int, float)) for arg in args[0]):
- return 1
- else:
- # Fall to call add_function_call
- pass
- else:
- # Python constants are scalars
- return 0
- elif function.experimental_traceable:
- # Trace the function call instead of adding the function as a node
- return function.function(*args, **kwargs)
-
- # args/kwargs are TorchScriptTensor/python built-in based
- param_schemas = function.param_schemas()
- (
- inputs,
- attributes,
- ) = param_manipulation.separate_input_attributes_from_arguments(
- param_schemas, args, kwargs, fill_defaults=True, allow_extra_kwargs=True
- )
-
- # Cast attributes to the correct type based on function signature
- op_schema = function.op_schema
- assert op_schema is not None
- for name, value in attributes.items():
- attribute = op_schema.attributes[name]
- if attribute.type == onnx.defs.OpSchema.AttrType.FLOAT:
- # Cast int to float if the attribute is FLOAT
- attributes[name] = float(value)
-
- # In PyTorch, an attribute annotated as `int[1]?` accepts an integer
- # or a sequence. When the attribute is an integer, it is treated as
- # a single element sequence. ONNX requires an attribute to either be
- # an integer or a sequence. So we promote the value to a sequence here.
- if attribute.type == onnx.defs.OpSchema.AttrType.INTS and isinstance(value, int):
- attributes[name] = (value,)
- if attribute.type == onnx.defs.OpSchema.AttrType.FLOATS and isinstance(
- value, float
- ):
- attributes[name] = (value,)
- return self._graph.add_function_call(function, inputs, attributes)
-
-
-@runtime_typing.checked
-def _add_attribute_to_torchscript_node(
- node: torch.Node,
- key: str,
- value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor],
-):
- """Initializes the right attribute based on type of value."""
- if isinstance(value, float):
- return node.f_(key, value)
- if isinstance(value, int):
- return node.i_(key, value)
- if isinstance(value, (str, bytes)):
- return node.s_(key, value) # type: ignore[arg-type]
- if isinstance(value, torch.Tensor):
- return node.t_(key, value)
- if isinstance(value, Sequence):
- if not value:
- # Treat empty sequences as empty list tensors
- # TODO(justinchuby): Revisit ways to determine the type of the empty list
- return node.is_(key, list(value)) # type: ignore[attr-defined]
- if isinstance(value[0], float):
- return node.fs_(key, list(value)) # type: ignore[arg-type]
- if isinstance(value[0], int):
- return node.is_(key, list(value)) # type: ignore[attr-defined]
- raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'")
- raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'")
-
-
-@runtime_typing.checked
-def _create_op_call_in_torch_graph(
- graph: torch.Graph,
- opname: str,
- *,
- inputs: Sequence[torch.Value],
- attributes: Mapping[str, Any],
- n_outputs: int = 1,
-) -> Tuple[torch.Value, ...]:
- """Creates a node representing an onnx op in `graph`.
-
- Args:
- graph: The torch graph to add the node to.
- opname: The name of the op to add. E.g. "onnx::Add".
- inputs: The onnx inputs to the op.
- attributes: The onnx attributes to the op.
- n_outputs: The number of outputs the op has.
-
- Returns:
- The outputs of the created node.
- """
- # Filter out None attributes, this can be convenient client side because
- # now they can pass through None attributes, and have them not show up
- attributes = {k: v for k, v in attributes.items() if v is not None}
-
- node = graph.create(opname, inputs, n_outputs)
- node = graph.insertNode(node)
- node_ouputs = tuple(node.outputs())
-
- assert len(node_ouputs) == n_outputs
- # Add all attributes
- for key, value in sorted(attributes.items()):
- _add_attribute_to_torchscript_node(node, key, value)
-
- return node_ouputs
-
-
-def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
- """Estimate the size of a tensor in bytes.
-
- Args:
- tensor: The tensor to estimate the size of.
-
- Returns:
- The estimated size of the tensor in bytes.
- """
- return tensor.numel() * tensor.element_size()
-
-
-def _shared_functions() -> list[onnx.FunctionProto]:
- """Hack to always include the share ops."""
-
- # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
- return [
- common_ops.Rank.to_function_proto(),
- common_ops.IsScalar.to_function_proto(),
- ]
-
-
-class TorchScriptGraph:
- def __init__(
- self,
- parent_torch_script_graph: Optional[TorchScriptGraph] = None,
- domain_name: Optional[str] = None,
- ):
- self._torch_graph = torch.Graph()
- # All the functions used, deduplicated by name
- # key: (name, domain)
- self._function_store: Dict[Tuple[str, str], onnxscript.OnnxFunction] = {}
- # Mapping from intializer name to data(torch.Tensor).
- self._initializers: Dict[str, torch.Tensor] = {}
- # Mapping from intializer name to input(TorchScriptTensor).
- self._initializers_inputs: Dict[str, TorchScriptTensor] = {}
- # Mapping from intializer name to input(TorchScriptTensor) from parent graph.
- self._initializers_inputs_from_parent: Dict[str, TorchScriptTensor] = {}
- # Mapping from model local function type name to function graph.
- # Local function type name is expected to be unique. Converter creates
- # a unique name and a unique function graph for every module call.
- self._sub_torch_script_graphs: Dict[str, TorchScriptGraph] = {}
- # Parent graph. None if this is the top level graph.
- self._parent_torch_script_graph = parent_torch_script_graph
- # Domain name of the graph. None if this is the top level graph.
- self._domain_name: Optional[str] = domain_name
- # Mapping from `torch.Value` to `TorchScriptTensor`.
- # Because `torch.Value` does not provide API to set and retrieve symbolic shapes,
- # and because `TorchScriptTensor` is not accessible through the `torch.Graph` graph,
- # this mapping is used to keep track of the `TorchScriptTensor` associated with
- # `torch.Value`.
- # `TorchScriptTensor` records dtype and symbolic shapes.
- # This info is later serialized as `ValueInfoProto` inside ONNX, to
- # provide shape and dtype information for nodes within nested function calls.
- # https://github.com/onnx/onnx/issues/5487
- self._value_to_tensor: Dict[torch.Value, TorchScriptTensor] = {}
-
- if self._domain_name is None and self._parent_torch_script_graph is not None:
- raise RuntimeError(
- "Domain name is not set. It is required because this 'TorchScriptGraph' instance "
- "is a subgraph that represents an ONNX local function."
- )
-
- @property
- def torch_graph(self):
- return self._torch_graph
-
- @property
- def initializers(self) -> Mapping[str, torch.Tensor]:
- return self._initializers
-
- # NOTE: This setter is used in torch converter when we activate fake mode,
- # we need to filter out the initializers that has fake tensor. This
- # is because we don't want to introduce fake tensor in onnxscript.
- @initializers.setter
- def initializers(self, initializers: Dict[str, torch.Tensor]):
- self._initializers = initializers
-
- @property
- def initializers_inputs(self) -> Mapping[str, TorchScriptTensor]:
- return self._initializers_inputs
-
- @property
- def initializers_inputs_from_parent(self) -> Mapping[str, TorchScriptTensor]:
- return self._initializers_inputs_from_parent
-
- @property
- def num_outputs(self) -> int:
- return len(list(self._torch_graph.outputs()))
-
- @property
- def domain_name(self) -> Optional[str]:
- return self._domain_name
-
- @runtime_typing.checked
- def add_input(
- self,
- input_name: Optional[str],
- shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- ) -> TorchScriptTensor:
- if input_name is None:
- # This input argument is None, which is mapped
- # to a NULL value in TorchScript type system.
- torch_value = _create_op_call_in_torch_graph(
- self._torch_graph, "prim::Constant", inputs=(), attributes={}
- )[0]
- torch_value.setType(torch.OptionalType.ofTensor())
- else:
- torch_value = self._torch_graph.addInput(input_name)
- torch_value.setType(torch_value.type().with_dtype(dtype)) # type: ignore[arg-type]
- # TODO(titaiwang): This approach loses the information that "same SymInts
- # indicates same shape", for example, [symint0, symint0, symint1]
- # would all be [None, None, None]
- torch_value.setType(
- torch_value.type().with_sizes(
- [dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr]
- )
- )
- tensor_value = _wrap_torch_value_to_tensor(
- torch_value, shape=shape, dtype=dtype, device=device
- )
- if isinstance(tensor_value, TorchScriptTensor):
- # NOTE: Only track value that maps to tensor.
- # Value that maps to Sequence/Dict of tensors is not tracked.
- self._value_to_tensor[torch_value] = tensor_value
- return tensor_value # type: ignore[return-value]
-
- @runtime_typing.checked
- def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor:
- if name in self._initializers_inputs:
- # NOTE: Previously it raises when `name` is already set. This is relaxed
- # because this will be invoked multiple times when submodule is called
- # multiple times.
- if name in self._initializers and self._initializers[name] is not value:
- raise ValueError(
- f"Initializer '{name}' exists already with a different value."
- )
- return self._initializers_inputs[name] # type: ignore[return-value]
-
- if (
- self != self._parent_torch_script_graph
- and self._parent_torch_script_graph is not None
- ):
- # Only the root graph can have initializers. Add as initializer
- # to root graph, and add as input to current graph.
- self._initializers_inputs_from_parent[name] = (
- self._parent_torch_script_graph.add_initializer(name, value)
- )
- else:
- self._initializers[name] = value
-
- torch_value = self._torch_graph.addInput(name)
- torch_value.setType(torch.TensorType.create_from_tensor(value))
- tensor_value = _wrap_torch_value_to_tensor(
- torch_value, shape=value.shape, dtype=value.dtype
- )
- if isinstance(tensor_value, TorchScriptTensor):
- self._value_to_tensor[torch_value] = tensor_value
- self._initializers_inputs[name] = tensor_value # type: ignore[assignment]
- return tensor_value # type: ignore[return-value]
-
- @runtime_typing.checked
- def register_outputs(
- self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]
- ):
- unwrapped_outputs = _unwrap_tensors_to_torch_values(outputs)
- if isinstance(unwrapped_outputs, torch.Value):
- self._torch_graph.registerOutput(unwrapped_outputs)
- return
- assert isinstance(unwrapped_outputs, Sequence)
- for ts_output in unwrapped_outputs:
- assert isinstance(
- ts_output, torch.Value
- ), f"ts_output must be a torch.Value, not {type(ts_output)}"
- self._torch_graph.registerOutput(ts_output)
- return
-
- def _add_constant_to_graph(self, constant) -> torch.Value:
- if constant is None:
- value = _create_op_call_in_torch_graph(
- self._torch_graph, "prim::Constant", inputs=(), attributes={}
- )[0]
- value.setType(torch.OptionalType.ofTensor())
- value.setDebugName(_rename_intermediate_value(value.debugName()))
- return value
-
- if isinstance(constant, bool):
- # Be sure to put bool before int, because bool is a subclass of int
- constant_tensor = torch.tensor(constant, dtype=torch.bool)
- elif isinstance(constant, float):
- constant_tensor = torch.tensor(constant, dtype=torch.float)
- elif isinstance(constant, int):
- constant_tensor = torch.tensor(constant, dtype=torch.int64)
- elif isinstance(constant, (tuple, list)) and all(
- isinstance(val, int) for val in constant
- ):
- constant_tensor = torch.tensor(constant, dtype=torch.int64)
- elif isinstance(constant, (tuple, list)) and all(
- isinstance(val, float) for val in constant
- ):
- constant_tensor = torch.tensor(constant, dtype=torch.float)
- elif isinstance(constant, complex):
- # NOTE: ONNX doesn't support tensor of complex64/complex128, so we
- # convert them to float32/float64 with real representation.
- constant_tensor = torch.view_as_real(torch.tensor(constant).resolve_conj())
- else:
- raise TypeError(
- f"Constant input '{constant}' of type '{type(constant)}' is not supported"
- )
- value = _create_op_call_in_torch_graph(
- self._torch_graph,
- "onnx::Constant",
- inputs=(),
- attributes=dict(value=constant_tensor),
- )[0]
- value.setDebugName(_rename_intermediate_value(value.debugName()))
- return value
-
- @runtime_typing.checked
- def _add_torchscript_op_call(
- self,
- name: str,
- onnx_inputs: Sequence[ValidInputType],
- onnx_attributes: Mapping[str, ValidArgumentType],
- n_outputs: int,
- ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
- unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs)
- graph_inputs = []
- assert isinstance(unwrapped_inputs, Sequence)
- for input in unwrapped_inputs:
- # NOTE(titaiwang): input could be empty list
- if (
- isinstance(input, Sequence)
- and input
- and all(isinstance(elem, torch.Value) for elem in input)
- ):
- # If all elements in the Sequence are torch.Values we know it
- # should be a Sequence input in ONNX.
- input_sequence = _create_op_call_in_torch_graph(
- self._torch_graph,
- "onnx::SequenceConstruct",
- inputs=input,
- attributes={},
- )[0]
- graph_inputs.append(input_sequence)
- elif not isinstance(input, torch.Value):
- graph_inputs.append(self._add_constant_to_graph(input))
- else:
- graph_inputs.append(input)
- for key, value in onnx_attributes.items():
- assert not isinstance(
- value, TorchScriptTensor
- ), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
- result = _create_op_call_in_torch_graph(
- self._torch_graph,
- name,
- inputs=graph_inputs,
- attributes=onnx_attributes,
- n_outputs=n_outputs,
- )
- assert result, "Expected at least one output from ONNX op call."
- # NOTE: TorchScriptTensor is created here, however neither dtype nor shape is
- # set. It is expected that exporter will modify the tensor being returned and
- # set these info.
- if len(result) == 1:
- tensor = TorchScriptTensor(result[0])
- tensor.name = _rename_intermediate_value(tensor.name)
- self._value_to_tensor[result[0]] = tensor
- return tensor
- tensors = tuple(TorchScriptTensor(v) for v in result)
- self._value_to_tensor.update(dict(zip(result, tensors)))
- for tensor in tensors:
- tensor.name = _rename_intermediate_value(tensor.name)
- return tensors
-
- @runtime_typing.checked
- def fetch_function_proto_dict(
- self, opset_version: int
- ) -> Mapping[Tuple[str, str], onnx.FunctionProto]:
- function_proto_dict: Dict[Tuple[str, str], onnx.FunctionProto] = {}
- # Fetch local function protos. E.g., local functions representing module calls.
- for (
- sub_graph_name,
- sub_torch_script_graph,
- ) in self._sub_torch_script_graphs.items():
- function_proto_dict.update(
- sub_torch_script_graph.fetch_function_proto_dict(opset_version)
- )
- domain = sub_torch_script_graph.domain_name
- assert domain is not None
- name_domain = (
- sub_graph_name,
- domain,
- )
- assert (
- name_domain not in function_proto_dict
- ), f"Sub graph name already exists. {name_domain}"
- function_proto_dict[name_domain] = sub_torch_script_graph.to_function_proto(
- opset_version, sub_graph_name
- )
- # Fetch torchlib function protos.
- for name_domain, function in self._function_store.items():
- function_proto_dict[name_domain] = function.to_function_proto()
- return function_proto_dict
-
- @runtime_typing.checked
- def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
- existing_value_info = {info.name: info for info in onnx_model.graph.value_info}
-
- # Override value_info for top level graph inputs.
- for input in self.torch_graph.inputs():
- if input not in self._value_to_tensor:
- raise RuntimeError(f"Input '{input.debugName()}' has no type.")
- tensor = self._value_to_tensor[input]
- if (value_info := tensor.value_info()) is None:
- continue
- for i, input_info in enumerate(onnx_model.graph.input):
- if input_info.name == input.debugName():
- # See NOTE: _C.Value re-naming.
- value_info.name = input_info.name
- onnx_model.graph.input.insert(i, value_info)
- onnx_model.graph.input.remove(input_info)
- break
-
- # Override value_info for top level graph outputs.
- for output in self.torch_graph.outputs():
- if output not in self._value_to_tensor:
- raise RuntimeError(f"Output '{output.debugName()}' has no type.")
- tensor = self._value_to_tensor[output]
- if (value_info := tensor.value_info()) is None:
- continue
- for i, output_info in enumerate(onnx_model.graph.output):
- if output_info.name == output.debugName():
- # See NOTE: _C.Value re-naming.
- value_info.name = output_info.name
- onnx_model.graph.output.insert(i, value_info)
- onnx_model.graph.output.remove(output_info)
- break
-
- # Remove existing static/incomplete value info.
- del onnx_model.graph.value_info[:]
-
- # Insert value info for nodes within nested function calls.
- # NOTE: This is an experimental feature, will be replaced by ValueInfo inside FunctionProto
- # in ONNX 1.16. https://github.com/microsoft/onnxscript/issues/1268
- # The naming strategy is subject to change. Since all local functions representing
- # nn.Modules exported by dynamo exporter have unique call sites, their function
- # op_type name can serve to form the unique identifier for value info.
- # Store inside top level GraphProto.
- new_value_info = self.generate_subgraphs_value_info_proto()
- # Insert value info for nodes in top level graph.
- new_value_info.update(self.generate_maingraph_value_info_proto())
- # Do not store input, output or initializer into value_info
- for input in onnx_model.graph.input:
- new_value_info.pop(input.name, None)
- for output in onnx_model.graph.output:
- new_value_info.pop(output.name, None)
- for tensor in onnx_model.graph.initializer: # type: ignore[assignment]
- new_value_info.pop(tensor.name, None)
- existing_value_info.update(new_value_info)
- onnx_model.graph.value_info.extend(existing_value_info.values())
-
- return onnx_model
-
- @runtime_typing.checked
- def add_op_call(
- self,
- onnx_op_schema: onnx.defs.OpSchema,
- onnx_inputs: Sequence[ValidInputType],
- onnx_attributes: Mapping[str, ValidArgumentType],
- ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
- # Compute outputs from the onnx_op op schema
- n_outputs = evaluator.compute_num_outputs(onnx_op_schema, onnx_inputs, onnx_attributes)
- result = self._add_torchscript_op_call(
- f"onnx::{onnx_op_schema.name}",
- onnx_inputs,
- onnx_attributes,
- n_outputs=n_outputs,
- )
-
- return result
-
- @runtime_typing.checked
- def add_function_call(
- self,
- onnx_function: onnxscript.OnnxFunction,
- onnx_inputs: Sequence[ValidInputType],
- onnx_attributes: Mapping[str, ValidArgumentType],
- ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
- identifier = (onnx_function.name, onnx_function.function_ir.domain)
- self._function_store[identifier] = onnx_function
-
- # Compute outputs from the function schema
- result = self._add_torchscript_op_call(
- f"{onnx_function.function_ir.domain}::{onnx_function.name}",
- onnx_inputs,
- onnx_attributes,
- n_outputs=len(onnx_function.function_ir.outputs),
- )
-
- return result
-
- @runtime_typing.checked
- def add_module_call(
- self,
- name: str,
- sub_torch_script_graph: TorchScriptGraph,
- onnx_inputs: Sequence[ValidInputType],
- ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
- self._sub_torch_script_graphs[name] = sub_torch_script_graph
- domain_name = sub_torch_script_graph.domain_name
- assert domain_name is not None
- return self._add_torchscript_op_call(
- f"{domain_name}::{name}",
- onnx_inputs=(
- *onnx_inputs,
- *sub_torch_script_graph.initializers_inputs_from_parent.values(),
- ),
- onnx_attributes={},
- n_outputs=sub_torch_script_graph.num_outputs,
- )
-
- def generate_function_value_info_proto(
- self, function_op_type: str
- ) -> Mapping[str, onnx.ValueInfoProto]:
- named_value_info: Dict[str, onnx.ValueInfoProto] = {}
- function_id = _function_id(self.domain_name, function_op_type)
- for torch_value, tensor in self._value_to_tensor.items():
- if (value_info := tensor.value_info()) is None:
- continue
- name = f"{function_id}/{torch_value.debugName()}"
- value_info.name = name
- named_value_info[name] = value_info
- named_value_info.update(self.generate_subgraphs_value_info_proto())
- return named_value_info
-
- @runtime_typing.checked
- def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]:
- """Unique naming strategies for values inside subgraphs, i.e. local functions.
-
- {function_domain::function_op_type}/{value_name}
-
- NOTE: Mainly designed for specialized functions, which are local functions
- with only one call site. For non-specialized functions, it is assumed that
- the `value_info` carried in `TorchScriptTensor` represents the general
- compatible shape and type.
- """
- named_value_info: Dict[str, onnx.ValueInfoProto] = {}
- for name, sub_graph in self._sub_torch_script_graphs.items():
- named_value_info.update(sub_graph.generate_function_value_info_proto(name))
- return named_value_info
-
- @runtime_typing.checked
- def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]:
- """Returns value info proto for values in the main graph."""
- named_value_info: Dict[str, onnx.ValueInfoProto] = {}
- for torch_value, tensor in self._value_to_tensor.items():
- if (value_info := tensor.value_info()) is None:
- continue
- # NOTE: _C.Value re-naming.
- # _C.Value's debugName is unstable.
- # When duplicated names are encountered, all names involved are updated by
- # TorchScript naming strategy. Hence the previous name stored in value_info
- # can be outdated.
- value_info.name = torch_value.debugName()
- named_value_info[torch_value.debugName()] = value_info
- return named_value_info
-
- @runtime_typing.checked
- def to_function_proto(self, opset_version: int, function_name: str) -> onnx.FunctionProto:
- assert len(self.initializers) == 0, "Model local functions cannot have initializers."
- (
- proto,
- _,
- _,
- _,
- ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
- initializers={},
- onnx_opset_version=opset_version,
- dynamic_axes={},
- defer_weight_export=False,
- operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
- strip_doc_string=False,
- keep_initializers_as_inputs=False,
- custom_opsets={},
- add_node_names=True,
- onnx_file_path="", # Large model export. Out of scope.
- node_attr_to_name={}, # Current module as function feature does not utilize attributes.
- )
-
- onnx_model = onnx.load_from_string(proto)
-
- # Dissect the model proto and transform to function proto.
- domain = self.domain_name
- if domain is None:
- raise RuntimeError("Domain name is not set.")
- onnx_function = onnx.helper.make_function(
- domain=domain,
- fname=function_name,
- inputs=[input.name for input in onnx_model.graph.input],
- outputs=[output.name for output in onnx_model.graph.output],
- nodes=onnx_model.graph.node,
- opset_imports=onnx_model.opset_import,
- doc_string=onnx_model.doc_string,
- )
- return onnx_function
-
- @runtime_typing.checked
- def to_model_proto(
- self, opset_version: int, include_initializers: bool = True
- ) -> onnx.ModelProto:
- function_proto_dict: Mapping[Tuple[str, str], onnx.FunctionProto] = (
- self.fetch_function_proto_dict(opset_version)
- )
- unique_custom_domains: Dict[str, int] = {}
-
- for function_proto in function_proto_dict.values():
- # TODO(BowenBao): All local function domain versions are hardcoded as 1.
- unique_custom_domains[function_proto.domain] = 1
-
- initializers_size = sum(
- _tensor_rawdata_size(tensor) for tensor in self.initializers.values()
- )
-
- large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD
-
- export_kwargs: dict[str, Any] = dict(
- initializers=self.initializers
- if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS
- else {},
- onnx_opset_version=opset_version,
- dynamic_axes={},
- defer_weight_export=False,
- operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
- strip_doc_string=False,
- keep_initializers_as_inputs=_flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS,
- custom_opsets={},
- add_node_names=True,
- node_attr_to_name={},
- )
-
- # We decided to cache the model to disk when the model is large.
- # Alternatively, we could build the ONNX `TensorProto`s in memory
- # and append them to the model proto.
- # We did not do it because it is harder to get right (vs. PyTorch's battle-tested
- # implementation) and creating the `TensorProto`s naively (by converting to numpy)
- # is slow.
- cache_model_to_disk = large_model and include_initializers
-
- if cache_model_to_disk:
- with tempfile.TemporaryDirectory() as temp_dir:
- onnx_file_path = os.path.join(temp_dir, "exported_model.onnx")
- export_kwargs["onnx_file_path"] = onnx_file_path
- (
- proto,
- _,
- _,
- _,
- ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
- **export_kwargs
- )
- onnx_model = onnx.load_from_string(proto)
- onnx.load_external_data_for_model(onnx_model, temp_dir)
- else:
- (
- proto,
- _,
- _,
- _,
- ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
- **export_kwargs
- )
- onnx_model = onnx.load_from_string(proto)
-
- onnx_model.functions.extend(function_proto_dict.values())
- onnx_model.functions.extend(_shared_functions())
-
- # Override value_infos with symbolic shapes.
- onnx_model = self._override_with_symbolic_value_info_proto(onnx_model)
-
- # `_export_onnx` only exports opset_imports that is visible to it. It does not
- # export opset_imports for nested functions, since it does not have access to
- # them. We manually add them back and merge with existing opset_imports in the
- # model proto.
- while len(onnx_model.opset_import) > 0:
- opsetid = onnx_model.opset_import.pop()
- unique_custom_domains[opsetid.domain] = opsetid.version
- onnx_model.opset_import.extend(
- [
- onnx.helper.make_opsetid(domain, version)
- for domain, version in unique_custom_domains.items()
- ]
- )
- # Include the library shared opset domain
- # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
- onnx_model.opset_import.append(
- onnx.helper.make_opsetid(
- common_ops.common_opset.domain, common_ops.common_opset.version
- )
- )
- return onnx_model
diff --git a/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py
deleted file mode 100644
index 76464b70ef..0000000000
--- a/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py
+++ /dev/null
@@ -1,231 +0,0 @@
-"""Test cases for graph building functionality."""
-
-# mypy: disable-error-code="arg-type,type-arg,valid-type"
-from __future__ import annotations
-
-import os
-import sys
-import unittest
-
-import torch
-
-import onnxscript
-import onnxscript.testing
-from onnxscript import FLOAT, evaluator
-from onnxscript import opset18 as op
-from onnxscript._internal import version_utils
-from onnxscript.function_libs.torch_lib import graph_building, ops
-
-IS_WINDOWS = os.name == "nt"
-
-
-class TestTorchScriptTracingEvaluator(unittest.TestCase):
- def setUp(self):
- self.opset_version = 18
- # TODO: Add test for initializer. Currently skipped since to `assert_isomorphic`
- # does not check for initializers.
- self.onnxscript_graph = graph_building.TorchScriptGraph()
- self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph)
-
- def test_torchscript_tensor_keeps_torch_device(self):
- x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
- x = self.onnxscript_graph.add_input(
- "x", x_tensor.shape, x_tensor.dtype, x_tensor.device
- )
- self.assertEqual(x.device, x_tensor.device)
-
- x.device = torch.device("cuda")
- self.assertEqual(x.device, torch.device("cuda"))
-
- def test_traced_constant_op_is_same_as_compiled_graph(self):
- """Test for op.Constant created in graph builder"""
- with evaluator.default_as(self.tracer):
- output = op.Constant(value_float=0.5)
-
- self.onnxscript_graph.register_outputs(output)
- traced = self.onnxscript_graph.to_model_proto(self.opset_version)
-
- @onnxscript.script()
- def expected_model():
- return op.Constant(value_float=0.5)
-
- expected = expected_model.to_model_proto()
-
- onnxscript.testing.assert_isomorphic(traced, expected)
-
- def test_traced_graph_on_single_node_is_same_as_compiled_graph(self):
- aten_relu = ops.nn.aten_relu
-
- x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
- x = self.onnxscript_graph.add_input("x", x_tensor.shape, x_tensor.dtype)
- with evaluator.default_as(self.tracer):
- output = aten_relu(x)
-
- self.onnxscript_graph.register_outputs(output)
- traced = self.onnxscript_graph.to_model_proto(self.opset_version)
-
- @onnxscript.script(default_opset=op)
- def expected_model(x: FLOAT[1, 2, 3]):
- return aten_relu(x)
-
- expected = expected_model.to_model_proto()
-
- onnxscript.testing.assert_isomorphic(traced, expected)
-
- @unittest.expectedFailure # The scripted version does not have output type
- def test_traced_graph_on_single_node_multi_output_is_same_as_compiled_graph(self):
- aten_topk = ops.core.aten_topk
-
- x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
- x = self.onnxscript_graph.add_input("x", x_tensor.shape, x_tensor.dtype)
- with evaluator.default_as(self.tracer):
- output = aten_topk(x, 2)
-
- self.onnxscript_graph.register_outputs(output)
- traced = self.onnxscript_graph.to_model_proto(self.opset_version)
-
- @onnxscript.script(default_opset=op)
- def expected_model(x: FLOAT[1, 2, 3]):
- values, indices = aten_topk(x, 2)
- return values, indices
-
- expected = expected_model.to_model_proto()
- onnxscript.testing.assert_isomorphic(traced, expected)
-
- def test_model_local_function_constructed_by_traced_graph_is_same_as_compiled_graph(
- self,
- ):
- aten_abs = ops.core.aten_abs
- aten_relu = ops.nn.aten_relu
-
- inner_graph = graph_building.TorchScriptGraph(domain_name="test_domain")
- inner_tracer = graph_building.TorchScriptTracingEvaluator(inner_graph)
-
- x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
- x = inner_graph.add_input("x", x_tensor.shape, x_tensor.dtype)
- with evaluator.default_as(inner_tracer):
- output = aten_abs(x)
- inner_graph.register_outputs(output)
-
- outer_graph = graph_building.TorchScriptGraph()
- outer_tracer = graph_building.TorchScriptTracingEvaluator(outer_graph)
- x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
- x = outer_graph.add_input("x", x_tensor.shape, x_tensor.dtype)
- with evaluator.default_as(outer_tracer):
- output = aten_relu(x)
- output = outer_graph.add_module_call("inner", inner_graph, (output,))
- outer_graph.register_outputs(output)
- traced = outer_graph.to_model_proto(self.opset_version)
-
- @onnxscript.script(
- opset=onnxscript.values.Opset("test_domain", 1),
- default_opset=op,
- )
- def inner(x: FLOAT[1, 2, 3]):
- return aten_abs(x)
-
- @onnxscript.script(default_opset=op)
- def outer(x: FLOAT[1, 2, 3]):
- output = aten_relu(x)
- return inner(output)
-
- expected = outer.to_model_proto()
- onnxscript.testing.assert_isomorphic(traced, expected)
-
- def test_add_input_with_optionaltype_does_not_raise_torch_internal_error(self):
- graph = graph_building.TorchScriptGraph()
- x = graph.add_input(input_name=None)
- with evaluator.default_as(self.tracer):
- _ = x.shape
-
-
-class TestTorchScriptGraph(unittest.TestCase):
- def test_add_initializer_raises_when_the_same_name_used_for_different_tensors(self):
- graph = graph_building.TorchScriptGraph()
- graph.add_initializer("x", torch.ones((1, 2, 3), dtype=torch.float32))
- with self.assertRaises(ValueError):
- graph.add_initializer("x", torch.ones((1, 2, 3), dtype=torch.float32))
-
- def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(self):
- graph = graph_building.TorchScriptGraph()
- x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
- graph.add_initializer("x", x_tensor)
- graph.add_initializer("x", x_tensor)
-
-
-class _MLP(torch.nn.Module):
- def __init__(self, input_size, hidden_size, output_size):
- super().__init__()
- self.fc1 = torch.nn.Linear(input_size, hidden_size)
- self.fc2 = torch.nn.Linear(hidden_size, output_size)
- self.relu = torch.nn.ReLU()
-
- def forward(self, x):
- out = self.fc1(x)
- out = self.relu(out)
- out = self.fc2(out)
- return out
-
-
-@unittest.skipIf(
- IS_WINDOWS and version_utils.torch_older_than("2.3"),
- "dynamo_export not supported on Windows in PyTorch<2.3",
-)
-@unittest.skipIf(
- sys.version_info > (3, 11),
- "dynamo_export not supported due to torch.compile not functional for python>3.11",
-)
-class TestModelSaving(unittest.TestCase):
- def test_save_initializer_to_files_for_large_model(self):
- # # of model parameters:
- # input_size x hidden_size + hidden_size +
- # hidden_size x output_size + output_size
- # ~= 3GB below
- batch_size, input_size, hidden_size, output_size = 1, 4, 50000000, 10
- model = _MLP(input_size, hidden_size, output_size)
- x = torch.randn(batch_size, input_size)
-
- model_proto = torch.onnx.dynamo_export(model, x).model_proto
- # Assert model is larger than 2GB (~=3GB)
- self.assertGreater(model_proto.ByteSize(), 2**31)
-
- def test_input_output_and_initializer_are_not_stored_in_value_info(self):
- batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
- model = _MLP(input_size, hidden_size, output_size)
- x = torch.randn(batch_size, input_size)
-
- model_proto = torch.onnx.dynamo_export(model, x).model_proto
- v_names = {v.name for v in model_proto.graph.value_info}
-
- for i in model_proto.graph.input:
- self.assertNotIn(i.name, v_names)
- for o in model_proto.graph.output:
- self.assertNotIn(o.name, v_names)
- for i in model_proto.graph.initializer:
- self.assertNotIn(i.name, v_names)
-
- @unittest.skipIf(
- not version_utils.torch_older_than("2.4"),
- "PyTorch 2.4-preview optimizes the functions away",
- )
- def test_experimental_function_value_info_are_stored_in_graph_value_info(self):
- batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
- model = _MLP(input_size, hidden_size, output_size)
- x = torch.randn(batch_size, input_size)
-
- model_proto = torch.onnx.dynamo_export(model, x).model_proto
- v_names = {v.name for v in model_proto.graph.value_info}
- torch_functions = [
- f for f in model_proto.functions if f.domain.startswith("pkg.torch")
- ]
- self.assertNotEqual(len(torch_functions), 0)
- for f in torch_functions:
- for n in f.node:
- for i in n.input:
- self.assertIn(f"{f.domain}::{f.name}/{i}", v_names)
- for o in n.output:
- self.assertIn(f"{f.domain}::{f.name}/{o}", v_names)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/function_libs/torch_lib/ops/__init__.py b/onnxscript/function_libs/torch_lib/ops/__init__.py
index 5a1cfd76c0..b7bedaa4b8 100644
--- a/onnxscript/function_libs/torch_lib/ops/__init__.py
+++ b/onnxscript/function_libs/torch_lib/ops/__init__.py
@@ -1,3 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
__all__ = [
"core",
"fft",
@@ -5,9 +7,21 @@
"nested",
"nn",
"prims",
+ "quantized_decomposed",
"sparse",
"special",
"vision",
]
-from . import core, fft, linalg, nested, nn, prims, sparse, special, vision
+from . import (
+ core,
+ fft,
+ linalg,
+ nested,
+ nn,
+ prims,
+ quantized_decomposed,
+ sparse,
+ special,
+ vision,
+)
diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py
index ecef6852b8..b3ebbc1c53 100644
--- a/onnxscript/function_libs/torch_lib/ops/common.py
+++ b/onnxscript/function_libs/torch_lib/ops/common.py
@@ -1,12 +1,22 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Common operators shared in the torchlib library."""
+# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import numpy.typing as npt
+import onnx
+
import onnxscript
import onnxscript.values
-from onnxscript import BOOL, INT64
+from onnxscript import BOOL, INT64, ir
from onnxscript import opset18 as op
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
-from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT
+from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType
COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype
@@ -54,3 +64,38 @@ def cast_to(a: RealType, dtype: int) -> RealType:
result = op.Cast(a, to=dtype)
return result
+
+
+def constant(
+ array: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
+ dtype: int | onnx.TensorProto.DataType | ir.DataType,
+) -> TensorType:
+ """Utility for creating a constant tensor.
+
+ Args:
+ array: The array to convert to a constant tensor.
+ dtype: The data type of the tensor.
+
+ Returns:
+ A constant node.
+ """
+ return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))
+
+
+def merge_dims(dims: Sequence[int | INT64]) -> INT64:
+ """Concatenate dimensions into a single value."""
+
+ if not dims:
+ return op.Constant(value_ints=ir.AttrInt64s("value_ints", []))
+
+ neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1]))
+
+ result_dims = [
+ op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d)
+ for d in dims
+ ]
+
+ # Set the output type to INT64 so op.Concat can be used
+ for dim in result_dims:
+ dim.dtype = ir.DataType.INT64
+ return op.Concat(*result_dims, axis=0)
diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
index a7a6073643..5127f3f9f6 100644
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -1,7 +1,5 @@
-# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
"""torch.ops.aten operators under the `core` module.
@@ -9,30 +7,31 @@
- All functions should not have the script() decorator. This is because
we want to delay the compilation of the function.
"""
+# pylint: disable=unused-argument
from __future__ import annotations
import math
from typing import Any, Optional, Sequence, Tuple, Union
+import numpy as np
+import torch
+
from onnxscript import (
- BFLOAT16,
BOOL,
COMPLEX64,
COMPLEX128,
DOUBLE,
FLOAT,
- FLOAT16,
INT8,
INT16,
INT32,
INT64,
UINT8,
- UINT16,
- UINT32,
- UINT64,
graph,
+ ir,
)
+from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
@@ -40,7 +39,6 @@
RealType,
TFloat,
TFloatHighPrecision,
- TFloatOrBFloat16,
TInt,
TReal,
TRealOrUInt8,
@@ -56,123 +54,102 @@
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
_MATH_PI = math.pi
-IsScalar = common_ops.IsScalar
Rank = common_ops.Rank
-@torch_op("aten::_local_scalar_dense")
-def aten__local_scalar_dense(self: Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) -> FLOAT:
- """_local_scalar_dense(Tensor self) -> Scalar"""
-
- # Return the first element in tensor as a scalar.
- return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=FLOAT.dtype)
-
-
-@torch_op("aten::_local_scalar_dense")
-def aten__local_scalar_dense_int(self: IntType) -> INT64:
+@torch_op("aten::_local_scalar_dense", trace_only=True)
+def aten__local_scalar_dense(self: TensorType) -> TensorType:
"""_local_scalar_dense(Tensor self) -> Scalar"""
# Return the first element in tensor as a scalar.
- return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=INT64.dtype)
+ if self.dtype.is_floating_point():
+ dtype = ir.DataType.FLOAT
+ elif self.dtype == ir.DataType.BOOL:
+ dtype = ir.DataType.BOOL
+ else:
+ dtype = ir.DataType.INT64
+ return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=dtype)
@torch_op("aten::_log_softmax", trace_only=True)
-def aten__log_softmax_half(
- self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool
-) -> FLOAT:
+def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision:
"""_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""
- # trace_only because we need to cast conditionally based on half_to_float
- if half_to_float:
+ self_is_scalar = len(self.shape) == 0
+ if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}:
self = op.Cast(self, to=FLOAT.dtype)
-
- return aten__log_softmax(self, dim, half_to_float)
-
-
-@torch_op("aten::_log_softmax", traceable=True)
-def aten__log_softmax(
- self: TFloatHighPrecision,
- dim: int,
- half_to_float: bool, # pylint: disable=unused-argument
-) -> TFloatHighPrecision:
- """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""
-
- self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.LogSoftmax(self, axis=dim)
- if self_is_scalar: # squeeze to scalar due to input is scalar
- result = op.Squeeze(result)
+ if self_is_scalar:
+ result = op.Squeeze(result, op.Constant(value_ints=[0]))
return result
@torch_op("aten::_softmax", trace_only=True)
-def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT:
+def aten__softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision:
"""_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""
- # trace_only because we need to cast conditionally based on half_to_float
- if half_to_float:
- self = op.Cast(self, to=FLOAT.dtype)
-
- return aten_softmax_no_dtype(self, dim)
+ self_is_scalar = len(self.shape) == 0
+ if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}:
+ self = op.Cast(self, to=FLOAT.dtype)
-@torch_op("aten::_softmax", trace_only=True)
-def aten__softmax(
- self: TFloatHighPrecision, dim: int, half_to_float: bool
-) -> TFloatHighPrecision:
- """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""
-
- # trace_only to reuse aten_softmax_no_dtype
+ if self_is_scalar:
+ self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
+ result = op.Softmax(self, axis=dim)
+ if self_is_scalar:
+ # Convert to scalar when input is scalar
+ result = op.Squeeze(result)
- del half_to_float # Unused
- return aten_softmax_no_dtype(self, dim)
+ return result
-@torch_op(("aten::abs", "_operator::abs"))
+@torch_op(("aten::abs", "_operator::abs"), trace_only=True)
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""
return op.Abs(self)
-@torch_op("aten::abs", complex=True)
+@torch_op("aten::abs", complex=True, trace_only=True)
def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""
- # self_real = self[..., 0]
- self_real = op.Slice(self, [0], [1], axes=[-1])
- # self_imag = self[..., 1]
- self_imag = op.Slice(self, [1], [2], axes=[-1])
- real_pow = op.Pow(self_real, 2)
- imag_pow = op.Pow(self_imag, 2)
- real_plus_imag = op.Add(real_pow, imag_pow)
- return op.Squeeze(op.Sqrt(real_plus_imag), axes=[-1])
+
+ return op.ReduceL2(self, [-1], keepdims=False)
-@torch_op("aten::acos")
+@torch_op("aten::acos", trace_only=True)
def aten_acos(self: TFloat) -> TFloat:
"""acos(Tensor self) -> Tensor"""
return op.Acos(self)
-@torch_op("aten::acosh")
+@torch_op("aten::acosh", trace_only=True)
def aten_acosh(self: TFloat) -> TFloat:
"""acosh(Tensor self) -> Tensor"""
return op.Acosh(self)
-@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"))
-def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
+@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
+def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
- # TODO(microsoft/onnxruntime#15977): Improve fp16 precision
- alpha = op.CastLike(alpha, other)
- other = op.Mul(other, alpha)
+
+ if self.dtype == ir.DataType.BOOL:
+ # alpha can also be bool
+ if alpha == 0:
+ return op.Identity(self)
+ return op.Or(self, other)
+
+ if alpha != 1.0:
+ alpha = op.CastLike(alpha, other)
+ other = op.Mul(other, alpha)
return op.Add(self, other)
-@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"), trace_only=True, complex=True)
+@torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
@@ -232,7 +209,7 @@ def aten_addcmul(
return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2))
-@torch_op("aten::addmm")
+@torch_op("aten::addmm", trace_only=True)
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
@@ -241,6 +218,9 @@ def aten_addmm(
# NOTE: ONNX Runtime does not support int inputs to Gemm as of 1.16.
# To support int inputs, consider an overriding implementation that casts to float and back.
+ alpha = float(alpha)
+ beta = float(beta)
+
# addmm only accepts 2d tensors: https://pytorch.org/docs/stable/generated/torch.addmm.html
return op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
@@ -254,7 +234,7 @@ def aten_addmv(
return op.Add(op.Mul(self, beta), op.Mul(op.MatMul(mat, vec), alpha))
-@torch_op("aten::addr", traceable=True)
+@torch_op("aten::addr", trace_only=True)
def aten_addr(
self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
@@ -303,7 +283,7 @@ def aten_affine_grid_generator_backward(
raise NotImplementedError()
-@torch_op("aten::alias")
+@torch_op("aten::alias", trace_only=True)
def aten_alias(self: TTensor) -> TTensor:
"""alias(Tensor(a) self) -> Tensor(a)"""
@@ -334,11 +314,11 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::all", traceable=True)
+@torch_op("aten::all", trace_only=True)
def aten_all(self: TTensor) -> BOOL:
"""all(Tensor self) -> Tensor"""
- if IsScalar(self):
+ if len(self.shape) == 0:
result = op.Cast(self, to=BOOL.dtype)
else:
self_bool = op.Cast(self, to=BOOL.dtype)
@@ -348,19 +328,15 @@ def aten_all(self: TTensor) -> BOOL:
return result
-@torch_op("aten::all.dim", traceable=True)
+@torch_op("aten::all.dim", trace_only=True)
def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL:
"""all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"""
- if IsScalar(self):
- result = op.Cast(self, to=BOOL.dtype)
- else:
- self_bool = op.Cast(self, to=BOOL.dtype)
- self_int = op.Cast(self_bool, to=INT64.dtype)
- dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
- all_true = op.ReduceMin(self_int, dims, keepdims=keepdim)
- result = op.Cast(all_true, to=BOOL.dtype)
- return result
+ self_bool = op.Cast(self, to=BOOL.dtype)
+ self_int = op.Cast(self_bool, to=INT64.dtype)
+ dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
+ all_true = op.ReduceMin(self_int, dims, keepdims=keepdim)
+ return op.Cast(all_true, to=BOOL.dtype)
@torch_op("aten::all.dims", trace_only=True)
@@ -368,7 +344,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
"""all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor"""
if not dim:
- return aten_all_dims_no_dim(self, keepdim)
+ return _aten_all_dims_no_dim(self, keepdim)
for d in dim:
self = aten_all_dim(self, d, keepdim=True)
if not keepdim:
@@ -376,13 +352,10 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
return self
-@torch_op("aten::all.dims", traceable=True)
-def aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL:
+def _aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL:
"""all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor"""
- # dim is None and thus not supplied
-
- if IsScalar(self):
+ if len(self.shape) == 0:
result = op.Cast(self, to=BOOL.dtype)
else:
self_bool = op.Cast(self, to=BOOL.dtype)
@@ -398,7 +371,7 @@ def aten_allclose(
other: TReal,
rtol: float = 1e-05,
atol: float = 1e-08,
- equal_nan: bool = False, # pylint: disable=unused-argument
+ equal_nan: bool = False,
) -> BOOL:
"""allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool"""
@@ -456,11 +429,11 @@ def aten_angle(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::any", traceable=True)
+@torch_op("aten::any", trace_only=True)
def aten_any(self: TTensor) -> BOOL:
"""any(Tensor self) -> Tensor"""
- if IsScalar(self):
+ if len(self.shape) == 0:
result = op.Cast(self, to=BOOL.dtype)
else:
self_bool = op.Cast(self, to=BOOL.dtype)
@@ -471,21 +444,17 @@ def aten_any(self: TTensor) -> BOOL:
return result
-@torch_op("aten::any.dim", traceable=True)
+@torch_op("aten::any.dim", trace_only=True)
def aten_any_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL:
"""any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"""
- if IsScalar(self):
- result = op.Cast(self, to=BOOL.dtype)
- else:
- self_bool = op.Cast(self, to=BOOL.dtype)
- # op.ReduceMax() in the next step cannot process BOOL inputs, so convert to INT64
- self_int = op.Cast(self_bool, to=INT64.dtype)
- # Change dim from int to INT64[1]
- dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
- any_true = op.ReduceMax(self_int, dims, keepdims=keepdim)
- result = op.Cast(any_true, to=BOOL.dtype)
- return result
+ self_bool = op.Cast(self, to=BOOL.dtype)
+ # op.ReduceMax() in the next step cannot process BOOL inputs, so convert to INT64
+ self_int = op.Cast(self_bool, to=INT64.dtype)
+ # Change dim from int to INT64[1]
+ dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
+ any_true = op.ReduceMax(self_int, dims, keepdims=keepdim)
+ return op.Cast(any_true, to=BOOL.dtype)
@torch_op("aten::any.dims", trace_only=True)
@@ -493,7 +462,7 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
"""any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor"""
if not dim:
- return aten_any_dims_no_dim(self, keepdim)
+ return _aten_any_dims_no_dim(self, keepdim)
for d in dim:
self = aten_any_dim(self, d, keepdim=True)
if not keepdim:
@@ -501,13 +470,8 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
return self
-@torch_op("aten::any.dims", traceable=True)
-def aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL:
- """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor"""
-
- # dim is None and thus not supplied
-
- if IsScalar(self):
+def _aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL:
+ if len(self.shape) == 0:
result = op.Cast(self, to=BOOL.dtype)
else:
self_bool = op.Cast(self, to=BOOL.dtype)
@@ -538,13 +502,16 @@ def _integral_to_be_adjusted(dtype: int) -> bool:
@torch_op("aten::arange", trace_only=True)
-def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1) -> TensorType:
+def aten_arange(
+ end: TRealUnlessFloat16OrInt8,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
- # NOTE: trace_only because both if branches need to be the same type, but we have
- # a cast in the if branch.
-
- if dtype == -1:
+ if dtype == -1 or dtype is None:
zero = op.CastLike(0.0, end)
one = op.CastLike(1.0, end)
result = op.Range(zero, end, one)
@@ -568,14 +535,16 @@ def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1)
@torch_op("aten::arange.start", trace_only=True)
def aten_arange_start(
- start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, dtype: int = -1
+ start: TRealUnlessFloat16OrInt8,
+ end: TRealUnlessFloat16OrInt8,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TensorType:
"""arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
- # NOTE: trace_only because both if branches need to be the same type, but we have
- # a cast in the if branch.
-
- if dtype == -1:
+ if dtype == -1 or dtype is None:
one = op.CastLike(1.0, end)
result = op.Range(start, end, one)
elif _range_supported(dtype):
@@ -596,7 +565,6 @@ def aten_arange_start(
return result
-@torch_op("aten::arange.start_step", private=True)
def _adjust_args_for_arange_int_dtype(
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
@@ -617,16 +585,51 @@ def _adjust_args_for_arange_int_dtype(
def aten_arange_start_step(
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
- step: TRealUnlessFloat16OrInt8,
+ step: TRealUnlessFloat16OrInt8 = 1.0,
dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TensorType:
"""arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
- # NOTE: trace_only because both if branches need to be the same type, but we have
- # a cast in the if branch.
-
if dtype == -1:
- result = op.Range(start, end, step)
+ # TODO: Because this is a trace_only function, the inputs are not promoted to
+ # Tensor until it hits ONNX ops. However, if it's dynamic, it should be
+ # Tensor at this point.
+ # https://github.com/microsoft/onnxscript/issues/1914
+ if isinstance(start, (int, float)):
+ start_is_int = isinstance(start, int)
+ else:
+ start_is_int = start.dtype in {
+ INT16.dtype,
+ INT32.dtype,
+ INT64.dtype,
+ }
+ if isinstance(end, (int, float)):
+ end_is_int = isinstance(end, int)
+ else:
+ end_is_int = end.dtype in {
+ INT16.dtype,
+ INT32.dtype,
+ INT64.dtype,
+ }
+ if isinstance(step, (int, float)):
+ step_is_int = isinstance(step, int)
+ else:
+ step_is_int = step.dtype in {
+ INT16.dtype,
+ INT32.dtype,
+ INT64.dtype,
+ }
+ if start_is_int and end_is_int and step_is_int:
+ result = op.Range(start, end, step)
+ else:
+ # to float
+ start = op.Cast(start, to=FLOAT.dtype)
+ end = op.Cast(end, to=FLOAT.dtype)
+ step = op.Cast(step, to=FLOAT.dtype)
+ result = op.Range(start, end, step)
elif _integral_to_be_adjusted(dtype):
# PyTorch arange op handles these integral types differently from INT64,
# so we have to adjust these arguments accordingly.
@@ -693,11 +696,23 @@ def aten_arctanh(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::argmax", traceable=True)
-def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
+@torch_op("aten::argmax", trace_only=True)
+def aten_argmax(
+ self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False
+) -> INT64:
+ """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
+
+ if dim is None:
+ result = _aten_argmax(self, keepdim)
+ else:
+ result = _aten_argmax_dim(self, dim, keepdim)
+ return result
+
+
+def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
- self_is_scaler = IsScalar(self)
+ self_is_scaler = len(self.shape) == 0
self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.ArgMax(self, keepdims=keepdim)
if self_is_scaler:
@@ -706,11 +721,10 @@ def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
return result
-@torch_op("aten::argmax", traceable=True)
-def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
+def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
- self_is_scaler = IsScalar(self)
+ self_is_scaler = len(self.shape) == 0
if self_is_scaler:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
@@ -721,11 +735,23 @@ def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fals
return result
-@torch_op("aten::argmin", traceable=True)
-def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
+@torch_op("aten::argmin", trace_only=True)
+def aten_argmin(
+ self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False
+) -> INT64:
+ """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
+
+ if dim is None:
+ result = _aten_argmin(self, keepdim)
+ else:
+ result = _aten_argmin_dim(self, dim, keepdim)
+ return result
+
+
+def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
- self_is_scaler = IsScalar(self)
+ self_is_scaler = len(self.shape) == 0
self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.ArgMin(self, keepdims=keepdim)
if self_is_scaler:
@@ -734,11 +760,10 @@ def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
return result
-@torch_op("aten::argmin", traceable=True)
-def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
+def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""
- self_is_scaler = IsScalar(self)
+ self_is_scaler = len(self.shape) == 0
if self_is_scaler:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
@@ -763,7 +788,7 @@ def aten_argwhere(self: TensorType) -> TensorType:
@torch_op("aten::as_strided", trace_only=True)
def aten_as_strided(
- self: TTensor, size: INT64, stride: INT64, storage_offset: int = 0
+ self: TTensor, size: INT64, stride: Sequence[int], storage_offset: int = 0
) -> TTensor:
"""as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"""
@@ -851,55 +876,60 @@ def aten_as_strided_scatter(
raise NotImplementedError()
-@torch_op("aten::asin")
+@torch_op("aten::asin", trace_only=True)
def aten_asin(self: TFloat) -> TFloat:
"""asin(Tensor self) -> Tensor"""
return op.Asin(self)
-@torch_op("aten::asinh")
+@torch_op("aten::asinh", trace_only=True)
def aten_asinh(self: TFloat) -> TFloat:
"""asinh(Tensor self) -> Tensor"""
return op.Asinh(self)
-@torch_op("aten::atan")
+@torch_op("aten::atan", trace_only=True)
def aten_atan(self: TFloat) -> TFloat:
"""atan(Tensor self) -> Tensor"""
return op.Atan(self)
-@torch_op("aten::atan2")
+@torch_op("aten::atan2", trace_only=True)
def aten_atan2(self: TFloat, other: TFloat) -> TFloat:
"""atan2(Tensor self, Tensor other) -> Tensor"""
# self is y, and other is x on coordinate
slope = op.Div(self, other)
atan = op.Atan(slope)
+ zero = common_ops.constant(0.0, dtype=self.dtype)
+ pi = common_ops.constant(_MATH_PI, dtype=self.dtype)
+
+ second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi)
+ result = op.Where(op.Less(other, zero), second_third_quadrant, atan)
- second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI)
- result = op.Where(other < 0.0, second_third_quadrant, atan)
+ # Map NaN to 0 to match PyTorch behavior
+ result = op.Where(op.IsNaN(result), zero, result)
return result
-@torch_op("aten::atanh")
+@torch_op("aten::atanh", trace_only=True)
def aten_atanh(self: TFloat) -> TFloat:
"""atanh(Tensor self) -> Tensor"""
return op.Atanh(self)
-@torch_op("aten::atleast_1d", traceable=True)
+@torch_op("aten::atleast_1d", trace_only=True)
def aten_atleast_1d(self: TTensor) -> TTensor:
"""atleast_1d(Tensor self) -> Tensor"""
- if IsScalar(self):
+ if len(self.shape) == 0:
self = op.Reshape(self, op.Constant(value_ints=[1]))
- return self
+ return op.Identity(self)
@torch_op("aten::atleast_1d.Sequence")
@@ -923,7 +953,7 @@ def aten_atleast_2d(self: TTensor) -> TTensor:
if Rank(self) <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1]))
- return self
+ return op.Identity(self)
@torch_op("aten::atleast_2d.Sequence")
@@ -941,7 +971,7 @@ def reshape_to_2d(tensor):
return op.SequenceMap(self, body=reshape_to_2d)
-@torch_op("aten::atleast_3d", traceable=True)
+@torch_op("aten::atleast_3d", trace_only=True)
def aten_atleast_3d(self: TTensor) -> TTensor:
"""atleast_3d(Tensor self) -> Tensor"""
@@ -950,7 +980,7 @@ def aten_atleast_3d(self: TTensor) -> TTensor:
self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1]))
elif rank == 2:
self = op.Unsqueeze(self, op.Constant(value_ints=[-1]))
- return self
+ return op.Identity(self)
@torch_op("aten::atleast_3d.Sequence")
@@ -970,20 +1000,25 @@ def reshape_to_3d(tensor):
return op.SequenceMap(self, body=reshape_to_3d)
-@torch_op("aten::baddbmm")
+@torch_op("aten::baddbmm", trace_only=True)
def aten_baddbmm(
self: TRealOrUInt8,
batch1: TRealUnlessInt16OrInt8,
batch2: TRealUnlessInt16OrInt8,
- beta: float = 1.0,
- alpha: float = 1.0,
+ beta: Optional[TFloat] = None,
+ alpha: Optional[TFloat] = None,
) -> TRealUnlessInt16OrInt8:
"""baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
+ # beta and alpha can be SymFloat
batch_mul = op.MatMul(batch1, batch2)
- alpha_cast = op.CastLike(alpha, self)
- mul_a = op.Mul(batch_mul, alpha_cast)
- beta_cast = op.CastLike(beta, self)
- mul_b = op.Mul(self, beta_cast)
+ if alpha is None or alpha == 1:
+ mul_a = batch_mul
+ else:
+ mul_a = op.Mul(batch_mul, op.CastLike(alpha, self))
+ if beta is None or beta == 1:
+ mul_b = self
+ else:
+ mul_b = op.Mul(self, op.CastLike(beta, self))
return op.Add(mul_a, mul_b)
@@ -1099,7 +1134,7 @@ def aten_batch_norm_update_stats(
raise NotImplementedError()
-@torch_op("aten::bernoulli")
+@torch_op("aten::bernoulli", trace_only=True)
def aten_bernoulli(self: TFloat) -> TFloat:
"""Proximal implementation of aten::bernoulli.default
@@ -1126,6 +1161,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor:
return op.CastLike(sampled, self)
+@torch_op("aten::bilinear", trace_only=True)
def aten_bilinear(
input1: TensorType,
input2: TensorType,
@@ -1134,7 +1170,23 @@ def aten_bilinear(
) -> TensorType:
"""bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor"""
- raise NotImplementedError()
+ # Bilinear transformation: y = x1^T A x2 + b
+ # input1 shape: (..., in1_features)
+ # input2 shape: (..., in2_features)
+ # weight shape: (out_features, in1_features, in2_features)
+ # bias shape: (out_features) - optional
+ # output shape: (..., out_features)
+
+ # Use Einsum to compute the bilinear transformation
+ # "...i,oij,...j->...o" means:
+ # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
+ result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")
+
+ # Add bias if provided
+ if bias is not None:
+ result = op.Add(result, bias)
+
+ return result
def aten_binary_cross_entropy_with_logits(
@@ -1167,176 +1219,179 @@ def aten_binomial(
@torch_op(
(
- "aten::bitwise_and",
"aten::bitwise_and.Tensor",
- "aten::bitwise_and.Scalar",
- "aten::bitwise_and.Scalar_Tensor",
"_operator::and_",
- )
+ ),
+ trace_only=True,
)
-def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
+def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
- # logical_and implements the BOOL variant
- return op.BitwiseAnd(self, other)
+ assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
+ dtype = self.dtype if self.dtype is not None else other.dtype
+ assert dtype is not None
+ if dtype.is_integer():
+ return op.BitwiseAnd(self, other)
+ if dtype == ir.DataType.BOOL:
+ return op.And(self, other)
+ raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
-@torch_op("aten::bitwise_left_shift")
-def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16:
- """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
- # assert other >= 0
- self = op.Cast(self, to=UINT16.dtype)
- other = op.Cast(other, to=UINT16.dtype)
-
- result = op.BitShift(self, other, direction="LEFT")
- return op.Cast(result, to=INT16.dtype)
+@torch_op("aten::bitwise_and.Scalar", trace_only=True)
+def aten_bitwise_and_scalar(self: TTensor, other: int) -> TTensor:
+ """bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor"""
+ other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
+ return aten_bitwise_and(self, other_tensor)
-@torch_op("aten::bitwise_left_shift")
-def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32:
- """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
- # assert other >= 0
- self = op.Cast(self, to=UINT32.dtype)
- other = op.Cast(other, to=UINT32.dtype)
- result = op.BitShift(self, other, direction="LEFT")
+@torch_op("aten::bitwise_and.Scalar_Tensor", trace_only=True)
+def aten_bitwise_and_scalar_tensor(self: float, other: TTensor) -> TTensor:
+ """bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
- return op.Cast(result, to=INT32.dtype)
+ self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
+ return aten_bitwise_and(self_tensor, other)
-@torch_op("aten::bitwise_left_shift")
-def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64:
+@torch_op(
+ (
+ "aten::bitwise_left_shift.Tensor",
+ "_operator::__lshift__",
+ ),
+ trace_only=True,
+)
+def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
+ assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
+ dtype = self.dtype if self.dtype is not None else other.dtype
+ assert dtype is not None
+
# assert other >= 0
- self = op.Cast(self, to=UINT64.dtype)
- other = op.Cast(other, to=UINT64.dtype)
+ if dtype.bitwidth == 8:
+ unsigned_dtype = ir.DataType.UINT8
+ signed_dtype = ir.DataType.INT8
+ elif dtype.bitwidth == 16:
+ unsigned_dtype = ir.DataType.UINT16
+ signed_dtype = ir.DataType.INT16
+ elif dtype.bitwidth == 32:
+ unsigned_dtype = ir.DataType.UINT32
+ signed_dtype = ir.DataType.INT32
+ elif dtype.bitwidth == 64:
+ unsigned_dtype = ir.DataType.UINT64
+ signed_dtype = ir.DataType.INT64
+ else:
+ raise NotImplementedError(f"Not implemented for type {dtype}")
+
+ self = op.Cast(self, to=unsigned_dtype)
+ other = op.Cast(other, to=unsigned_dtype)
result = op.BitShift(self, other, direction="LEFT")
- return op.Cast(result, to=INT64.dtype)
+ return op.Cast(result, to=signed_dtype)
-@torch_op("aten::bitwise_left_shift")
-def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
- """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
- # assert other >= 0
- self = op.Cast(self, to=UINT8.dtype)
- other = op.Cast(other, to=UINT8.dtype)
+@torch_op(
+ ("aten::bitwise_left_shift.Tensor_Scalar", "aten::__lshift__.Scalar"), trace_only=True
+)
+def aten_bitwise_left_shift_tensor_scalar(self: TInt, other: int) -> TInt:
+ """bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
+ other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
+ return aten_bitwise_left_shift(self, other_tensor)
- result = op.BitShift(self, other, direction="LEFT")
- return op.Cast(result, to=INT8.dtype)
+@torch_op("aten::bitwise_left_shift.Scalar_Tensor", trace_only=True)
+def aten_bitwise_left_shift_scalar_tensor(self: int, other: TInt) -> TInt:
+ """bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
+ self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
+ return aten_bitwise_left_shift(self_tensor, other)
-@torch_op("aten::bitwise_not")
-def aten_bitwise_not(self: TInt) -> TInt:
+@torch_op("aten::bitwise_not", trace_only=True)
+def aten_bitwise_not(self: TTensor) -> TTensor:
"""bitwise_not(Tensor self) -> Tensor"""
- # logical_not implements the BOOL variant
- return op.BitwiseNot(self)
+ if self.dtype == ir.DataType.BOOL:
+ return op.Not(self)
+ if self.dtype.is_integer():
+ return op.BitwiseNot(self)
+ raise NotImplementedError(f"Not implemented for type {self.dtype}")
@torch_op(
(
- "aten::bitwise_or",
"aten::bitwise_or.Tensor",
- "aten::bitwise_or.Scalar",
- "aten::bitwise_or.Scalar_Tensor",
"_operator::or_",
- )
+ ),
+ trace_only=True,
)
-def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
+def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
- # logical_or implements the BOOL variant
-
- return op.BitwiseOr(self, other)
+ assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
+ dtype = self.dtype if self.dtype is not None else other.dtype
+ assert dtype is not None
-@torch_op("aten::bitwise_right_shift")
-def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16:
- """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
- negative = op.Less(self, 0)
- self = op.Cast(self, to=UINT16.dtype)
- other = op.Cast(other, to=UINT16.dtype)
+ if dtype.is_integer():
+ return op.BitwiseOr(self, other)
+ if dtype == ir.DataType.BOOL:
+ return op.Or(self, other)
+ raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
- # Simulate arithmetic shift using logical shift
- # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting
- mask = op.BitShift(
- op.Cast(op.Constant(value_int=0xFFFF), to=UINT16.dtype), other, direction="RIGHT"
- )
- mask = op.BitwiseNot(mask)
- # Do logical shift
- shifted = op.BitShift(self, other, direction="RIGHT")
- # Compute the arithmetic shifted value assuming the sign bit was set
- negative_shifted = op.BitwiseOr(shifted, mask)
- # Choose the shifted value based on the sign bit
- return op.Where(
- negative, op.Cast(negative_shifted, to=INT16.dtype), op.Cast(shifted, to=INT16.dtype)
- )
+@torch_op("aten::bitwise_or.Scalar", trace_only=True)
+def aten_bitwise_or_scalar(self: TTensor, other: int) -> TTensor:
+ """bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor"""
+ other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
+ return aten_bitwise_or(self, other_tensor)
-@torch_op("aten::bitwise_right_shift")
-def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32:
- """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
- negative = op.Less(self, 0)
- self = op.Cast(self, to=UINT32.dtype)
- other = op.Cast(other, to=UINT32.dtype)
- # Simulate arithmetic shift using logical shift
- # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting
- mask = op.BitShift(
- op.Cast(op.Constant(value_int=0xFFFFFFFF), to=UINT32.dtype), other, direction="RIGHT"
- )
- mask = op.BitwiseNot(mask)
- # Do logical shift
- shifted = op.BitShift(self, other, direction="RIGHT")
- # Compute the arithmetic shifted value assuming the sign bit was set
- negative_shifted = op.BitwiseOr(shifted, mask)
- # Choose the shifted value based on the sign bit
- return op.Where(
- negative, op.Cast(negative_shifted, to=INT32.dtype), op.Cast(shifted, to=INT32.dtype)
- )
+@torch_op("aten::bitwise_or.Scalar_Tensor", trace_only=True)
+def aten_bitwise_or_scalar_tensor(self: int, other: TTensor) -> TTensor:
+ """bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
+ self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
+ return aten_bitwise_or(self_tensor, other)
-@torch_op("aten::bitwise_right_shift")
-def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64:
+@torch_op(
+ (
+ "aten::bitwise_right_shift.Tensor",
+ "_operator::__rshift__",
+ ),
+ trace_only=True,
+)
+def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt:
"""bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
- negative = op.Less(self, 0)
- self = op.Cast(self, to=UINT64.dtype)
- other = op.Cast(other, to=UINT64.dtype)
-
- # Simulate arithmetic shift using logical shift
- # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting
- mask = op.BitShift(
- # 0xFFFFFFFFFFFFFFFF
- op.Cast(op.Constant(value_int=-1), to=UINT64.dtype),
- other,
- direction="RIGHT",
- )
- mask = op.BitwiseNot(mask)
- # Do logical shift
- shifted = op.BitShift(self, other, direction="RIGHT")
- # Compute the arithmetic shifted value assuming the sign bit was set
- negative_shifted = op.BitwiseOr(shifted, mask)
- # Choose the shifted value based on the sign bit
- return op.Where(
- negative, op.Cast(negative_shifted, to=INT64.dtype), op.Cast(shifted, to=INT64.dtype)
- )
-
+ assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
+ dtype = self.dtype if self.dtype is not None else other.dtype
+ assert dtype is not None
+
+ if dtype.bitwidth == 8:
+ unsigned_dtype = ir.DataType.UINT8
+ signed_dtype = ir.DataType.INT8
+ mask = ir.tensor(0xFF, dtype=unsigned_dtype)
+ elif dtype.bitwidth == 16:
+ unsigned_dtype = ir.DataType.UINT16
+ signed_dtype = ir.DataType.INT16
+ mask = ir.tensor(0xFFFF, dtype=unsigned_dtype)
+ elif dtype.bitwidth == 32:
+ unsigned_dtype = ir.DataType.UINT32
+ signed_dtype = ir.DataType.INT32
+ mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype)
+ elif dtype.bitwidth == 64:
+ unsigned_dtype = ir.DataType.UINT64
+ signed_dtype = ir.DataType.INT64
+ mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF
+ else:
+ raise NotImplementedError(f"Not implemented for type {dtype}")
-@torch_op("aten::bitwise_right_shift")
-def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
- """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
negative = op.Less(self, 0)
- self = op.Cast(self, to=UINT8.dtype)
- other = op.Cast(other, to=UINT8.dtype)
+ self = op.Cast(self, to=unsigned_dtype)
+ other = op.Cast(other, to=unsigned_dtype)
# Simulate arithmetic shift using logical shift
# Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting
- mask = op.BitShift(
- op.Cast(op.Constant(value_int=0xFF), to=UINT8.dtype), other, direction="RIGHT"
- )
+ mask = op.BitShift(mask, other, direction="RIGHT")
mask = op.BitwiseNot(mask)
# Do logical shift
shifted = op.BitShift(self, other, direction="RIGHT")
@@ -1344,29 +1399,68 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
negative_shifted = op.BitwiseOr(shifted, mask)
# Choose the shifted value based on the sign bit
return op.Where(
- negative, op.Cast(negative_shifted, to=INT8.dtype), op.Cast(shifted, to=INT8.dtype)
+ negative, op.Cast(negative_shifted, to=signed_dtype), op.Cast(shifted, to=signed_dtype)
)
@torch_op(
- (
- "aten::bitwise_xor",
- "aten::bitwise_xor.Tensor",
- "aten::bitwise_xor.Scalar",
- "aten::bitwise_xor.Scalar_Tensor",
- )
+ ("aten::bitwise_right_shift.Tensor_Scalar", "aten::__rshift__.Scalar"), trace_only=True
)
-def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
+def aten_bitwise_right_shift_tensor_scalar(self: TInt, other: int) -> TInt:
+ """bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor"""
+ other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
+ return aten_bitwise_right_shift(self, other_tensor)
+
+
+@torch_op("aten::bitwise_right_shift.Scalar_Tensor", trace_only=True)
+def aten_bitwise_right_shift_scalar_tensor(self: int, other: TInt) -> TInt:
+ """bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
+ self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
+ return aten_bitwise_right_shift(self_tensor, other)
+
+
+@torch_op("aten::bitwise_xor.Tensor", trace_only=True)
+def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
- # logical_xor implements the BOOL variant
- return op.BitwiseXor(self, other)
+ assert self.dtype == other.dtype or self.dtype is None or other.dtype is None
+ dtype = self.dtype if self.dtype is not None else other.dtype
+ assert dtype is not None
+
+ if dtype.is_integer():
+ return op.BitwiseXor(self, other)
+ if dtype == ir.DataType.BOOL:
+ return op.Xor(self, other)
+ raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
+
+
+@torch_op("aten::bitwise_xor.Scalar", trace_only=True)
+def aten_bitwise_xor_scalar(self: TTensor, other: int) -> TTensor:
+ """bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor"""
+ other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype))
+ return aten_bitwise_xor(self, other_tensor)
+
+@torch_op("aten::bitwise_xor.Scalar_Tensor", trace_only=True)
+def aten_bitwise_xor_scalar_tensor(self: int, other: TTensor) -> TTensor:
+ """bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
+ self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype))
+ return aten_bitwise_xor(self_tensor, other)
-def aten_blackman_window(window_length: int) -> TensorType:
+
+@torch_op("aten::blackman_window", trace_only=True)
+def aten_blackman_window(
+ window_length: int,
+ dtype: int = 1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
- raise NotImplementedError()
+ if dtype is None or dtype == -1:
+ dtype = 1
+ return op.BlackmanWindow(window_length, output_datatype=dtype)
def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType:
@@ -1375,7 +1469,7 @@ def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::bmm")
+@torch_op("aten::bmm", trace_only=True)
def aten_bmm(self: TFloat, mat2: TFloat) -> TFloat:
"""bmm(Tensor self, Tensor mat2) -> Tensor"""
@@ -1388,10 +1482,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::broadcast_to")
-def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor:
+@torch_op("aten::broadcast_to", trace_only=True)
+def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
-
+ size = common_ops.merge_dims(size)
return op.Expand(self, size)
@@ -1424,15 +1518,13 @@ def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
return aten_cat(tensors, dim=dim)
-@torch_op("aten::cat")
+@torch_op(("aten::cat", "aten::concat", "aten::concatenate"), trace_only=True)
def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
"""cat(Tensor[] tensors, int dim=0) -> Tensor"""
- # NOTE: Having empty tensors when concatenating along non-zero dimension
- # is not supported.
- # TODO(justinchuby): Filter these tensors out with Sequence ops before
- # calling ConcatFromSequence.
- return op.ConcatFromSequence(tensors, axis=dim)
+ # Remove None tensors
+ tensors = [tensor for tensor in tensors if tensor is not None]
+ return op.Concat(*tensors, axis=dim)
def aten_ccol_indices(self: TensorType) -> TensorType:
@@ -1455,14 +1547,14 @@ def aten_cdist(
raise NotImplementedError()
-@torch_op("aten::ceil")
+@torch_op("aten::ceil", trace_only=True)
def aten_ceil(self: TFloat) -> TFloat:
"""ceil(Tensor self) -> Tensor"""
return op.Ceil(self)
-@torch_op("math::ceil")
+@torch_op("math::ceil", trace_only=True)
def python_math_ceil(self: TFloat) -> TInt:
"""ceil(Tensor self) -> Tensor"""
ceil = op.Ceil(self)
@@ -1515,38 +1607,68 @@ def aten_choose_qparams_optimized(
raise NotImplementedError()
-@torch_op("aten::chunk")
-def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
- """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
- # This will create a Sequence of tensors
- neg_1 = op.Constant(value_ints=[-1])
- # Get size of specified dim
- self_shape = op.Shape(self)
- dim_size = op.Gather(self_shape, dim, axis=0)
- # Compute size/chunk to get the number of data in one chunk
- num_per_chunk = op.Div(dim_size, chunks)
- num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
+if version_utils.torch_older_than("2.7.0"):
+ # PyTorch <2.7 does not support determining the number of outputs for the Split op
+ # https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6
+ @torch_op("aten::chunk")
+ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
+ """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
+ # This will create a Sequence of tensors
+ neg_1 = op.Constant(value_ints=[-1])
+ # Get size of specified dim
+ self_shape = op.Shape(self)
+ dim_size = op.Gather(self_shape, dim, axis=0)
+ # Compute size/chunk to get the number of data in one chunk
+ num_per_chunk = op.Div(dim_size, chunks)
+ num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
- # Compute real chunk number
- num_chunk = op.Div(dim_size, num_per_chunk)
- # Get something like [n, n, n, n, ...], total num_chunk
- list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
+ # Compute real chunk number
+ num_chunk = op.Div(dim_size, num_per_chunk)
+ # Get something like [n, n, n, n, ...], total num_chunk
+ list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
- remainder = op.Mod(dim_size, num_per_chunk)
- if remainder > 0: # type: ignore[operator]
- # Append the remainder to the [n, n, n, n, ..., r]
- list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
+ remainder = op.Mod(dim_size, num_per_chunk)
+ if remainder > 0: # type: ignore[operator]
+ # Append the remainder to the [n, n, n, n, ..., r]
+ list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
- return op.SplitToSequence(self, list_split, axis=dim)
+ return op.SplitToSequence(self, list_split, axis=dim)
+else:
+ @torch_op("aten::chunk", trace_only=True)
+ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
+ """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
+ if chunks == 1:
+ return op.Identity(self)
+ return op.Split(self, axis=dim, num_outputs=chunks)
-@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True)
-def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal:
- """clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
- clamped = self
+
+@torch_op("aten::clamp", trace_only=True)
+def aten_clamp(self: TReal, min: Optional[float] = None, max: Optional[float] = None) -> TReal:
+ """clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"""
+
+ if min is None and max is None:
+ return op.Identity(self)
+
+ if min is not None:
+ min = op.CastLike(min, self)
+
+ if max is not None:
+ max = op.CastLike(max, self)
+
+ return op.Clip(self, min, max)
+
+
+@torch_op("aten::clamp.Tensor", trace_only=True)
+def aten_clamp_tensor(
+ self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None
+) -> TReal:
+ """clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
if min is None and max is None:
- return clamped
+ return op.Identity(self)
+
+ clamped = self
# If min is greater than max torch.clamp(..., min, max)
# sets all elements in input to the value of max.
@@ -1562,48 +1684,58 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] =
return clamped
-@torch_op("aten::clamp_max", traceable=True)
-def aten_clamp_max(self: TReal, max_: TReal) -> TReal:
- """clamp_max(Tensor self, Tensor max) -> Tensor"""
+@torch_op("aten::clamp_max", trace_only=True)
+def aten_clamp_max(self: TReal, max_: float) -> TReal:
+ """clamp_max(Tensor self, Scalar max) -> Tensor"""
- self_size = op.Size(self)
- max_shape = op.Shape(max_)
- max_rank = op.Size(max_shape)
- if self_size == 0:
- result = op.Expand(self, max_shape)
+ # This implementation does not intend to handle when self is an empty tensor
+ max_ = op.CastLike(max_, self)
+ return op.Clip(self, None, max_)
+
+
+@torch_op("aten::clamp_max.Tensor", trace_only=True)
+def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal:
+ """clamp_max.Tensor(Tensor self, Tensor max) -> Tensor"""
+
+ # This implementation does not intend to handle when self is an empty tensor
+ max_rank = len(max_.shape)
+ if max_rank == 0:
+ max_ = op.CastLike(max_, self)
+ result = op.Clip(self, None, max_)
else:
- if max_rank == 0:
- max_ = op.CastLike(max_, self)
- result = op.Clip(self, None, max_)
- else:
- result = op.Min(self, max_)
+ result = op.Min(self, max_)
return result
-@torch_op("aten::clamp_min", traceable=True)
-def aten_clamp_min(self: TReal, min_: TReal) -> TReal:
- """clamp_min(Tensor self, Tensor min) -> Tensor"""
+@torch_op("aten::clamp_min", trace_only=True)
+def aten_clamp_min(self: TReal, min_: float) -> TReal:
+ """clamp_min(Tensor self, Scalar min) -> Tensor"""
- self_size = op.Size(self)
- min_shape = op.Shape(min_)
- min_rank = op.Size(min_shape)
- if self_size == 0:
- result = op.Expand(self, min_shape)
+ # This implementation does not intend to handle when self is an empty tensor
+ min_ = op.CastLike(min_, self)
+ return op.Clip(self, min_, None)
+
+
+@torch_op("aten::clamp_min.Tensor", trace_only=True)
+def aten_clamp_min_tensor(self: TReal, min_: TReal) -> TReal:
+ """clamp_min.Tensor(Tensor self, Tensor min) -> Tensor"""
+
+ # This implementation does not intend to handle when self is an empty tensor
+ min_rank = len(min_.shape)
+ if min_rank == 0:
+ min_ = op.CastLike(min_, self)
+ result = op.Clip(self, min_, None)
else:
- if min_rank == 0:
- min_ = op.CastLike(min_, self)
- result = op.Clip(self, min_, None)
- else:
- result = op.Max(self, min_)
+ result = op.Max(self, min_)
return result
-@torch_op("aten::clone")
+@torch_op("aten::clone", trace_only=True)
def aten_clone(
self: TTensor,
- memory_format: str = "", # pylint: disable=unused-argument
+ memory_format: str = "",
) -> TTensor:
"""clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor"""
@@ -1642,13 +1774,6 @@ def aten_combinations(
raise NotImplementedError()
-@torch_op("aten::complex", private=True)
-def _aten_complex(real: TFloat, imag: TFloat) -> TFloat:
- """Non-broadcasting complex constructor."""
-
- return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)
-
-
@torch_op("aten::complex", trace_only=True)
def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""complex(Tensor real, Tensor imag) -> Tensor"""
@@ -1658,33 +1783,16 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
real = op.Expand(real, broadcasted_shape)
imag = op.Expand(imag, broadcasted_shape)
- return _aten_complex(real, imag)
-
-
-@torch_op("aten::concat")
-def aten_concat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
- """concat(Tensor[] tensors, int dim=0) -> Tensor"""
-
- # TODO(justinchuby): Combine the implementation with cat
- return op.ConcatFromSequence(tensors, axis=dim)
-
-
-@torch_op("aten::concatenate")
-def aten_concatenate(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
- """concatenate(Tensor[] tensors, int dim=0) -> Tensor"""
-
- # TODO(justinchuby): Combine the implementation with cat
- return op.ConcatFromSequence(tensors, axis=dim)
+ return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)
-@torch_op("aten::conj")
+@torch_op("aten::conj", trace_only=True)
def aten_conj(self: TTensor) -> TTensor:
"""conj(Tensor(a) self) -> Tensor(a)"""
return op.Identity(self)
-@torch_op("aten::conj", complex=True, private=True)
def _complex_conjugate(self: TFloat) -> TFloat:
zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
@@ -1703,8 +1811,6 @@ def _complex_conjugate(self: TFloat) -> TFloat:
def aten_conj_complex(self: TFloat) -> TFloat:
"""conj(Tensor(a) self) -> Tensor(a)"""
- # TODO(#834): Allow calling scripted functions from other
- # scripted functions and remove trace only.
return _complex_conjugate(self)
@@ -1749,10 +1855,10 @@ def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTens
return op.Pad(self, onnx_padding, value)
-@torch_op("aten::contiguous")
+@torch_op("aten::contiguous", trace_only=True)
def aten_contiguous(
self: TTensor,
- memory_format: str = "contiguous_format", # pylint: disable=unused-argument
+ memory_format: str = "contiguous_format",
) -> TTensor:
"""contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)"""
@@ -1940,24 +2046,32 @@ def aten_convolution(
) -> TFloat:
"""convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor"""
+ rank = len(input.shape)
+
+ image_d = rank - 2
+
+ # NOTE: We assume the sequence padding/dilation/stride
+ # from ATen op can only be either len == 1 or
+ # len == rank.
+
if not isinstance(padding, Sequence):
- padding = (padding, padding)
+ padding = [padding] * image_d
+ elif len(padding) == 1:
+ padding = [padding[0]] * image_d
pads = [*padding, *padding]
if not isinstance(dilation, Sequence):
- dilation = (dilation, dilation)
+ dilation = [dilation] * image_d
+ elif len(dilation) == 1:
+ dilation = [dilation[0]] * image_d
dilations = list(dilation)
if not isinstance(stride, Sequence):
- stride = (stride, stride)
+ stride = [stride] * image_d
+ elif len(stride) == 1:
+ stride = [stride[0]] * image_d
strides = list(stride)
- if bias is None:
- weight_dim_0 = op.Shape(weight, start=0, end=1)
- bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1]))
- zero = op.CastLike(0.0, input)
- bias = op.Expand(zero, bias_shape)
-
result = _aten_convolution_onnx(
input,
weight,
@@ -1973,12 +2087,11 @@ def aten_convolution(
return result
-@torch_op("aten::convolution", private=True, traceable=True)
def _aten_convolution_onnx(
input: TFloat,
weight: TFloat,
bias: TFloat,
- transposed: BOOL,
+ transposed: bool,
strides: Sequence[int],
pads: Sequence[int],
dilations: Sequence[int],
@@ -1992,7 +2105,7 @@ def _aten_convolution_onnx(
# Alternatively we could cast transposed to BOOL.
# E.g. `if op.Cast(transposed, BOOL.dtype): ...`
- no_batch = Rank(input) != Rank(weight)
+ no_batch = len(input.shape) != len(weight.shape)
if no_batch:
input = op.Unsqueeze(input, op.Constant(value_ints=[0]))
@@ -2076,11 +2189,11 @@ def aten_convolution_overrideable(
raise NotImplementedError()
-@torch_op("aten::copy")
+@torch_op("aten::copy", trace_only=True)
def aten_copy(
self: TTensor,
src: TTensor2,
- non_blocking: bool = False, # pylint: disable=unused-argument
+ non_blocking: bool = False,
) -> TTensor:
"""copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor"""
@@ -2091,7 +2204,11 @@ def aten_copy(
def aten__to_copy(
self: TTensor,
dtype: int = -1,
- non_blocking: bool = False, # pylint: disable=unused-argument
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+ non_blocking: bool = False,
+ memory_format: str = "",
) -> TTensor:
"""_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"""
@@ -2113,14 +2230,14 @@ def aten_corrcoef(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::cos")
+@torch_op("aten::cos", trace_only=True)
def aten_cos(self: TFloat) -> TFloat:
"""cos(Tensor self) -> Tensor"""
return op.Cos(self)
-@torch_op("aten::cosh")
+@torch_op("aten::cosh", trace_only=True)
def aten_cosh(self: TFloat) -> TFloat:
"""cosh(Tensor self) -> Tensor"""
@@ -2164,23 +2281,13 @@ def aten_cov(
raise NotImplementedError()
-@torch_op("aten::cross")
+@torch_op(("aten::cross", "aten::linalg_cross"))
def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""
- zero = op.Constant(value_ints=[0])
- one = op.Constant(value_ints=[1])
- two = op.Constant(value_ints=[2])
- three = op.Constant(value_ints=[3])
- axes = op.Expand(dim, op.Constant(value_ints=[1]))
-
# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
- a1 = op.Slice(self, zero, one, axes)
- a2 = op.Slice(self, one, two, axes)
- a3 = op.Slice(self, two, three, axes)
- b1 = op.Slice(other, zero, one, axes)
- b2 = op.Slice(other, one, two, axes)
- b3 = op.Slice(other, two, three, axes)
+ a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3)
+ b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3)
# Broadcasting is implicitly supported by Mul
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
@@ -2390,18 +2497,11 @@ def aten_cumsum(
cast = self
else:
cast = op.Cast(self, to=dtype)
- return _aten_cumsum_onnx(cast, dim)
-
-
-@torch_op("aten::cumsum", private=True, traceable=True)
-def _aten_cumsum_onnx(
- self: TRealUnlessInt16OrInt8, dim: Union[INT32, INT64]
-) -> TRealUnlessInt16OrInt8:
- if IsScalar(self):
+ if len(self.shape) == 0:
# A scalar
- result = op.Identity(self)
+ result = op.Identity(cast)
else:
- result = op.CumSum(self, dim)
+ result = op.CumSum(cast, dim)
return result
@@ -2411,7 +2511,7 @@ def aten_data(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::deg2rad", traceable=True)
+@torch_op("aten::deg2rad", trace_only=True)
def aten_deg2rad(self: TFloat) -> TFloat:
"""deg2rad(Tensor self) -> Tensor"""
@@ -2424,7 +2524,7 @@ def aten_dense_dim(self: TensorType) -> int:
raise NotImplementedError()
-@torch_op("aten::detach")
+@torch_op("aten::detach", trace_only=True)
def aten_detach(self: TensorType) -> TensorType:
"""detach(Tensor(a) self) -> Tensor(a)"""
@@ -2458,87 +2558,10 @@ def aten_diagflat(self: TensorType, offset: int = 0) -> TensorType:
@torch_op(("aten::diagonal", "aten::diagonal_copy"), trace_only=True)
-def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TReal:
+def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TTensor:
"""diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)"""
- # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims
- # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1
- # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2
- # [0,1,2] -> [0,1,2] when dim1=1 and dim2=2
- if dim1 < 0:
- dim1 = dim1 + len(self.shape)
- if dim2 < 0:
- dim2 = dim2 + len(self.shape)
-
- self_rank = len(self.shape)
- perm = list(range(self_rank))
- perm.remove(dim1)
- perm.remove(dim2)
- perm.append(dim1)
- perm.append(dim2)
-
- # If rank=2, then axes=[0]; if rank=3, then axes=[1]
- # This is because computing diagonal sum is on dim2 after transpose by perm
- axes = [self_rank - 2]
-
- return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes)
-
-
-@torch_op("aten::diagonal", private=True, traceable=True)
-def _aten_diagonal_onnx(
- self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
-) -> TTensor:
- neg_1 = op.Constant(value_ints=[-1])
- dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
- dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
- mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
- tmp_tensor = op.ConstantOfShape(mask_shape)
- mask = op.EyeLike(tmp_tensor, k=offset)
- mask = op.CastLike(mask, self)
- self_t = op.Transpose(self, perm=perm)
- result = op.Mul(self_t, mask)
- result = op.ReduceSum(result, keepdims=False, axes=axes)
- # min(row, col)
- min_dim_size = op.Min(dim1_size, dim2_size)
- # take 2 tensors as example:
- # one is 3x5 in size, min_dim_size = 3, dim1_size = 3
- # the other is 5x3 in size, min_dim_size = 3, dim1_size = 5
- # 3 rows x 5 cols 5 rows x 3 cols
- # offset diagonal offset diagonal
- # ---------------- ----------------
- # -4 0 -6 0
- # -3 0 -5 0
- # -2 1 -4 1
- # -1 2 -3 2
- # 0 3 -2 3
- # 1 3 -1 3
- # 2 3 0 3
- # 3 2 1 2
- # 4 1 2 1
- # 5 0 3 0
- # 6 0 4 0
-
- # From above table, we can get the logic below
- if offset < 0:
- # row + offset
- length = dim1_size + offset
- start = op.Constant(value_ints=[0])
- else: # offset >= 0
- # col - offset
- length = dim2_size - offset
- start = op.Reshape(op.Constant(value_int=offset), neg_1)
-
- # max(min(length, min(row, col)), 0)
- length = op.Max(op.Min(length, min_dim_size), 0)
- end = start + length
- result = op.Slice(result, start, end, axes=axes)
-
- return result
-
-
-@torch_op("aten::diagonal", trace_only=True)
-def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1) -> BOOL:
- """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)"""
+ is_bool = self.dtype == BOOL.dtype
# perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims
# [0,1,2] -> [2,0,1] when dim1=0 and dim2=1
@@ -2560,23 +2583,21 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1
# This is because computing diagonal sum is on dim2 after transpose by perm
axes = [self_rank - 2]
- return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes)
-
-
-@torch_op("aten::diagonal", private=True)
-def _aten_diagonal_bool_onnx(
- self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
-) -> BOOL:
neg_1 = op.Constant(value_ints=[-1])
dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
- tmp_tensor = op.ConstantOfShape(mask_shape)
- mask = op.EyeLike(tmp_tensor, k=offset)
- self_int = op.Cast(self, to=INT64.dtype)
- mask_int = op.Cast(mask, to=INT64.dtype)
- self_int_t = op.Transpose(self_int, perm=perm)
- result = op.Mul(self_int_t, mask_int)
+ mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset)
+
+ if is_bool:
+ self_int = op.Cast(self, to=INT64.dtype)
+ mask_int = op.Cast(mask, to=INT64.dtype)
+ self_int_t = op.Transpose(self_int, perm=perm)
+ result = op.Mul(self_int_t, mask_int)
+ else:
+ mask = op.CastLike(mask, self)
+ self_t = op.Transpose(self, perm=perm)
+ result = op.Mul(self_t, mask)
result = op.ReduceSum(result, keepdims=False, axes=axes)
# min(row, col)
min_dim_size = op.Min(dim1_size, dim2_size)
@@ -2599,20 +2620,23 @@ def _aten_diagonal_bool_onnx(
# 6 0 4 0
# From above table, we can get the logic below
+ offset_val = op.Constant(value_ints=[offset])
if offset < 0:
# row + offset
- length = dim1_size + offset
+ length = op.Add(dim1_size, offset_val)
start = op.Constant(value_ints=[0])
else: # offset >= 0
# col - offset
- length = dim2_size - offset
- start = op.Reshape(op.Constant(value_int=offset), neg_1)
+ length = op.Sub(dim2_size, offset_val)
+ start = offset_val
# max(min(length, min(row, col)), 0)
- length = op.Max(op.Min(length, min_dim_size), 0)
- end = start + length
+ length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))
+ end = op.Add(start, length)
result = op.Slice(result, start, end, axes=axes)
- result = op.Cast(result, to=BOOL.dtype)
+
+ if is_bool:
+ result = op.Cast(result, to=BOOL.dtype)
return result
@@ -2667,17 +2691,14 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
@torch_op(
(
- "aten::div",
"aten::div.Tensor",
"aten::div.Scalar",
- # When rounding_mode is None, performs a true division
- # https://pytorch.org/docs/stable/generated/torch.div.html
- "aten::div.Tensor_mode",
- "aten::div.Scalar_mode",
- "aten::divide",
- "aten::true_divide",
- "_operator::truediv",
- )
+ "aten::divide.Tensor",
+ "aten::divide.Scalar",
+ "aten::true_divide.Tensor",
+ "aten::true_divide.Scalar",
+ ),
+ trace_only=True,
)
def aten_div(self: TFloat, other: TFloat) -> TFloat:
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""
@@ -2686,14 +2707,19 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat:
return op.Div(self, other)
+@torch_op("_operator::truediv", trace_only=True)
+def operator_truediv(self: TensorType, other: TensorType) -> FLOAT:
+ return op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))
+
+
@torch_op(
(
- "aten::div",
"aten::div.Tensor",
"aten::div.Scalar",
- "aten::divide",
- "aten::true_divide",
- "_operator::truediv",
+ "aten::divide.Tensor",
+ "aten::divide.Scalar",
+ "aten::true_divide.Tensor",
+ "aten::true_divide.Scalar",
),
complex=True,
)
@@ -2721,55 +2747,51 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat:
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
-def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
+def aten_div_mode(self: TReal, other: TReal, rounding_mode: Optional[str] = None) -> TReal:
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""
- # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
- assert rounding_mode in {"trunc", "floor"}
-
- if rounding_mode == "trunc":
- # Rounds the results of the division towards zero.
- # Equivalent to C-style integer division
- result = aten_trunc(op.Div(self, other))
- else: # rounding_mode == "floor"
- result = op.Floor(op.Div(self, other))
-
- return result
+ assert rounding_mode in {"trunc", "floor", None}
+ if self.dtype.is_integer():
+ quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))
-@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
-def aten_div_mode_int(self: TInt, other: TInt, rounding_mode: str) -> TInt:
- """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+ if rounding_mode == "trunc":
+ # Rounds the results of the division towards zero.
+ # Equivalent to C-style integer division
+ result = aten_trunc(quotient)
+ return op.CastLike(result, self)
+ if rounding_mode == "floor":
+ result = op.Floor(quotient)
+ return op.CastLike(result, self)
- Variant for integer inputs.
- """
- # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison
- assert rounding_mode in {"trunc", "floor"}
+ assert rounding_mode is None
+ # When rounding_mode is None, the return type is float32
+ return quotient
- quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))
+ # Float inputs
if rounding_mode == "trunc":
# Rounds the results of the division towards zero.
# Equivalent to C-style integer division
- result = aten_trunc(quotient)
- else: # rounding_mode == "floor"
- result = op.Floor(quotient)
+ return aten_trunc(op.Div(self, other))
+ if rounding_mode == "floor":
+ return op.Floor(op.Div(self, other))
- return op.CastLike(result, self)
+ return op.Div(self, other)
-@torch_op("aten::dot")
+@torch_op("aten::dot", trace_only=True)
def aten_dot(self: TFloat, tensor: TFloat) -> TFloat:
"""dot(Tensor self, Tensor tensor) -> Tensor"""
return op.MatMul(self, tensor)
-@torch_op("aten::dropout", traceable=True)
+@torch_op("aten::dropout", trace_only=True)
def aten_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
"""dropout(Tensor input, float p, bool train) -> Tensor"""
- if IsScalar(input):
+ if len(input.shape) == 0:
input = op.Reshape(input, op.Constant(value_ints=[-1]))
result, _ = op.Dropout(input, p, train)
result = op.Squeeze(result)
@@ -2789,7 +2811,7 @@ def aten_dstack(tensors: Sequence[TensorType]) -> TensorType:
def aten_einsum(
equation: str,
tensors: Sequence[TReal],
- path: Optional[int] = None, # pylint: disable=unused-argument
+ path: Optional[int] = None,
) -> TReal:
"""einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor"""
@@ -2797,14 +2819,14 @@ def aten_einsum(
return op.Einsum(*tensors, equation=equation)
-@torch_op("aten::embedding")
+@torch_op("aten::embedding", trace_only=True)
def aten_embedding(
weight: TTensor,
- indices: TTensor,
+ indices: TInt,
padding_idx: int = -1,
scale_grad_by_freq: bool = False,
sparse: bool = False,
-): # pylint: disable=unused-argument
+) -> TTensor:
# embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
return op.Gather(weight, indices)
@@ -2828,9 +2850,9 @@ def aten_embedding_bag(
weight: TFloat,
indices: INT64,
offsets: INT64,
- scale_grad_by_freq: bool = False, # pylint: disable=unused-argument
+ scale_grad_by_freq: bool = False,
mode: int = 0, # [0,1,2] indicate ["sum", "mean", "max"]
- sparse: bool = False, # pylint: disable=unused-argument
+ sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
@@ -2867,8 +2889,8 @@ def _aten_embedding_bag_onnx(
indices_1d = op.Reshape(indices, neg_1)
# Get weight out according to indices_1d,
new_weight = op.Gather(weight, indices_1d)
- # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
- new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1))
+ # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
+ new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=[1]))
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
indices_size = op.Shape(indices_1d)
@@ -2962,9 +2984,9 @@ def aten_embedding_bag_padding_idx(
weight: TFloat,
indices: INT64,
offsets: INT64,
- scale_grad_by_freq: bool = False, # pylint: disable=unused-argument
+ scale_grad_by_freq: bool = False,
mode: int = 0, # [0,1,2] indicate ["sum", "mean", "max"]
- sparse: bool = False, # pylint: disable=unused-argument
+ sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: int = -1,
@@ -2974,9 +2996,9 @@ def aten_embedding_bag_padding_idx(
We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
- assert (
- padding_idx is not None
- ), "padding_idx must not be None. This is likely a dispatcher error"
+ assert padding_idx is not None, (
+ "padding_idx must not be None. This is likely a dispatcher error"
+ )
if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
@@ -3007,8 +3029,8 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
# Get weight out according to indices,
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
indices_weight = op.Gather(weight, indices)
- # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
- indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1))
+ # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
+ indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=[1]))
# The element in sequence must be FLOAT32 dtype due to ORT bug
indices_weight = op.Cast(indices_weight, to=FLOAT.dtype)
@@ -3101,7 +3123,7 @@ def aten_embedding_dense_backward(
raise NotImplementedError()
-@torch_op("aten::embedding_renorm", traceable=True)
+@torch_op("aten::embedding_renorm", trace_only=True)
def aten_embedding_renorm(
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
) -> TFloat:
@@ -3150,35 +3172,42 @@ def aten_embedding_sparse_backward(
raise NotImplementedError()
-@torch_op(("aten::empty", "aten::empty.memory_format"))
-def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
- # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+@torch_op("aten::empty.memory_format", trace_only=True)
+def aten_empty(
+ size: Sequence[INT64],
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+ memory_format: str = "",
+) -> TensorType: # type: ignore[type-var]
+ """empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
+ if dtype == -1:
+ dtype = FLOAT.dtype
- # using Zeros to simulate np.empty()
- size = op.Cast(size, to=INT64.dtype)
- zero = op.Constant(value_float=0.0)
- zero = op.Cast(zero, to=dtype)
+ # using Zeros to simulate empty()
+ zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
+ size = common_ops.merge_dims(size)
return op.Expand(zero, size)
@torch_op("aten::empty_like", trace_only=True)
-def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor:
+def aten_empty_like(
+ self: TTensor,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+ memory_format: str = "",
+) -> TTensor:
"""empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
- # NOTE: trace_only because both if branches need to be the same type, but we have
- # a cast in the if branch.
-
- if dtype == -1:
+ if dtype == -1 or dtype is None:
zero = op.CastLike(0, self)
else:
zero = op.Cast(0, to=dtype)
- return _aten_empty_like_onnx(self, zero)
-
-
-@torch_op("aten::empty_like", private=True)
-def _aten_empty_like_onnx(self: TTensor, zero) -> TTensor:
shape = op.Shape(self)
return op.Expand(zero, shape)
@@ -3191,28 +3220,32 @@ def aten_empty_quantized(
raise NotImplementedError()
-@torch_op("aten::empty_strided")
+@torch_op("aten::empty_strided", trace_only=True)
def aten_empty_strided(
- size: INT64,
- stride: INT64, # pylint: disable=unused-argument
+ size: Sequence[INT64],
+ stride: INT64,
+ layout: str = "",
+ dtype: int = FLOAT.dtype,
+ device: str = "",
+ pin_memory: bool = False,
) -> TTensor: # type: ignore[type-var]
# empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
# using Zeros to simulate empty()
- size = op.Cast(size, to=INT64.dtype)
- zero = op.Constant(value_float=0.0)
+ zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
+ size = common_ops.merge_dims(size)
return op.Expand(zero, size)
-@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar"))
+@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq"), trace_only=True)
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
"""eq.Tensor(Tensor self, Tensor other) -> Tensor"""
return op.Equal(self, other)
-@torch_op("aten::equal")
+@torch_op("aten::equal", trace_only=True)
def aten_equal(self: TTensor, other: TTensor) -> BOOL:
"""equal(Tensor self, Tensor other) -> bool"""
@@ -3231,14 +3264,14 @@ def aten_erfinv(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::exp")
+@torch_op("aten::exp", trace_only=True)
def aten_exp(self: TFloat) -> TFloat:
"""exp(Tensor self) -> Tensor"""
return op.Exp(self)
-@torch_op("aten::exp2", traceable=True)
+@torch_op("aten::exp2", trace_only=True)
def aten_exp2(self: TFloat) -> TFloat:
"""exp2(Tensor self) -> Tensor"""
@@ -3247,17 +3280,18 @@ def aten_exp2(self: TFloat) -> TFloat:
return op.Pow(two, self)
-@torch_op("aten::expand")
-def aten_expand(self: TTensor, size: TInt) -> TTensor:
+@torch_op("aten::expand", trace_only=True)
+def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor:
"""expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
- size = op.Cast(size, to=INT64.dtype)
# NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1.
# To support -1 dim, we need to convert -1 to 1.
- size = op.Abs(size)
- return op.Expand(self, size)
+ # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely
+ # and isn't expected to appear from correct usages of SymInt.
+ size = [1 if isinstance(s, int) and s == -1 else s for s in size]
+ return op.Expand(self, common_ops.merge_dims(size))
-@torch_op("aten::expand_as", traceable=True)
+@torch_op("aten::expand_as", trace_only=True)
def aten_expand_as(self: TTensor, other: TTensor) -> TTensor:
"""expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
@@ -3272,12 +3306,6 @@ def aten_expand_copy(self: TensorType, size: INT64, implicit: bool = False) -> T
raise NotImplementedError()
-def aten_expm1(self: TensorType) -> TensorType:
- """expm1(Tensor self) -> Tensor"""
-
- raise NotImplementedError()
-
-
def aten_eye(n: int) -> TensorType:
"""eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
@@ -3418,15 +3446,14 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType
raise NotImplementedError()
-@torch_op(("aten::fill", "aten::fill.Tensor"))
-def aten_fill(self: TTensor, value: TTensor) -> TTensor:
+@torch_op(("aten::fill.Tensor", "aten::fill.Scalar"))
+def aten_fill(self: TTensor, value: TTensor2) -> TTensor:
"""fill.Tensor(Tensor self, Tensor value) -> Tensor"""
- # after fill, the self Tensor should keep origianl type
+ # Cast the value before Expand so it can be constant folded
+ value = op.CastLike(value, self)
shape = op.Shape(self)
- expanded = op.Expand(value, shape)
- result = op.CastLike(expanded, self)
- return result
+ return op.Expand(value, shape)
def aten_fix(self: TensorType) -> TensorType:
@@ -3435,17 +3462,63 @@ def aten_fix(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::flip")
-def aten_flip(self: TTensor, dims: INT64) -> TTensor:
+@torch_op("aten::flatten.using_ints", trace_only=True)
+def aten_flatten(self: TTensor, start_dim: int = 0, end_dim: int = -1) -> TTensor:
+ """flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"""
+ dim = len(self.shape)
+ if dim == 1:
+ return op.Identity(self)
+ # use ONNX's Flatten operator for cases where the output shape is 2D
+ if start_dim == 1:
+ if end_dim in (-1, dim - 1):
+ return op.Flatten(self, axis=start_dim)
+ elif start_dim == 0:
+ if end_dim in (-2, dim - 2):
+ return op.Flatten(self, axis=end_dim + 1)
+
+ # if end_dim is negative add dim
+ if end_dim < 0:
+ end_dim = dim + end_dim
+
+ input_size = op.Shape(self)
+ dim_head = op.Slice(
+ input_size,
+ op.Constant(value_ints=[0]),
+ op.Constant(value_ints=[start_dim]),
+ op.Constant(value_ints=[0]),
+ )
+ final_dims = [dim_head, op.Constant(value_ints=[-1])]
+ if end_dim < dim - 1:
+ dim_tail = op.Slice(
+ input_size,
+ op.Constant(value_ints=[end_dim + 1]),
+ op.Constant(value_ints=[dim]),
+ op.Constant(value_ints=[0]),
+ )
+ final_dims = [
+ dim_head,
+ op.Constant(value_ints=[-1]),
+ dim_tail,
+ ]
+
+ final_shape = op.Concat(*final_dims, axis=0)
+ return op.Reshape(self, final_shape)
+
+
+@torch_op("aten::flip", trace_only=True)
+def aten_flip(self: TTensor, dims: Sequence[int]) -> TTensor:
"""flip(Tensor self, int[] dims) -> Tensor"""
- shape_dim = op.Shape(dims)
- neg_1 = op.Constant(value_int=-1)
- starts = op.Expand(neg_1, shape_dim) # something like [-1, -1, -1]
- steps = op.Expand(neg_1, shape_dim) # something like [-1, -1, -1]
- ends = op.Expand(_INT64_MIN, shape_dim) # something like [-xxx, -xxx, -xxx]
- result = op.Slice(self, starts, ends, dims, steps)
- return result
+ if not dims:
+ # Nothing to flip
+ return op.Identity(self)
+
+ rank = len(dims)
+ starts = op.Constant(value_ints=[-1] * rank) # something like [-1, -1, -1]
+ steps = starts # something like [-1, -1, -1]
+ ends = op.Constant(value_ints=[_INT64_MIN] * rank) # something like [-xxx, -xxx, -xxx]
+ dims = op.Constant(value_ints=dims)
+ return op.Slice(self, starts, ends, dims, steps)
def aten_fliplr(self: TensorType) -> TensorType:
@@ -3460,25 +3533,49 @@ def aten_flipud(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::floor")
-def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::floor", trace_only=True)
+def aten_floor(self: TFloat) -> TFloat:
"""floor(Tensor self) -> Tensor"""
return op.Floor(self)
-@torch_op("math::floor")
-def python_math_floor(self: TFloatOrBFloat16) -> TInt:
+@torch_op("math::floor", trace_only=True)
+def python_math_floor(self: TFloat) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)
-@torch_op(("aten::floor_divide", "_operator::floordiv"))
-def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
+@torch_op("aten::floor_divide", trace_only=True)
+def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
- return op.Floor(op.Div(self, other))
+ if self.dtype.is_floating_point():
+ return op.Floor(op.Div(self, other))
+
+ assert self.dtype.is_integer()
+
+ if not self.dtype.is_signed():
+ return op.Div(self, other)
+
+ # Convert truncation to flooring
+ # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70
+ # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
+ # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
+ offset = op.And(
+ op.Not(op.Equal(op.Sign(self), op.Sign(other))),
+ op.Cast(op.Mod(self, other), to=BOOL.dtype),
+ )
+ offset = op.Cast(offset, to=self.dtype)
+ return op.Sub(op.Div(self, other), offset)
+
+
+@torch_op("_operator::floordiv", trace_only=True)
+def operator_floordiv(self: INT64, other: INT64) -> INT64:
+ # We implement floor_divide only for positive inputs (using integer division)
+ # because that is the usual intended case and is the most efficient.
+ return op.Div(self, other)
def aten_fmax(self: TensorType, other: TensorType) -> TensorType:
@@ -3493,14 +3590,14 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::fmod")
+@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"), trace_only=True)
def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8:
"""fmod.Tensor(Tensor self, Tensor other) -> Tensor"""
return op.Mod(self, other, fmod=1)
-@torch_op("aten::frac")
+@torch_op("aten::frac", trace_only=True)
def aten_frac(self: TFloat) -> TFloat:
"""frac(Tensor self) -> Tensor
@@ -3531,32 +3628,41 @@ def aten_from_file(
raise NotImplementedError()
-@torch_op("aten::full")
-def aten_full(size: INT64, fill_value: FLOAT, dtype: int = FLOAT.dtype):
+@torch_op("aten::full", trace_only=True)
+def aten_full(
+ size: Union[INT64, INT32],
+ fill_value: TensorType,
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
+ if dtype != -1:
+ fill_value = op.Cast(fill_value, to=dtype)
+
size = op.Cast(size, to=INT64.dtype)
- fill_value = op.Cast(fill_value, to=dtype)
return op.Expand(fill_value, size)
-@torch_op("aten::full_like")
-def aten_full_like(self, fill_value: TTensor) -> TTensor:
+@torch_op("aten::full_like", trace_only=True)
+def aten_full_like(
+ self: TensorType,
+ fill_value: TensorType,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
- fill_value = op.CastLike(fill_value, self)
- self_shape = op.Shape(self)
-
- return op.Expand(fill_value, self_shape)
-
-
-@torch_op("aten::full_like")
-def aten_full_like_dtype(self, fill_value: TTensor, dtype: int) -> TTensor:
- """full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
+ if dtype == -1:
+ fill_value = op.CastLike(fill_value, self)
+ else:
+ fill_value = op.Cast(fill_value, to=dtype)
- fill_value = op.Cast(fill_value, to=dtype)
self_shape = op.Shape(self)
-
return op.Expand(fill_value, self_shape)
@@ -3580,25 +3686,30 @@ def aten_fused_moving_avg_obs_fake_quant(
raise NotImplementedError()
-@torch_op("aten::gather", traceable=True)
+@torch_op("aten::gather", trace_only=True)
def aten_gather(
self: TReal,
dim: int,
index: TInt,
- sparse_grad: bool = False, # pylint: disable=unused-argument
+ sparse_grad: bool = False,
) -> TReal:
"""gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"""
- if IsScalar(index): # When (index) is empty, return (self)
- result = self
- else:
- if IsScalar(self): # Unsqueeze for GatherElements op
- self = op.Reshape(self, op.Constant(value_ints=[-1]))
- if op.Size(index) == 0: # Return empty array
- result = op.CastLike(index, self)
+ if len(self.shape) == 0:
+ if len(index.shape) == 0:
+ return op.Identity(self)
else:
- index = op.Cast(index, to=INT64.dtype)
- result = op.GatherElements(self, index, axis=dim)
+ return op.Expand(self, op.Shape(index))
+
+ is_scalar_index = len(index.shape) == 0
+ if is_scalar_index:
+ index = op.Unsqueeze(index, [0])
+
+ index = op.Cast(index, to=INT64.dtype)
+ result = op.GatherElements(self, index, axis=dim)
+
+ if is_scalar_index:
+ result = op.Squeeze(result, [0])
return result
@@ -3617,25 +3728,21 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
@torch_op(
- ("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal", "_operator::ge")
+ ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
+ trace_only=True,
)
-def aten_ge(self: TReal, other: TReal) -> BOOL:
+def aten_ge(self: TTensor, other: TTensor) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""
- return op.GreaterOrEqual(self, other)
-
-
-@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal"))
-def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL:
- """ge.Tensor(Tensor self, Tensor other) -> Tensor"""
+ if self.dtype == ir.DataType.BOOL:
+ # self, other, self >= other
+ # F, F, T
+ # F, T, F
+ # T, F, T
+ # T, T, T
+ return op.Or(self, op.Not(other))
- # self, other, self >= other
- # F, F, T
- # F, T, F
- # T, F, T
- # T, T, T
-
- return op.Or(self, op.Not(other))
+ return op.GreaterOrEqual(self, other)
def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]:
@@ -3650,8 +3757,7 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType:
raise NotImplementedError()
-# NOTE: The name is made up for `getitem` to be included in the registry
-@torch_op("aten::getitem")
+@torch_op(("_operator::getitem", "aten::getitem"))
def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor:
return op.SequenceAt(self, i)
@@ -3748,19 +3854,6 @@ def aten_grid_sampler_3d_backward(
raise NotImplementedError()
-def aten_group_norm(
- input: TensorType,
- num_groups: int,
- weight: Optional[TensorType] = None,
- bias: Optional[TensorType] = None,
- eps: float = 1e-05,
- cudnn_enabled: bool = True,
-) -> TensorType:
- """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor"""
-
- raise NotImplementedError()
-
-
def aten_gru_cell(
input: TensorType,
hx: TensorType,
@@ -3774,35 +3867,57 @@ def aten_gru_cell(
raise NotImplementedError()
-@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater", "_operator::gt"))
-def aten_gt(self: TReal, other: TReal) -> BOOL:
+@torch_op(
+ ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"),
+ trace_only=True,
+)
+def aten_gt(self: TTensor, other: TTensor) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""
- return op.Greater(self, other)
-
+ if self.dtype == ir.DataType.BOOL:
+ # self, other, self > other
+ # F, F, F
+ # F, T, F
+ # T, F, T
+ # T, T, F
-@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater"))
-def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL:
- """gt.Tensor(Tensor self, Tensor other) -> Tensor"""
- # self, other, self > other
- # F, F, F
- # F, T, F
- # T, F, T
- # T, T, F
+ return op.And(self, op.Not(other))
- return op.And(self, op.Not(other))
+ return op.Greater(self, other)
-def aten_hamming_window(window_length: int) -> TensorType:
+@torch_op("aten::hamming_window", trace_only=True)
+def aten_hamming_window(
+ window_length: int,
+ dtype: int = 1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
- raise NotImplementedError()
-
-
-def aten_hann_window(window_length: int) -> TensorType:
+ if dtype is None or dtype == -1:
+ dtype = 1
+ # ONNX uses different alpha/beta values for the Hamming window
+ # Whereas PyTorch uses alpha=0.54, beta=0.46, ONNX uses
+ # alpha=0.543478, beta=0.456522. This causes a slight difference
+ # in the output values, but we still uses the HammingWindow op for performance.
+ return op.HammingWindow(window_length, output_datatype=dtype)
+
+
+@torch_op("aten::hann_window", trace_only=True)
+def aten_hann_window(
+ window_length: int,
+ dtype: int = 1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
- raise NotImplementedError()
+ if dtype is None or dtype == -1:
+ dtype = 1
+ return op.HannWindow(window_length, output_datatype=dtype)
def aten_hardshrink(self: TensorType, lambd: float = 0.5) -> TensorType:
@@ -3819,7 +3934,7 @@ def aten_hardshrink_backward(
raise NotImplementedError()
-@torch_op("aten::heaviside")
+@torch_op("aten::heaviside", trace_only=True)
def aten_heaviside(self: TReal, values: TReal) -> TReal:
"""heaviside(Tensor self, Tensor values) -> Tensor"""
@@ -3864,7 +3979,7 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::hstack")
+# Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918
def aten_hstack(tensors: Sequence[TTensor]) -> TTensor:
"""hstack(Tensor[] tensors) -> Tensor"""
@@ -3937,7 +4052,6 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
return op.Shape(broadcasted)
-@torch_op("aten::index.Tensor", private=True, trace_only=True)
def _aten_index_onnx(
self: TensorType,
indices: Sequence[Optional[INT64]],
@@ -3965,7 +4079,7 @@ def _aten_index_onnx(
not_none_indices = [idx for idx in indices if idx is not None]
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
final_index = op.Concat(
- *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices),
+ *(op.Unsqueeze(op.Expand(idx, broadcast_shape), [-1]) for idx in not_none_indices),
axis=-1,
)
@@ -3974,7 +4088,7 @@ def _aten_index_onnx(
if _has_none_in_middle(indices):
# If there is None in the middle, Advanced Indexing cannot decide where to put
# the new dimensions. So it places them in the front, like GatherND does.
- return self
+ return op.Identity(self)
# When the indices are consecutive, Advanced Indexing will place the new dimensions
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
@@ -4103,7 +4217,7 @@ def aten_index_copy(
raise NotImplementedError()
-@torch_op(("aten::index_put", "aten::_unsafe_index_put"))
+@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
def aten_index_put(
self: TReal,
indices: Sequence[INT64],
@@ -4116,19 +4230,83 @@ def aten_index_put(
`_.
"""
- # TODO(justinchuby): Handle when indicies has more than one element
- index = op.SequenceAt(indices, 0)
- new_index = op.Unsqueeze(index, [-1])
-
- if op.Cast(accumulate, to=BOOL.dtype):
- result = op.ScatterND(self, new_index, values, reduction="add")
+ def _make_reshape_list_broadcastable(reshape_list, values_shape):
+ # Remove ones until the rank of reshape_list matches values_shape.
+ while len(reshape_list) > len(values_shape) and 1 in reshape_list:
+ reshape_list.remove(1)
+
+ # Now ensure each dimension is broadcastable:
+ # This is mandatory when mixing basic and advanced indexing
+ # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
+ # the reshape list should be : [[2, 1], [1, 3], [2, 1]]
+ for i, r in enumerate(reshape_list):
+ if r not in (1, values_shape[i]):
+ value_index = values_shape.index(r)
+ # Swap elements
+ # For the example above the current reshape list is [1, 2] for last dim,
+ # to make it broadcastable, we swap the elements
+ reshape_list[value_index], reshape_list[i] = r, 1
+
+ return reshape_list
+
+ # Ensure the number of indices matches the tensor rank.
+ self_rank = len(self.shape)
+ if len(indices) < self_rank:
+ indices = list(indices) + [None] * (self_rank - len(indices))
+
+ # Get values shape
+ values_shape = tuple(values.shape)
+
+ index_vectors = []
+ for i in range(self_rank):
+ if indices[i] is None:
+ # For a full slice along dim i, create a range index [0, self.shape[i]).
+ idx = op.Range(0, self.shape[i], 1)
+ reshape_update = self.shape[i]
+ else:
+ idx = indices[i]
+ reshape_update = math.prod(idx.shape)
+ # when Index is more than 1D, flatten it and also the values shape
+ # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
+ # Indices -> (2*4,) and values shape (2*4, 32)
+ if len(idx.shape) > 1:
+ values_shape = (reshape_update, *values_shape[len(idx.shape) :])
+
+ # Flatten index (always working with 1D index in each dim)
+ idx = op.Reshape(idx, [-1])
+
+ # Create a reshape pattern: one value per index dimension,
+ # with the current dimension set to the update size.
+ reshape_list = [1] * len(indices)
+ reshape_list[i] = reshape_update
+
+ # Adjust the reshape list to match the values shape.
+ reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)
+
+ # Reshape and expand the index.
+ idx = op.Reshape(idx, reshape_list, allowzero=True)
+ idx = op.Expand(idx, values_shape)
+
+ # Flatten the index to 1D and unsqueeze to form a column vector.
+ idx = op.Reshape(idx, [-1])
+ idx = op.Unsqueeze(idx, axes=[1])
+ index_vectors.append(idx)
+
+ # Concatenate the index vectors along axis=1 to form the final indices.
+ new_index = op.Concat(*index_vectors, axis=1)
+
+ # Flatten values to match the indices
+ flat_values = op.Reshape(values, [-1])
+
+ if accumulate:
+ result = op.ScatterND(self, new_index, flat_values, reduction="add")
else:
- result = op.ScatterND(self, new_index, values)
+ result = op.ScatterND(self, new_index, flat_values)
return result
-@torch_op("aten::index_put")
+@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
@@ -4137,37 +4315,18 @@ def aten_index_put_bool(
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""
- index = op.SequenceAt(indices, 0) # assume indices only have 1 element
- # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
- index_int = op.Cast(index, to=INT32.dtype)
- # if all False, return self
- if op.ReduceSum(index_int) == 0:
- result = self
- else:
- # change array([F,F,T,F,F]) to array([2])
- index = op.ArgMax(index_int) # assume index only have 1 True
- # change array([2]) to array([2,2,2,2,2])
- self_dim_1 = op.Shape(self, start=1, end=2)
- index_dim_0 = op.Shape(index, start=0, end=1)
- shape = op.Concat(self_dim_1, index_dim_0, axis=0)
- new_ind = op.Expand(index, shape)
- new_ind_t = op.Transpose(new_ind)
-
- # values must have same rank with input(self)
- if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator]
- values = op.Unsqueeze(values, op.Constant(value_ints=[0]))
-
- if op.Cast(accumulate, to=BOOL.dtype):
- zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
- zeros = op.CastLike(zeros, values)
- result = op.ScatterElements(zeros, new_ind_t, values)
- # FIXME: type promotion
- result = op.CastLike(result, self)
- result = op.Add(result, self)
- else:
- result = op.ScatterElements(self, new_ind_t, values)
-
- return result
+ # TODO: Support indices with more than 1 elements
+ index = indices[0]
+ # accumulate should be always False, True does not make sense but an assert would be great
+ # Reshape indices so it can be properly broadcasted
+ self_rank = len(self.shape)
+ index_rank = len(index.shape)
+ if self_rank > index_rank:
+ index_shape = op.Shape(index)
+ padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)])
+ padded_shape = op.Concat(index_shape, padding, axis=0)
+ index = op.Reshape(index, padded_shape)
+ return op.Where(index, values, self)
def aten_index_reduce(
@@ -4183,11 +4342,11 @@ def aten_index_reduce(
raise NotImplementedError()
-@torch_op("aten::index_select", traceable=True)
+@torch_op("aten::index_select", trace_only=True)
def aten_index_select(self: TTensor, dim: int, index: IntType) -> TTensor:
"""index_select(Tensor self, int dim, Tensor index) -> Tensor"""
- self_is_scalar = IsScalar(self)
+ self_is_scalar = len(self.shape) == 0
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
@@ -4258,12 +4417,15 @@ def aten_instance_norm(
if use_input_stats:
return op.InstanceNormalization(input, weight, bias, epsilon=eps)
- assert (
- running_mean is not None and running_var is not None
- ), "running_mean and running_var must be provided when use_input_stats is False"
+ assert running_mean is not None and running_var is not None, (
+ "running_mean and running_var must be provided when use_input_stats is False"
+ )
batch_size = op.Shape(input, start=0, end=1)
- bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0))
+ bn_input = op.Reshape(
+ input,
+ op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0),
+ )
weight = op.Tile(weight, batch_size)
bias = op.Tile(bias, batch_size)
running_mean = op.Tile(running_mean, batch_size)
@@ -4279,7 +4441,7 @@ def aten_instance_norm(
momentum=1.0 - momentum,
training_mode=False,
)
- return op.Reshape(norm, op.Shape(input))
+ return op.Reshape(norm, op.Shape(input), allowzero=True)
def aten_int_repr(self: TensorType) -> TensorType:
@@ -4360,21 +4522,11 @@ def aten_is_pinned(self: TensorType, device: Optional[str] = None) -> bool:
raise NotImplementedError()
-@torch_op("aten::is_same_size")
+# is_same_size is decomposed by PyTorch
def aten_is_same_size(self: TTensor, other: TTensor) -> BOOL:
"""is_same_size(Tensor self, Tensor other) -> bool"""
- # Cannot compare different shape of two tensors using op.Equal()
- # So we need to compare the rank first, if rank is same, then compare shape
- result = op.Equal(Rank(self), Rank(other))
- if result: # Same rank, then compare shape
- self_shape = op.Shape(self)
- other_shape = op.Shape(other)
- result_bool = op.Equal(self_shape, other_shape)
- result_int = op.Cast(result_bool, to=INT8.dtype)
- result = op.Cast(op.ReduceMin(result_int, keepdims=False), to=BOOL.dtype)
-
- return result
+ raise NotImplementedError
def aten_is_set_to(self: TensorType, tensor: TensorType) -> bool:
@@ -4401,7 +4553,7 @@ def aten_isclose(
other: TReal,
rtol: float = 1e-05,
atol: float = 1e-08,
- equal_nan: bool = False, # pylint: disable=unused-argument
+ equal_nan: bool = False,
) -> BOOL:
"""isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor"""
@@ -4424,7 +4576,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:
@torch_op("aten::isinf")
-def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
+def aten_isinf(self: TFloat) -> BOOL:
"""isinf(Tensor self) -> Tensor"""
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4433,14 +4585,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
@torch_op("aten::isnan")
-def aten_isnan(self: TFloatOrBFloat16) -> BOOL:
+def aten_isnan(self: TFloat) -> BOOL:
"""isnan(Tensor self) -> Tensor"""
return op.IsNaN(self)
@torch_op("aten::isneginf")
-def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
+def aten_isneginf(self: TFloat) -> BOOL:
"""isneginf(Tensor self) -> Tensor"""
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4449,7 +4601,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
@torch_op("aten::isposinf")
-def aten_isposinf(self: TFloatOrBFloat16) -> BOOL:
+def aten_isposinf(self: TFloat) -> BOOL:
"""isposinf(Tensor self) -> Tensor"""
# Added Cast inside the function so it can support all real dtypes naturally
@@ -4521,6 +4673,7 @@ def aten_layer_norm(
weight: Optional[TReal] = None,
bias: Optional[TReal] = None,
eps: float = 1e-05,
+ cudnn_enable: bool = True,
) -> TReal:
"""layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
@@ -4528,28 +4681,10 @@ def aten_layer_norm(
start_axis = -len(normalized_shape)
if weight is None:
- one = op.Constant(value_float=1.0)
+ one = op.Constant(value=ir.tensor(1, dtype=input.dtype))
weight = op.Expand(one, op.Shape(input, start=start_axis))
- if bias is None:
- zero = op.Constant(value_float=0.0)
- bias = op.Expand(zero, op.Shape(input, start=start_axis))
-
- return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps)
-
-
-@torch_op("aten::layer_norm", private=True)
-def _aten_layer_norm_onnx(
- input: TReal,
- weight: TReal,
- bias: TReal,
- axis: int,
- eps: float = 1e-05,
-) -> TReal:
- """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"""
-
- # TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982
- result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps)
+ result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps)
return result
@@ -4565,30 +4700,36 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op(("aten::le", "aten::le.Tensor", "_operator::le"))
-def aten_le(self: TReal, other: TReal) -> BOOL:
+@torch_op(
+ ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"),
+ trace_only=True,
+)
+def aten_le(self: TTensor, other: TTensor) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""
- return op.LessOrEqual(self, other)
-
-
-@torch_op(("aten::le", "aten::le.Tensor", "aten::less_equal"))
-def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
- """le.Tensor(Tensor self, Tensor other) -> Tensor"""
+ if self.dtype == ir.DataType.BOOL:
+ # self, other, self <= other
+ # F, F, T
+ # F, T, T
+ # T, F, F
+ # T, T, T
- # self, other, self <= other
- # F, F, T
- # F, T, T
- # T, F, F
- # T, T, T
+ return op.Or(other, op.Not(self))
- return op.Or(other, op.Not(self))
+ return op.LessOrEqual(self, other)
-def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
+@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar"))
+def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor:
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""
- raise NotImplementedError()
+ weight = op.CastLike(weight, self)
+ diff = op.Sub(end, self)
+ return op.Where(
+ op.Less(weight, 0.5),
+ op.Add(self, op.Mul(weight, diff)),
+ op.Sub(end, op.Mul(diff, op.Sub(1.0, weight))),
+ )
def aten_lgamma(self: TensorType) -> TensorType:
@@ -4609,7 +4750,7 @@ def aten_lift_fresh(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::lift_fresh_copy")
+@torch_op("aten::lift_fresh_copy", trace_only=True)
def aten_lift_fresh_copy(self: TensorType) -> TensorType:
"""lift_fresh_copy(Tensor self) -> Tensor"""
@@ -4626,10 +4767,19 @@ def aten_linear_backward(
@torch_op("aten::linspace", trace_only=True)
def aten_linspace(
- start: TFloat, end: TFloat, steps: int, dtype: int = FLOAT.dtype
+ start: TFloat,
+ end: TFloat,
+ steps: int,
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TensorType:
"""linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
+ if dtype == -1 or dtype is None:
+ dtype = FLOAT.dtype
+
# Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896
if steps == 0:
return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype)
@@ -4651,43 +4801,43 @@ def aten_linspace(
)
-@torch_op("aten::log")
-def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::log", trace_only=True)
+def aten_log(self: TFloat) -> TFloat:
"""log(Tensor self) -> Tensor"""
return op.Log(self)
-@torch_op("aten::log10")
-def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::log10", trace_only=True)
+def aten_log10(self: TFloat) -> TFloat:
"""log10(Tensor self) -> Tensor"""
return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self))
@torch_op("aten::log1p")
-def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+def aten_log1p(self: TFloat) -> TFloat:
"""log1p(Tensor self) -> Tensor"""
return op.Log(op.Add(self, 1.0))
-@torch_op("aten::log2")
-def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::log2", trace_only=True)
+def aten_log2(self: TFloat) -> TFloat:
"""log2(Tensor self) -> Tensor"""
return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))
-@torch_op("aten::logaddexp")
-def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::logaddexp", trace_only=True)
+def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp(Tensor self, Tensor other) -> Tensor"""
return op.Log(op.Add(op.Exp(self), op.Exp(other)))
-@torch_op("aten::logaddexp2")
-def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::logaddexp2", trace_only=True)
+def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
two = op.CastLike(2.0, self)
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
@@ -4695,11 +4845,11 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr
return op.Div(op.Log(summation), op.Log(two))
-@torch_op("aten::logcumsumexp", traceable=True)
-def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
+@torch_op("aten::logcumsumexp", trace_only=True)
+def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat:
"""logcumsumexp(Tensor self, int dim) -> Tensor"""
- if IsScalar(self):
+ if len(self.shape) == 0:
result = self
else:
# Make dim 1-d
@@ -4719,88 +4869,70 @@ def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
return result
-@torch_op("aten::logdet")
+@torch_op("aten::logdet", trace_only=True)
def aten_logdet(self: TFloat) -> TFloat:
"""logdet(Tensor self) -> Tensor"""
return op.Log(op.Det(self))
-@torch_op(
- (
- "aten::logical_and",
- "aten::bitwise_and",
- "aten::bitwise_and.Tensor",
- "aten::bitwise_and.Scalar",
- "aten::bitwise_and.Scalar_Tensor",
- )
-)
-def aten_logical_and(self: BOOL, other: BOOL) -> BOOL:
+@torch_op("aten::logical_and", trace_only=True)
+def aten_logical_and(self: TTensor, other: TTensor) -> BOOL:
"""logical_and(Tensor self, Tensor other) -> Tensor"""
- return op.And(self, other)
+ assert self.dtype == other.dtype
+ if self.dtype == ir.DataType.BOOL:
+ return op.And(self, other)
+ return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
-@torch_op(("aten::logical_not", "aten::bitwise_not"))
-def aten_logical_not(self: BOOL) -> BOOL:
+
+@torch_op("aten::logical_not", trace_only=True)
+def aten_logical_not(self: TTensor) -> BOOL:
"""logical_not(Tensor self) -> Tensor"""
- return op.Not(self)
+ if self.dtype == ir.DataType.BOOL:
+ return op.Not(self)
+ return op.Not(op.Cast(self, to=BOOL.dtype))
-@torch_op(
- (
- "aten::logical_or",
- "aten::bitwise_or",
- "aten::bitwise_or.Tensor",
- "aten::bitwise_or.Scalar",
- "aten::bitwise_or.Scalar_Tensor",
- "aten::add",
- "aten::add.Tensor",
- )
-)
-def aten_logical_or(self: BOOL, other: BOOL) -> BOOL:
+@torch_op("aten::logical_or", trace_only=True)
+def aten_logical_or(self: TTensor, other: TTensor) -> BOOL:
"""logical_or(Tensor self, Tensor other) -> Tensor"""
- return op.Or(self, other)
+ assert self.dtype == other.dtype
+ if self.dtype == ir.DataType.BOOL:
+ return op.Or(self, other)
+ return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
-@torch_op(
- (
- "aten::logical_xor",
- "aten::bitwise_xor",
- "aten::bitwise_xor.Tensor",
- "aten::bitwise_xor.Scalar",
- "aten::bitwise_xor.Scalar_Tensor",
- )
-)
-def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
+
+@torch_op("aten::logical_xor", trace_only=True)
+def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL:
"""logical_xor(Tensor self, Tensor other) -> Tensor"""
- return op.Xor(self, other)
+ assert self.dtype == other.dtype
+ if self.dtype == ir.DataType.BOOL:
+ return op.Xor(self, other)
+ return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
-@torch_op("aten::logit", private=True)
-def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
- return op.Log(op.Div(self, op.Sub(1.0, self)))
+@torch_op("aten::logit", trace_only=True)
+def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat:
+ """logit(Tensor self, float? eps=None) -> Tensor"""
+ one = ir.tensor(1, dtype=self.dtype)
-@torch_op("aten::logit", private=True)
-def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
- eps = op.CastLike(eps, self)
- one = op.CastLike(1.0, self)
- temporary_self = op.Where(self <= one - eps, self, one - eps)
- z = op.Where(temporary_self < eps, eps, temporary_self)
+ if eps is None:
+ return op.Log(op.Div(self, op.Sub(one, self)))
- return op.Log(op.Div(z, op.Sub(one, z)))
+ one_minus_eps = ir.tensor(1 - eps, dtype=self.dtype)
+ eps = ir.tensor(eps, dtype=self.dtype)
+ temporary_self = op.Where(self <= one_minus_eps, self, one_minus_eps)
+ z = op.Where(temporary_self < eps, eps, temporary_self)
-@torch_op("aten::logit", trace_only=True)
-def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16:
- """logit(Tensor self, float? eps=None) -> Tensor"""
- if eps is None:
- return _aten_logit_onnx(self)
- return _aten_logit_clamp_onnx(self, eps)
+ return op.Log(op.Div(z, op.Sub(one, z)))
def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> TensorType:
@@ -4809,11 +4941,11 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
raise NotImplementedError()
-@torch_op("aten::logsumexp", traceable=True)
+@torch_op("aten::logsumexp", trace_only=True)
def aten_logsumexp(self: TFloat, dim: INT64, keepdim: int = False) -> TFloat:
"""logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"""
- if IsScalar(self):
+ if len(self.shape) == 0:
# A scalar
result = self
else:
@@ -4821,12 +4953,6 @@ def aten_logsumexp(self: TFloat, dim: INT64, keepdim: int = False) -> TFloat:
return result
-def aten_lshift(self: TensorType, other: TensorType) -> TensorType:
- """__lshift__.Tensor(Tensor self, Tensor other) -> Tensor"""
-
- raise NotImplementedError()
-
-
def aten_lstm_cell(
input: TensorType,
hx: Sequence[TensorType],
@@ -4861,26 +4987,24 @@ def aten_lstm_mps_backward(
raise NotImplementedError()
-@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less", "_operator::lt"))
-def aten_lt(self: TReal, other: TReal) -> BOOL:
+@torch_op(
+ ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
+ trace_only=True,
+)
+def aten_lt(self: TTensor, other: TTensor) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""
+ if self.dtype == ir.DataType.BOOL:
+ # self, other, self < other
+ # F, F, F
+ # F, T, T
+ # T, F, F
+ # T, T, F
+ return op.And(other, op.Not(self))
+
return op.Less(self, other)
-@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less"))
-def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL:
- """lt.Tensor(Tensor self, Tensor other) -> Tensor"""
-
- # self, other, self < other
- # F, F, F
- # F, T, T
- # T, F, F
- # T, T, F
-
- return op.And(other, op.Not(self))
-
-
def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
"""lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor"""
@@ -4910,9 +5034,6 @@ def aten_mH(self: TRealOrUInt8) -> TRealOrUInt8:
def aten_mH_complex(self: TFloat) -> TFloat:
"""mH(Tensor(a) self) -> Tensor(a)"""
- # TODO(#834): Allow calling scripted functions from other
- # scripted functions and remove trace only.
-
# c is the last dimension being the real and imaginary parts
trasposed = op.Einsum(self, equation="...ijc->...jic")
return _complex_conjugate(trasposed)
@@ -4946,8 +5067,8 @@ def aten_margin_ranking_loss(
@torch_op(
- ("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"),
- traceable=True,
+ ("aten::masked_fill.Scalar", "aten::masked_fill.Tensor"),
+ trace_only=True,
)
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
"""masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"""
@@ -4957,10 +5078,26 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
return op.Where(mask, value_cast, self)
-def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType:
+@torch_op(("aten::masked_scatter"), trace_only=True)
+def aten_masked_scatter(self: TTensor, mask: TTensor, source: TTensor) -> TTensor:
"""masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"""
- raise NotImplementedError()
+ if len(mask.shape) < len(self.shape):
+ mask = op.Expand(mask, op.Shape(self))
+ else:
+ self = op.Expand(self, op.Shape(mask))
+ index = op.Transpose(op.NonZero(mask), perm=[1, 0])
+
+ # NOTE: source can have more elements than needed.
+ # It could also have arbitrary shape.
+ # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
+ source = op.Reshape(source, op.Constant(value_ints=[-1]))
+ axes = op.Constant(value_ints=[0])
+ starts = op.Constant(value_ints=[0])
+ ends = op.Gather(op.Shape(index), op.Constant(value_ints=[0]), axis=0)
+ source = op.Slice(source, starts, ends, axes)
+
+ return op.ScatterND(self, index, source)
def aten_masked_select(self: TensorType, mask: TensorType) -> TensorType:
@@ -4977,7 +5114,7 @@ def aten_masked_select_backward(
raise NotImplementedError()
-@torch_op("aten::matmul")
+@torch_op("aten::matmul", trace_only=True)
def aten_matmul(
self: TRealUnlessInt16OrInt8, other: TRealUnlessInt16OrInt8
) -> TRealUnlessInt16OrInt8:
@@ -5018,27 +5155,18 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::max")
+@torch_op("aten::max", trace_only=True)
def aten_max(self: TReal) -> TReal:
"""max(Tensor self) -> Tensor"""
- self_is_scalar = IsScalar(self)
- if self_is_scalar:
- self = op.Reshape(self, op.Constant(value_ints=[-1]))
-
- result = op.ReduceMax(self, keepdims=False)
-
- if self_is_scalar:
- result = op.Squeeze(result)
-
- return result
+ return op.ReduceMax(self, keepdims=False)
-@torch_op("aten::max.dim", traceable=True)
+@torch_op("aten::max.dim", trace_only=True)
def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, INT64]:
"""max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"""
- if IsScalar(self):
+ if len(self.shape) == 0:
result = self
indices = op.Constant(value_int=0)
else:
@@ -5048,18 +5176,14 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I
return result, indices
-@torch_op(("aten::maximum", "aten::max.other"))
-def aten_maximum(self: TReal, other: TReal) -> TReal:
+@torch_op("aten::maximum", trace_only=True)
+def aten_maximum(self: TTensor, other: TTensor) -> TTensor:
"""maximum(Tensor self, Tensor other) -> Tensor"""
- return op.Max(self, other)
-
-
-@torch_op(("aten::maximum", "aten::max.other"))
-def aten_maximum_bool(self: BOOL, other: BOOL) -> BOOL:
- """maximum(Tensor self, Tensor other) -> Tensor"""
+ if self.dtype == ir.DataType.BOOL:
+ return op.Or(self, other)
- return op.Or(self, other)
+ return op.Max(self, other)
@torch_op("aten::mean")
@@ -5070,16 +5194,15 @@ def aten_mean(self: TReal) -> TReal:
return op.Squeeze(result)
-@torch_op("aten::mean.dim", traceable=True)
+@torch_op("aten::mean.dim", trace_only=True)
def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
"""mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
- if IsScalar(self):
+ if len(self.shape) == 0:
result = self
else:
- if IsScalar(dim):
- dim = op.Unsqueeze(dim, axes=0)
- result = op.ReduceMean(self, dim, keepdims=keepdim)
+ dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
+ result = op.ReduceMean(self, dims, keepdims=keepdim)
return result
@@ -5095,17 +5218,17 @@ def aten_meshgrid(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::min")
+@torch_op("aten::min", trace_only=True)
def aten_min(self: TReal) -> TReal:
"""min(Tensor self) -> Tensor"""
return op.ReduceMin(self, keepdims=False)
-@torch_op("aten::min.dim", traceable=True)
+@torch_op("aten::min.dim", trace_only=True)
def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, TInt]:
"""min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"""
- if IsScalar(self):
+ if len(self.shape) == 0:
result = self
indices = op.Constant(value_int=0)
else:
@@ -5116,18 +5239,14 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T
return result, indices
-@torch_op(("aten::minimum", "aten::min.other"))
-def aten_minimum(self: TReal, other: TReal) -> TReal:
+@torch_op("aten::minimum", trace_only=True)
+def aten_minimum(self: TTensor, other: TTensor) -> TTensor:
"""minimum(Tensor self, Tensor other) -> Tensor"""
- return op.Min(self, other)
-
+ if self.dtype == ir.DataType.BOOL:
+ return op.And(self, other)
-@torch_op(("aten::minimum", "aten::min.other"))
-def aten_minimum_bool(self: BOOL, other: BOOL) -> BOOL:
- """minimum(Tensor self, Tensor other) -> Tensor"""
-
- return op.And(self, other)
+ return op.Min(self, other)
def aten_miopen_batch_norm(
@@ -5398,7 +5517,7 @@ def aten_mkldnn_max_pool3d_backward(
raise NotImplementedError()
-@torch_op("aten::mm")
+@torch_op("aten::mm", trace_only=True)
def aten_mm(
self: TRealUnlessInt16OrInt8, mat2: TRealUnlessInt16OrInt8
) -> TRealUnlessInt16OrInt8:
@@ -5466,27 +5585,28 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"))
-def aten_mul(self: TReal, other: TReal) -> TReal:
+@torch_op(
+ ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
+ trace_only=True,
+)
+def aten_mul(self: TTensor, other: TTensor) -> TTensor:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
- return op.Mul(self, other)
-
-
-@torch_op(("aten::mul", "aten::mul.Tensor"))
-def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
- """ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""
+ if self.dtype == ir.DataType.BOOL:
+ return op.And(self, other)
- # TODO(justinchuby): Handle cases where type reconcilation is not enough,
- # since different ONNX operators are used based on different data types.
-
- return op.And(self, other)
+ return op.Mul(self, other)
-@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), complex=True)
+@torch_op(
+ ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"),
+ trace_only=True,
+ complex=True,
+)
def aten_mul_complex(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
+ # TODO(justinchuby): Maybe use Split to simplify the logic
self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
other_real = op.Slice(other, [0], [1], axes=[-1])
@@ -5506,22 +5626,22 @@ def aten_mul_complex(self: TReal, other: TReal) -> TReal:
return op.Concat(real, imag, axis=-1)
-@torch_op("aten::multinomial")
+@torch_op("aten::multinomial", trace_only=True)
def aten_multinomial(
self: TFloat,
num_samples: int,
- replacement: bool = False, # pylint: disable=unused-argument
+ replacement: bool = False,
) -> TInt:
"""multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor"""
# ONNX Multinomial doesn't support 1D input
- if Rank(self) == 1:
+ if len(self.shape) == 1:
unsqueezed_input = op.Unsqueeze(self, axes=0)
else:
unsqueezed_input = self
# ONNX multinomial expects log probability
log_input = op.Log(unsqueezed_input)
result = op.Multinomial(log_input, dtype=INT64.dtype, sample_size=num_samples)
- if Rank(self) == 1:
+ if len(self.shape) == 1:
result = op.Squeeze(result)
return result
@@ -5532,10 +5652,11 @@ def aten_multiply(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
+@torch_op("aten::mv", trace_only=True)
def aten_mv(self: TensorType, vec: TensorType) -> TensorType:
"""mv(Tensor self, Tensor vec) -> Tensor"""
- raise NotImplementedError()
+ return op.MatMul(self, vec)
def aten_mvlgamma(self: TensorType, p: int) -> TensorType:
@@ -5595,18 +5716,13 @@ def aten_nansum(
raise NotImplementedError()
-@torch_op("aten::narrow", traceable=True)
+@torch_op("aten::narrow", trace_only=True)
def aten_narrow(self: TTensor, dim: INT64, start: INT64, length: INT64) -> TTensor:
"""narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)"""
- if IsScalar(dim):
- dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
-
- if IsScalar(start):
- start = op.Reshape(start, op.Constant(value_ints=[-1]))
-
- if IsScalar(length):
- length = op.Reshape(length, op.Constant(value_ints=[-1]))
+ dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
+ start = op.Reshape(start, op.Constant(value_ints=[-1]))
+ length = op.Reshape(length, op.Constant(value_ints=[-1]))
end = op.Add(start, length)
return op.Slice(self, start, end, dim)
@@ -5682,14 +5798,18 @@ def aten_native_batch_norm(
axes.pop(1)
axes = op.Constant(value_ints=axes)
if running_mean is None: # Using input mean
- running_mean = op.Squeeze(op.ReduceMean(input, axes))
+ running_mean = op.ReduceMean(input, axes, keepdims=False)
if running_var is None: # Using input var
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
- running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))
+ running_var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False)
+ # TODO: This is a temporary fix for the issue that BatchNormalization
+ # is forced to be in training mode in PyTorch, and ORT currently
+ # only supports training mode with opset version lower than 14.
+ training = False
# We have to split to two private functions, because BatchNormalization returns
# three outputs when training_mode=True and one when it is False.
if training:
@@ -5717,7 +5837,6 @@ def aten_native_batch_norm(
return norm, input_mean, input_rstd
-@torch_op("aten::native_batch_norm", private=True)
def _aten_native_batch_norm_training_onnx(
input: TFloat,
weight: TFloat,
@@ -5769,7 +5888,6 @@ def _aten_native_batch_norm_training_onnx(
return norm, mean, rstd, running_mean, new_running_var
-@torch_op("aten::native_batch_norm", private=True)
def _aten_native_batch_norm_inference_onnx(
input: TFloat,
weight: TFloat,
@@ -5829,14 +5947,18 @@ def aten__native_batch_norm_legit_functional(
axes.pop(1)
axes = op.Constant(value_ints=axes)
if running_mean is None: # Using input mean
- running_mean = op.Squeeze(op.ReduceMean(input, axes))
+ running_mean = op.ReduceMean(input, axes, keepdims=False)
if running_var is None: # Using input var
mean = op.ReduceMean(input, axes)
input_sub_mean = op.Sub(input, mean)
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
- running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes))
+ running_var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False)
+ # TODO: This is a temporary fix for the issue that BatchNormalization
+ # is forced to be in training mode in PyTorch, and ORT currently
+ # only supports training mode with opset version lower than 14.
+ training = False
# We have to split to two private functions, because BatchNormalization returns
# three outputs when training_mode=True and one when it is False.
if training:
@@ -5899,16 +6021,10 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::native_dropout")
-def aten_native_dropout(
- input: TFloatOrBFloat16, p: float, train: bool = True
-) -> Tuple[TFloatOrBFloat16, BOOL]:
+@torch_op("aten::native_dropout", trace_only=True)
+def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""
- # Python bool attributes need to be explicitly converted to BOOL
- # because the underlying attribute type is int
- # TODO(#872): Allow ONNX Script to handle this conversion
- train = op.Cast(train, to=BOOL.dtype)
result, mask = op.Dropout(input, p, train)
return result, mask
@@ -5926,9 +6042,9 @@ def aten_native_group_norm(
input: TFloat,
weight: Optional[TFloat] = None,
bias: Optional[TFloat] = None,
- N: Optional[INT64] = None, # pylint: disable=unused-argument
- C: Optional[INT64] = None, # pylint: disable=unused-argument
- HxW: Optional[INT64] = None, # pylint: disable=unused-argument
+ N: Optional[INT64] = None,
+ C: Optional[INT64] = None,
+ HxW: Optional[INT64] = None,
group: int = 1,
eps: float = 1e-05,
) -> Tuple[TFloat, TFloat, TFloat]:
@@ -5941,22 +6057,10 @@ def aten_native_group_norm(
if bias is None: # Set to 0.0 as default, the shape is Channel size
bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))
- # Accoding to Torch, return rstd instead of var
- norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps)
- return norm, mean, rstd
-
-
-@torch_op("aten::native_group_norm", private=True)
-def _aten_native_group_norm_onnx(
- input: TFloat,
- weight: TFloat,
- bias: TFloat,
- group: INT64,
- eps: float,
-) -> Tuple[TFloat, TFloat, TFloat]:
# Because onnx.GroupNorm() need size=group for weight and bias
# But the torch's aten function's input need size=channel, the size mismatched
# So we have to use onnx.InstanceNorm() to simulate
+ # This implementation should be simplified after opset 21
neg_1 = op.Constant(value_ints=[-1])
# Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter
group_tensor = op.Reshape(group, neg_1)
@@ -5969,14 +6073,16 @@ def _aten_native_group_norm_onnx(
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
)
# Reshape back to input's shape
- norm = op.Reshape(norm, op.Shape(input))
+ norm = op.Reshape(norm, op.Shape(input), allowzero=True)
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = Rank(input)
axes_unsqueeze = op.Range(1, input_rank - 1, 1)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
+ weight_full_shape = op.CastLike(weight_full_shape, norm)
norm_mul_weight = op.Mul(norm, weight_full_shape)
+ bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight)
norm_result = op.Add(norm_mul_weight, bias_full_shape)
# Compute mean and rstd, but using Torch algorithm
# The returned shape for mean and vstd should be [N, group, -1]
@@ -5991,7 +6097,9 @@ def _aten_native_group_norm_onnx(
sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean)
# In Pytorch, vstd = 1/(sqrt(var + eps))
var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False)
- rstd = op.Div(1.0, op.Sqrt(var + eps))
+ eps = op.Constant(value=ir.tensor(eps, dtype=input.dtype))
+ one = op.Constant(value=ir.tensor(1.0, dtype=input.dtype))
+ rstd = op.Div(one, op.Sqrt(op.Add(var, eps)))
# Get the correct shape [N, group] for mean again
mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=False)
return norm_result, mean, rstd
@@ -6017,7 +6125,7 @@ def aten_native_group_norm_backward(
@torch_op("aten::native_layer_norm", trace_only=True)
def aten_native_layer_norm(
input: TReal,
- normalized_shape: INT64,
+ normalized_shape: Sequence[int],
weight: Optional[TReal] = None,
bias: Optional[TReal] = None,
eps: float = 1e-05,
@@ -6066,14 +6174,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()
-@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"))
+@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"), trace_only=True)
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""
return op.Not(op.Equal(self, other))
-@torch_op(("aten::neg", "_operator::neg"))
+@torch_op(("aten::neg", "_operator::neg"), trace_only=True)
def aten_neg(self: TReal) -> TReal:
"""neg(Tensor self) -> Tensor"""
@@ -6086,111 +6194,94 @@ def aten_negative(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::new_empty")
-def aten_new_empty(self: TTensor, size: INT64) -> TTensor:
- """new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
-
- # using zero to simulate empty array
- result = op.ConstantOfShape(size)
- return op.CastLike(result, self)
-
-
-@torch_op("aten::new_empty")
-def aten_new_empty_dtype(
- self: TTensor, # pylint: disable=unused-argument
+@torch_op("aten::new_empty", trace_only=True)
+def aten_new_empty(
+ self: TTensor,
size: INT64,
- dtype: int,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TTensor:
"""new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
# using zero to simulate empty array
result = op.ConstantOfShape(size)
+ if dtype == -1:
+ return op.CastLike(result, self)
return op.Cast(result, to=dtype)
-@torch_op("aten::new_empty_strided")
+@torch_op("aten::new_empty_strided", trace_only=True)
def aten_new_empty_strided(
self: TTensor,
size: INT64,
- stride: INT64, # pylint: disable=unused-argument
-) -> TTensor:
- """new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
-
- # using zero to simulate empty array
- zero = op.ConstantOfShape(size)
- return op.CastLike(zero, self)
-
-
-@torch_op("aten::new_empty_strided")
-def aten_new_empty_strided_dtype(
- self: TTensor, # pylint: disable=unused-argument
- size: INT64,
- stride: INT64, # pylint: disable=unused-argument
- dtype: int,
+ stride: INT64,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TTensor:
"""new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
# using zero to simulate empty array
zero = op.ConstantOfShape(size)
+ if dtype == -1:
+ return op.CastLike(zero, self)
return op.Cast(zero, to=dtype)
-@torch_op("aten::new_full")
-def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
- # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
-
- fill_value = op.CastLike(fill_value, self)
- return op.Expand(fill_value, size)
-
-
-@torch_op("aten::new_full")
-def aten_new_full_dtype(
- self: TTensor, # pylint: disable=unused-argument
+@torch_op("aten::new_full", trace_only=True)
+def aten_new_full(
+ self: TTensor,
size: INT64,
- fill_value: TTensor,
- dtype: int,
+ fill_value: TensorType,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TTensor:
# new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- fill_value = op.Cast(fill_value, to=dtype)
+ if dtype == -1:
+ fill_value = op.CastLike(fill_value, self)
+ else:
+ fill_value = op.Cast(fill_value, to=dtype)
return op.Expand(fill_value, size)
-@torch_op("aten::new_ones")
-def aten_new_ones(self: TReal, size: INT64) -> TReal: # pylint: disable=unused-argument
+@torch_op("aten::new_ones", trace_only=True)
+def aten_new_ones(
+ self: TReal,
+ size: INT64,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TReal:
"""new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
one = op.Constant(value_float=1.0)
result = op.Expand(one, size)
- return op.CastLike(result, self)
+ if dtype == -1:
+ return op.CastLike(result, self)
+ return op.Cast(result, to=dtype)
-@torch_op("aten::new_ones")
-def aten_new_ones_dtype(
- self: TReal, # pylint: disable=unused-argument
+@torch_op("aten::new_zeros", trace_only=True)
+def aten_new_zeros(
+ self: TReal,
size: INT64,
- dtype: int,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TReal:
- one = op.Constant(value_float=1.0)
- result = op.Expand(one, size)
- return op.Cast(result, to=dtype)
-
-
-@torch_op("aten::new_zeros")
-def aten_new_zeros(self: TReal, size: INT64) -> TReal:
"""new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
result = op.ConstantOfShape(size)
- return op.CastLike(result, self)
-
-
-@torch_op("aten::new_zeros")
-def aten_new_zeros_dtype(
- self: TReal, # pylint: disable=unused-argument
- size: INT64,
- dtype: int,
-) -> TReal:
- result = op.ConstantOfShape(size)
+ if dtype == -1:
+ return op.CastLike(result, self)
return op.Cast(result, to=dtype)
@@ -6200,7 +6291,7 @@ def aten_nextafter(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::nonzero")
+@torch_op("aten::nonzero", trace_only=True)
def aten_nonzero(self: TTensor) -> INT64:
"""nonzero(Tensor self) -> Tensor"""
# NOTE: In torch the return shape is [n, d], while in onnx [d, n],
@@ -6220,7 +6311,7 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp
raise NotImplementedError()
-@torch_op(("aten::normal", "aten::normal_functional"), traceable=True)
+@torch_op("aten::normal_functional", trace_only=True)
def aten_normal(
self: TTensor,
mean: float = 0.0,
@@ -6228,26 +6319,28 @@ def aten_normal(
) -> TFloat: # type: ignore[type-var]
"""normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor"""
- if IsScalar(self):
+ if len(self.shape) == 0:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.RandomNormalLike(self, mean=mean, scale=std)
return result
-@torch_op("aten::normal.float_float")
+@torch_op("aten::normal.float_float", trace_only=True)
def aten_normal_float_float(
mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype
) -> TensorType:
"""normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
+ if dtype == -1:
+ dtype = FLOAT.dtype
# Create a dummy tensor for RandomNormalLike to get the shape
dummy_tensor = op.ConstantOfShape(size)
result = op.RandomNormalLike(dummy_tensor, mean=mean, scale=std)
return op.Cast(result, to=dtype)
-@torch_op("aten::normal.float_Tensor")
+@torch_op("aten::normal.float_Tensor", trace_only=True)
def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat:
"""normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor"""
@@ -6257,7 +6350,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat:
return op.Add(op.Mul(std, sampled), mean_casted)
-@torch_op("aten::normal.Tensor_float")
+@torch_op("aten::normal.Tensor_float", trace_only=True)
def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat:
"""normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor"""
@@ -6266,7 +6359,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat:
return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean)
-@torch_op("aten::normal.Tensor_Tensor")
+@torch_op("aten::normal.Tensor_Tensor", trace_only=True)
def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat:
"""normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor"""
@@ -6287,10 +6380,17 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::ones")
-def aten_ones(size: IntType, dtype: int = FLOAT.dtype):
+@torch_op("aten::ones", trace_only=True)
+def aten_ones(
+ size: IntType,
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+):
"""ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
-
+ if dtype == -1:
+ dtype = FLOAT.dtype
size = op.Cast(size, to=INT64.dtype)
one = op.Constant(value_float=1.0)
one = op.Cast(one, to=dtype)
@@ -6298,26 +6398,26 @@ def aten_ones(size: IntType, dtype: int = FLOAT.dtype):
@torch_op("aten::ones_like", trace_only=True)
-def aten_ones_like(self: TTensor, dtype: int = -1) -> TTensor:
- """ones_like.
+def aten_ones_like(
+ self: TTensor,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+ memory_format: str = "",
+) -> TTensor:
+ """ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
- # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
-
- # NOTE: trace_only because both if branches need to be the same type, but we have
- # a cast in the if branch.
+ if dtype is None:
+ dtype = -1
if dtype == -1:
one = op.CastLike(1, self)
else:
one = op.Cast(1, to=dtype)
- return _aten_ones_like_onnx(self, one)
-
-
-@torch_op("aten::ones_like", private=True)
-def _aten_ones_like_onnx(self: TTensor, one) -> TTensor:
shape = op.Shape(self)
return op.Expand(one, shape)
@@ -6403,16 +6503,42 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType:
raise NotImplementedError()
-def aten_pixel_shuffle(self: TensorType, upscale_factor: int) -> TensorType:
+@torch_op("aten::pixel_shuffle", trace_only=True)
+def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal:
"""pixel_shuffle(Tensor self, int upscale_factor) -> Tensor"""
+ if len(self.shape) == 4:
+ return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD")
- raise NotImplementedError()
+ # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
+ batch_dims = op.Shape(self, end=-3)
+ chw_in_dims = op.Shape(self, start=-3)
+
+ reshaped_self = op.Reshape(
+ self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
+ )
+ depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD")
+ final_dims = op.Shape(depth_to_space, start=1)
+ output_shape = op.Concat(batch_dims, final_dims, axis=0)
+ return op.Reshape(depth_to_space, output_shape, allowzero=True)
-def aten_pixel_unshuffle(self: TensorType, downscale_factor: int) -> TensorType:
+@torch_op("aten::pixel_unshuffle", trace_only=True)
+def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal:
"""pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor"""
+ if len(self.shape) == 4:
+ return op.SpaceToDepth(self, blocksize=downscale_factor)
- raise NotImplementedError()
+ # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
+ batch_dims = op.Shape(self, end=-3)
+ chw_in_dims = op.Shape(self, start=-3)
+
+ reshaped_self = op.Reshape(
+ self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
+ )
+ space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor)
+ final_dims = op.Shape(space_to_depth, start=1)
+ output_shape = op.Concat(batch_dims, final_dims, axis=0)
+ return op.Reshape(space_to_depth, output_shape, allowzero=True)
def aten_poisson(self: TensorType, generator: Optional[str] = None) -> TensorType:
@@ -6455,19 +6581,49 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op(
- ("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow")
-)
+@torch_op(("aten::pow.Tensor_Tensor", "_operator::pow"), trace_only=True)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""
-
+ # TODO(justinchuby): Add type promotion
return op.Pow(self, exponent)
-def aten_prelu(self: TensorType, weight: TensorType) -> TensorType:
+@torch_op("aten::pow.Tensor_Scalar", trace_only=True)
+def aten_pow_tensor_scalar(self: TReal, exponent: float) -> TReal:
+ """pow(Tensor self, Scalar exponent) -> Tensor"""
+ if self.dtype.is_floating_point():
+ # Handle cases when e.g. (1) self is float16 or int
+ return op.Pow(self, ir.tensor(exponent, dtype=self.dtype))
+ # For integer types, we need to cast self to the exponent type
+ if isinstance(exponent, int):
+ # The scalar exponent can be an int
+ return op.Pow(self, ir.tensor(exponent, dtype=self.dtype))
+
+ # exponent is float so we cast self to match the exponent type.
+ # More precisely if self is float64, we should cast exponent to float64; but
+ # this is uncommon and should be fixed when we create a general type promotion
+ # mechanism for torchlib
+ return op.Pow(op.Cast(self, to=FLOAT.dtype), exponent)
+
+
+@torch_op("aten::pow.Scalar", trace_only=True)
+def aten_pow_scalar(self: float, exponent: TTensor) -> TTensor:
+ """pow.Scalar(Scalar self, Tensor exponent) -> Tensor"""
+ return op.Pow(op.Cast(self, to=exponent.dtype), exponent)
+
+
+@torch_op(("aten::prelu", "aten::_prelu_kernel"), trace_only=True)
+def aten_prelu(self: TReal, weight: TReal) -> TReal:
"""prelu(Tensor self, Tensor weight) -> Tensor"""
- raise NotImplementedError()
+ rank = len(self.shape)
+ if rank == 0:
+ # e.g. self: [], weight: [1]
+ weight = op.Squeeze(weight)
+ elif rank >= 2:
+ # e.g. self: [5,10,5], weight: [10]
+ weight = op.Reshape(weight, [1, -1] + [1] * (rank - 2))
+ return op.PRelu(self, weight)
def aten_prelu_backward(
@@ -6478,10 +6634,22 @@ def aten_prelu_backward(
raise NotImplementedError()
-def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType:
+@torch_op("aten::prod", trace_only=True)
+def aten_prod(self: TReal, dtype: int = -1) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""
- raise NotImplementedError()
+ if dtype != -1 and dtype is not None:
+ self = op.Cast(self, to=dtype)
+ return op.ReduceProd(self)
+
+
+@torch_op("aten::prod.dim_int", trace_only=True)
+def aten_prod_dim_int(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal:
+ """prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
+
+ if dtype != -1 and dtype is not None:
+ self = op.Cast(self, to=dtype)
+ return op.ReduceProd(self, axes=[dim], keepdims=keepdim)
def aten_promote_types(type1: int, type2: int) -> int:
@@ -6701,37 +6869,48 @@ def aten_quantized_rnn_tanh_cell(
raise NotImplementedError()
-@torch_op("aten::rad2deg", traceable=True)
+@torch_op("aten::rad2deg", trace_only=True)
def aten_rad2deg(self: TFloat) -> TFloat:
"""rad2deg(Tensor self) -> Tensor"""
return op.Mul(self, op.CastLike(180.0 / _MATH_PI, self))
-@torch_op("aten::rand")
-def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
+@torch_op("aten::rand", trace_only=True)
+def aten_rand(
+ size: INT64,
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TReal:
"""rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
-
+ if dtype == -1:
+ dtype = FLOAT.dtype
shaper = op.ConstantOfShape(size)
return op.RandomUniformLike(shaper, dtype=dtype)
-@torch_op("aten::rand_like")
-def aten_rand_like(self: TFloat) -> TFloat:
- """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
-
- return op.RandomUniformLike(self)
-
-
-@torch_op("aten::rand_like")
-def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType:
+@torch_op("aten::rand_like", trace_only=True)
+def aten_rand_like(
+ self: TFloat, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False
+) -> TFloat:
"""rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
+ if dtype == -1:
+ return op.RandomUniformLike(self)
return op.RandomUniformLike(self, dtype=dtype)
-@torch_op("aten::randint")
-def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType:
+@torch_op("aten::randint", trace_only=True)
+def aten_randint(
+ high: INT64,
+ size: INT64,
+ dtype: int = INT64.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
shaper = op.ConstantOfShape(size)
@@ -6743,9 +6922,15 @@ def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorTy
return op.Cast(rand_int, to=dtype)
-@torch_op("aten::randint.low")
+@torch_op("aten::randint.low", trace_only=True)
def aten_randint_low(
- low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype
+ low: INT64,
+ high: INT64,
+ size: INT64,
+ dtype: int = INT64.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> TensorType:
"""randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
@@ -6760,21 +6945,15 @@ def aten_randint_low(
return op.Cast(rand_int, to=dtype)
-@torch_op("aten::randint_like")
-def aten_randint_like(self: TensorType, high: INT64) -> IntType:
- """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
-
- self_float = op.Cast(self, to=FLOAT.dtype)
- rand = op.RandomUniformLike(self_float)
- # Scale to [0, high] first
- rand_scaled = op.Mul(rand, op.CastLike(high, rand))
- # Round to ints
- rand_int = op.Floor(rand_scaled)
- return op.CastLike(rand_int, self)
-
-
-@torch_op("aten::randint_like")
-def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType:
+@torch_op("aten::randint_like", trace_only=True)
+def aten_randint_like(
+ self: TensorType,
+ high: INT64,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> IntType:
"""randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
self_float = op.Cast(self, to=FLOAT.dtype)
@@ -6783,11 +6962,21 @@ def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> Tensor
rand_scaled = op.Mul(rand, op.CastLike(high, rand))
# Round to ints
rand_int = op.Floor(rand_scaled)
+ if dtype == -1:
+ return op.CastLike(rand_int, self)
return op.Cast(rand_int, to=dtype)
-@torch_op("aten::randint_like.low_dtype")
-def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType:
+@torch_op("aten::randint_like.low_dtype", trace_only=True)
+def aten_randint_like_low_dtype(
+ self: TensorType,
+ low: INT64,
+ high: INT64,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> IntType:
"""randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None.
@@ -6801,55 +6990,47 @@ def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> In
rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
# Round to ints
rand_int = op.Floor(rand_translated)
- return op.CastLike(rand_int, self)
-
-
-@torch_op("aten::randint_like.low_dtype")
-def aten_randint_like_low_dtype_dtype(
- self: TensorType, low: INT64, high: INT64, dtype: int
-) -> TensorType:
- """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
-
- self_float = op.Cast(self, to=FLOAT.dtype)
- rand = op.RandomUniformLike(self_float)
- # Translate to [low, high] first
- high = op.Cast(high, to=FLOAT.dtype)
- low = op.Cast(low, to=FLOAT.dtype)
- rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low)
- # Round to ints
- rand_int = op.Floor(rand_translated)
+ if dtype == -1:
+ return op.CastLike(rand_int, self)
return op.Cast(rand_int, to=dtype)
-@torch_op("aten::randn")
-def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal:
+@torch_op("aten::randn", trace_only=True)
+def aten_randn(
+ size: INT64,
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TReal:
"""randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
shaper = op.ConstantOfShape(size)
return op.RandomNormalLike(shaper, dtype=dtype)
-@torch_op("aten::randn_like")
-def aten_randn_like(self: TFloat) -> TFloat:
- """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
-
- return op.RandomNormalLike(self)
-
-
-@torch_op("aten::randn_like")
-def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType:
+@torch_op("aten::randn_like", trace_only=True)
+def aten_randn_like(
+ self: TFloat, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False
+) -> TFloat:
"""randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
+ if dtype == -1:
+ return op.RandomNormalLike(self)
return op.RandomNormalLike(self, dtype=dtype)
-def aten_randperm(n: int) -> TensorType:
+def aten_randperm(
+ n: int, layout: str = "", device: str = "", pin_memory: bool = False
+) -> TensorType:
"""randperm(int n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
raise NotImplementedError()
-def aten_range(start: float, end: float) -> TensorType:
+def aten_range(
+ start: float, end: float, layout: str = "", device: str = "", pin_memory: bool = False
+) -> TensorType:
"""range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
raise NotImplementedError()
@@ -6867,8 +7048,8 @@ def aten_real(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::reciprocal")
-def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::reciprocal", trace_only=True)
+def aten_reciprocal(self: TFloat) -> TFloat:
"""reciprocal(Tensor self) -> Tensor"""
return op.Reciprocal(self)
@@ -6886,10 +7067,15 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::remainder")
-def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op(
+ ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True
+)
+def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
+ if self.dtype.is_integer():
+ return op.Mod(self, other)
+
# TODO(justinchuby): Improve fp16 precision by following the logic in
# https://github.com/pytorch/pytorch/blob/3a823e46170778cc32783f27596c77d0103084a9/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L264-L277
@@ -6899,13 +7085,6 @@ def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB
return op.Sub(self, op.Mul(rounded_quotient, other))
-@torch_op("aten::remainder")
-def aten_remainder_int(self: TInt, other: TInt) -> TInt:
- """remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
-
- return op.Mod(self, other)
-
-
def aten_rename(self: TensorType, names: Optional[str]) -> TensorType:
"""rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)"""
@@ -6918,39 +7097,141 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
raise NotImplementedError()
-@torch_op("aten::repeat")
-def aten_repeat(self: TTensor, repeats: TInt) -> TTensor:
+@torch_op("aten::repeat", trace_only=True)
+def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor:
"""repeat(Tensor self, SymInt[] repeats) -> Tensor"""
- if op.Size(repeats) == 0:
- result = self
+ if len(repeats) == 0:
+ return self
+ self_expanded = op.Expand(self, [1] * len(repeats))
+ return op.Tile(self_expanded, repeats)
+
+
+@torch_op("aten::repeat_interleave.self_int", trace_only=True)
+def aten_repeat_interleave_self_int(
+ self: TensorType, repeats: int, dim: Optional[int] = None
+) -> TensorType:
+ """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+
+ The trick is to repeat in one direction orthogonal to reshape.
+
+ .. code-block:: python
+
+ x = torch.tensor([[0, 1, 2], [3, 4, 5]])
+ x.repeat_interleave(2, dim=0)
+
+ is equivalent to:
+
+ .. code-block:: python
+
+ x = torch.tensor([[0, 1, 2], [3, 4, 5]])
+ x.repeat((1, 2)).reshape((-1, t.shape[1]))
+ """
+ if dim is None:
+ raise NotImplementedError("No conversion available yet when dim is None.")
+
+ self_rank = len(self.shape)
+ pos_dim = (dim + self_rank) % self_rank
+ unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
+ if isinstance(repeats, int):
+ tiles = [1] * (self_rank + 1)
+ tiles[pos_dim + 1] = repeats
+ tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
else:
- # TODO(justinchuby): Make ones_like a function when onnxscript supports it
- repeats = op.Cast(repeats, to=INT64.dtype)
- # shape = ones_like(repeats) := {
- one = op.Constant(value_int=1)
- repeats_shape = op.Shape(repeats)
- shape = op.Expand(one, repeats_shape)
- # }
- self_expanded = op.Expand(self, shape)
- result = op.Tile(self_expanded, repeats)
- return result
+ # repeats is a symbolic tensor
+ tile_repeat = op.Concat(
+ op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)),
+ op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
+ op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
+ axis=0,
+ )
+ tiled = op.Expand(unsqueezed, tile_repeat)
+ if self_rank == 1:
+ return op.Identity(tiled)
+ final_shape = op.Concat(
+ op.Shape(self, start=0, end=dim),
+ op.Constant(value_ints=[-1]),
+ op.Shape(self, start=pos_dim + 1),
+ axis=0,
+ )
+ return op.Reshape(tiled, final_shape)
-def aten_repeat_interleave(
- repeats: TensorType, output_size: Optional[int] = None
+@torch_op("aten::repeat_interleave.Tensor", trace_only=True)
+def aten_repeat_interleave_Tensor(
+ self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None
) -> TensorType:
- """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""
+ """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor
- raise NotImplementedError()
+ When `repeats` is a tensor, each line is multiplied
+ by a different number.
+ There are multiple strategies. Here is one.
+ .. code-block:: python
-@torch_op("aten::reshape")
-def aten_reshape(self: TTensor, shape: IntType) -> TTensor:
- """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)"""
+ import torch
- # Reshape only support INT64 as 'shape'
- shape = op.Cast(shape, to=INT64.dtype)
+ x = torch.tensor([[0, 1, 2], [3, 4, 5]])
+ times = torch.tensor([2, 3], dtype=torch.int64)
+ y = x.repeat_interleave(times, dim=0)
+ print("repeat_interleave")
+ print(y)
+
+ ci = times.cumsum(dim=0)
+ rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1))
+ srows = times.shape[0] - rows.to(torch.int64).sum(axis=0)
+ indices = srows.reshape((-1, ))
+ print("decomposed")
+ print(x[indices, :])
+ """
+ if repeats is None:
+ repeats = self
+ self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1)
+ if dim is None:
+ # flatten
+ self = op.Reshape(self, [-1])
+ rank = 1
+ else:
+ rank = len(self.shape)
+
+ if rank > 2:
+ shape_x0 = op.Shape(self, start=0, end=1)
+ shape_x = op.Shape(self, start=1)
+ self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
+ elif rank == 1:
+ shape_x = None
+ self = op.Reshape(self, [-1, 1])
+ else:
+ if rank != 2:
+ raise NotImplementedError(
+ f"rank(self)={rank} not implemented for repeat_interleave"
+ )
+ shape_x = None
+
+ ci = op.CumSum(repeats, [0])
+ last_ci = op.Gather(ci, [-1])
+ trange = op.Range(0, op.Squeeze(last_ci, [0]), 1)
+ rows = op.Less(trange, op.Unsqueeze(ci, [-1]))
+ srows = op.Sub(
+ op.Shape(self, start=0, end=1),
+ op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]),
+ )
+ indices = op.Reshape(srows, [-1])
+ values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
+ if rank == 2:
+ return values
+ # shape_x is None at this stage.
+ assert shape_x is None # for mypy
+ return op.Reshape(
+ values,
+ op.Concat([-1], shape_x, axis=0) if shape_x else [-1],
+ )
+
+
+@torch_op("aten::reshape", trace_only=True)
+def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor:
+ """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)"""
+ shape = common_ops.merge_dims(shape)
return op.Reshape(self, shape)
@@ -6960,14 +7241,14 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::resolve_conj")
+@torch_op("aten::resolve_conj", trace_only=True)
def aten_resolve_conj(self: TTensor) -> TTensor:
"""resolve_conj(Tensor(a) self) -> Tensor(a)"""
return op.Identity(self)
-@torch_op("aten::resolve_neg")
+@torch_op("aten::resolve_neg", trace_only=True)
def aten_resolve_neg(self: TTensor) -> TTensor:
"""resolve_neg(Tensor(a) self) -> Tensor(a)"""
@@ -7019,74 +7300,84 @@ def aten_rnn_tanh_cell(
@torch_op("aten::roll", trace_only=True)
-def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor:
+def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor:
"""roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
+ if isinstance(shifts, int):
+ shifts = [shifts]
+
+ if isinstance(dims, int):
+ dims = [dims]
+
self_rank = len(self.shape)
if self_rank == 0:
- return self
+ return op.Identity(self)
elif self.shape[0] == 0: # empty tensor
- return self
+ return op.Identity(self)
+
+ # NOTE: In pytorch, default value of dims is an empty list.
+ if len(dims) == 0: # Empty sequence
+ assert len(shifts) == 1, "shifts should be a single integer if dims is empty"
+ return _aten_roll_shift_no_dim_onnx(self, shifts[0])
else:
- # NOTE: In pytorch, default value of dims is an empty list.
- if len(dims) == 0: # Empty sequence
- # assert isinstance(shifts, int)
- return _aten_roll_shift_no_dim_onnx(self, shifts)
- else:
- # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list
- result = self
- for i in range(len(shifts)): # pylint: disable=consider-using-enumerate
- shift = op.Gather(shifts, i, axis=0)
- dim = dims[i]
- result = _aten_roll_shift_and_dim_onnx(result, shift, dim)
- return result
+ assert len(shifts) == len(dims)
+ result = self
+ for i, shift in enumerate(shifts):
+ dim = dims[i]
+ result = _aten_roll_shift_and_dim_onnx(result, shift, dim)
+ return result
@torch_op("aten::roll", trace_only=True, complex=True)
-def aten_roll_complex(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor:
+def aten_roll_complex(
+ self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()
+) -> TTensor:
"""roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
+ if isinstance(shifts, int):
+ shifts = [shifts]
+
+ if isinstance(dims, int):
+ dims = [dims]
+
self_rank = len(self.shape)
if self_rank == 1:
- return self
+ return op.Identity(self)
if self.shape[0] == 0: # empty tensor
- return self
+ return op.Identity(self)
self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
if not dims:
- # assert isinstance(shifts, int)
- shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts)
- shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts)
+ assert len(shifts) == 1, "shifts should be a single integer if dims is empty"
+ shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0])
+ shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0])
result = op.Concat(shift_real, shift_imag, axis=-1)
else:
- # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list
+ assert len(shifts) == len(dims)
for i, dim in enumerate(dims):
- shift = op.Gather(shifts, i, axis=0)
- self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim)
- self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim)
+ self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim)
+ self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim)
result = op.Concat(self_real, self_imag, axis=-1)
return result
-@torch_op("aten::roll", private=True)
-def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor:
+def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor:
neg_1 = op.Constant(value_ints=[-1])
# flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D]
self_flatten = op.Reshape(self, neg_1)
# Compute slice length
- shift_tensor = op.Reshape(shift, neg_1)
- if shift_tensor < 0:
+ if shift < 0:
# For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end
- slice_length = -shift_tensor
+ slice_length = op.Constant(value_ints=[-shift])
else:
# For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end
# The effect equals to move [D] to the beginning
- slice_length = op.Size(self_flatten) - shift_tensor
+ slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift])
# Get second part of the tensor, e.g. [A,B,C]
suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length)
# Get first part of the tensor, e.g. [D]
@@ -7096,15 +7387,13 @@ def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor:
return op.Reshape(result, op.Shape(self))
-@torch_op("aten::roll", private=True)
-def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor:
+def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor:
neg_1 = op.Constant(value_ints=[-1])
- dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1)
- shift_tensor = op.Reshape(shift, neg_1)
- if shift_tensor < 0:
- slice_length = -shift_tensor
+ dim_tensor = op.Constant(value_ints=[dim])
+ if shift < 0:
+ slice_length = op.Constant(value_ints=[-shift])
else:
- slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor
+ slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift])
# from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix
suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor)
prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor)
@@ -7118,7 +7407,7 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
raise NotImplementedError()
-@torch_op("aten::round")
+@torch_op("aten::round", trace_only=True)
def aten_round(self: TFloat) -> TFloat:
"""round(Tensor self) -> Tensor"""
@@ -7167,50 +7456,49 @@ def aten_rrelu(
raise NotImplementedError()
-def aten_rshift(self: TensorType, other: TensorType) -> TensorType:
- """__rshift__.Tensor(Tensor self, Tensor other) -> Tensor"""
-
- raise NotImplementedError()
-
-
-@torch_op("aten::rsqrt")
-def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::rsqrt", trace_only=True)
+def aten_rsqrt(self: TFloat) -> TFloat:
"""rsqrt(Tensor self) -> Tensor"""
return op.Reciprocal(op.Sqrt(self))
-@torch_op(("aten::rsub", "aten::rsub.Scalar"))
+# Do not register rsub. It will be decomposed and type promoted by torch
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
- return op.Sub(other, op.Mul(self, alpha))
-
-
-@torch_op(("aten::rsub", "aten::rsub.Scalar"), trace_only=True, complex=True)
-def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
- """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
-
- return aten_rsub(self, other, alpha)
+ raise NotImplementedError
@torch_op("aten::scalar_tensor", trace_only=True)
-def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> RealType:
+def aten_scalar_tensor(
+ s: TensorType,
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> RealType:
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
+ if dtype == -1:
+ dtype = FLOAT.dtype
- # Set trace_only=True because different if branches return different dtypes
- # which is not supported in an ONNX function
return common_ops.cast_to(s, dtype=dtype)
@torch_op("aten::scalar_tensor", trace_only=True, complex=True)
def aten_scalar_tensor_complex(
- s: Union[FLOAT, DOUBLE], dtype: int = COMPLEX64.dtype
+ s: Union[FLOAT, DOUBLE],
+ dtype: int = COMPLEX64.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
) -> RealType:
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
# NOTE: When the input is originally in complex, this function is invoked.
# On the other hand, when the input is originally in real, aten_scalar_tensor is used.
# is invoked.
+ if dtype == -1:
+ dtype = COMPLEX64.dtype
if dtype == COMPLEX128.dtype:
result = op.Cast(s, to=DOUBLE.dtype)
elif dtype == COMPLEX64.dtype:
@@ -7222,16 +7510,38 @@ def aten_scalar_tensor_complex(
return result
-@torch_op("aten::scalar_tensor", trace_only=True)
-def aten_scalar_tensor_sym_number(s: RealType, dtype: int = FLOAT.dtype) -> RealType:
- """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
+@torch_op("aten::scatter.src", trace_only=True)
+def aten_scatter_src(
+ self: TTensor,
+ dim: int, # we have to use int here because ScatterElements() will use this attribute
+ index: TInt,
+ src: TTensor,
+) -> TTensor:
+ """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
+ if len(index.shape) == 0:
+ index = op.Unsqueeze(index, [0])
+ if len(src.shape) == 0:
+ src = op.Unsqueeze(src, [0])
+ return op.ScatterElements(self, index, src, axis=dim)
- # Set trace_only=True because different if branches return different dtypes
- # which is not supported in an ONNX function
- return common_ops.cast_to(s, dtype=dtype)
+@torch_op("aten::scatter.value", trace_only=True)
+def aten_scatter_value(
+ self: TTensor,
+ dim: int, # we have to use int here because ScatterElements() will use this attribute
+ index: TInt,
+ value: float,
+) -> TTensor:
+ """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
+ # Ensure value is a scalar tensor and expand it to match index shape
+ if len(index.shape) == 0:
+ index = op.Unsqueeze(index, [0])
+ scalar_tensor = ir.tensor([value], dtype=self.dtype)
+ src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor)
+ return op.ScatterElements(self, index, src, axis=dim)
-@torch_op("aten::scatter_add")
+
+@torch_op("aten::scatter_add", trace_only=True)
def aten_scatter_add(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
@@ -7244,14 +7554,14 @@ def aten_scatter_add(
return op.ScatterElements(self, index, src, axis=dim, reduction="add")
-@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True)
+@torch_op("aten::scatter_reduce.two", trace_only=True)
def aten_scatter_reduce(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
reduce: str,
- include_self: bool = True, # pylint: disable=unused-argument
+ include_self: bool = True,
):
"""scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor"""
@@ -7263,24 +7573,66 @@ def aten_scatter_reduce(
"amax": "max",
}
onnx_reduce = reduce_mode[reduce]
- return _aten_scatter_reduce_onnx(self, index, src, dim, onnx_reduce)
+ dtype = src.dtype or self.dtype
+ assert dtype is not None, "dtype should be not None"
-
-@torch_op("aten::scatter_reduce", private=True)
-def _aten_scatter_reduce_onnx(
- self: TReal,
- index: TInt,
- src: TReal,
- dim: int,
- onnx_reduce: str,
-):
- self_is_scalar = IsScalar(self)
+ self_is_scalar = len(self.shape) == 0
if self_is_scalar: # assert (index_rank == 0 and rank_src == 0)
neg_1 = op.Constant(value_ints=[-1])
self = op.Reshape(self, neg_1)
index = op.Reshape(index, neg_1)
src = op.Reshape(src, neg_1)
+
+ if not include_self:
+ # onnx standard always assume the value from self is part of the reduction.
+ # A first step is added to replace the impacted value by another one
+ # chosen in a way that the results of the reduction is not changed
+ # whether or not it takes part in it.
+ # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul.
+ # mean is not supported.
+ if onnx_reduce == "max":
+ if dtype in {
+ ir.DataType.FLOAT16,
+ ir.DataType.FLOAT,
+ ir.DataType.DOUBLE,
+ }:
+ value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype)
+ elif dtype == ir.DataType.BFLOAT16:
+ value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype)
+ elif dtype == ir.DataType.BOOL:
+ value = ir.tensor([False], dtype=dtype)
+ else:
+ value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype)
+ reduction_init = "min"
+ elif onnx_reduce == "min":
+ if dtype in {
+ ir.DataType.FLOAT16,
+ ir.DataType.FLOAT,
+ ir.DataType.DOUBLE,
+ }:
+ value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype)
+ elif dtype == ir.DataType.BFLOAT16:
+ value = ir.tensor([torch.finfo(torch.bfloat16).max], dtype=dtype)
+ elif dtype == ir.DataType.BOOL:
+ value = ir.tensor([True], dtype=dtype)
+ else:
+ value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype)
+ reduction_init = "max"
+ elif onnx_reduce == "add":
+ value = ir.tensor([0], dtype=dtype)
+ reduction_init = "none"
+ elif onnx_reduce == "mul":
+ value = ir.tensor([1], dtype=dtype)
+ reduction_init = "none"
+ else:
+ value = ir.tensor([0], dtype=dtype)
+ reduction_init = "none"
+
+ cst = op.ConstantOfShape(op.Shape(src), value=value)
+ self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init)
+
result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce)
+
if self_is_scalar:
result = op.Squeeze(result)
return result
@@ -7314,7 +7666,7 @@ def aten_segment_reduce(
raise NotImplementedError()
-@torch_op(("aten::select", "aten::select.int"))
+@torch_op("aten::select.int", trace_only=True)
def aten_select(self: TTensor, dim: int, index: int) -> TTensor:
"""select(Tensor self, int dim, int index) -> Tensor"""
@@ -7329,13 +7681,13 @@ def aten_select_backward(
raise NotImplementedError()
-@torch_op("aten::select_scatter")
+@torch_op("aten::select_scatter", trace_only=True)
def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int) -> TensorType:
"""select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor"""
# Change src rank to self rank according to dim
# e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4]
- update = op.Unsqueeze(src, axes=dim)
+ update = op.Unsqueeze(src, axes=[dim])
# Change index rank to the same as 'update' [2,1,4]
indices = op.Expand(index, op.Shape(update))
return op.ScatterElements(self, indices, update, axis=dim, reduction="none")
@@ -7360,8 +7712,8 @@ def aten_sgn(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::sigmoid")
-def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::sigmoid", trace_only=True)
+def aten_sigmoid(self: TFloat) -> TFloat:
"""sigmoid(Tensor self) -> Tensor"""
return op.Sigmoid(self)
@@ -7380,21 +7732,36 @@ def aten_signbit(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::sin")
+@torch_op("aten::sin", trace_only=True)
def aten_sin(self: TFloat) -> TFloat:
"""sin(Tensor self) -> Tensor"""
return op.Sin(self)
-@torch_op("aten::sinh")
+@torch_op("aten::sinh", trace_only=True)
def aten_sinh(self: TFloat) -> TFloat:
"""sinh(Tensor self) -> Tensor"""
return op.Sinh(self)
-@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True)
+@torch_op(("aten::slice.Tensor"), trace_only=True, complex=True)
+def aten_slice_complex(
+ self: TTensor,
+ dim: int = 0,
+ start: Optional[INT64] = None,
+ end: Optional[INT64] = None,
+ step: Optional[INT64] = None,
+) -> TTensor:
+ """slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)"""
+ if dim < 0:
+ # Account for the complex dimension in ONNX
+ dim = len(self.shape) + dim - 1
+ return aten_slice(self, dim, start, end, step)
+
+
+@torch_op(("aten::slice.Tensor"), trace_only=True)
def aten_slice(
self: TTensor,
dim: int = 0,
@@ -7488,7 +7855,7 @@ def aten_slice_scatter(
zero,
op.Unsqueeze(step, zero),
)
- index_base = op.Unsqueeze(index_base, -1)
+ index_base = op.Unsqueeze(index_base, [-1])
# Use trace only to construct the perm attribute in Transpose
dims = None
@@ -7522,15 +7889,15 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), trace_only=True)
-def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
+@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True)
+def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
- self_is_scalar = IsScalar(self)
+ self_is_scalar = len(self.shape) == 0
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
result = op.Softmax(self, axis=dim)
- if dtype != -1:
+ if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)
if self_is_scalar:
# Convert to scalar when input is scalar
@@ -7539,27 +7906,20 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB
return result
-@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), traceable=True)
-def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
- """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""
-
- self_is_scalar = IsScalar(self)
- if self_is_scalar:
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
- result = op.Softmax(self, axis=dim)
- if self_is_scalar:
- # Convert to scalar when input is scalar
- result = op.Squeeze(result)
-
- return result
-
-
+@torch_op("aten::sort", trace_only=True)
def aten_sort(
- self: TensorType, dim: int = -1, descending: bool = False
-) -> tuple[TensorType, TensorType]:
- """sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)"""
+ self: TReal, dim: int = -1, descending: bool = False, stable: bool = False
+) -> tuple[TReal, INT64]:
+ """sort(Tensor self, int dim=-1, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices)"""
- raise NotImplementedError()
+ self_is_scalar = len(self.shape) == 0
+ if self_is_scalar:
+ return op.Identity(self), op.Constant(value_int=0)
+ shape = op.Shape(self)
+ dim_size = op.Gather(shape, dim, axis=0)
+ dim_size = op.Reshape(dim_size, op.Constant(value_ints=[1]))
+ values, indices = op.TopK(self, dim_size, axis=dim, largest=descending, sorted=True)
+ return values, indices
def aten_sparse_dim(self: TensorType) -> int:
@@ -7602,8 +7962,8 @@ def aten_split_with_sizes_copy(
raise NotImplementedError()
-@torch_op("aten::sqrt")
-def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::sqrt", trace_only=True)
+def aten_sqrt(self: TFloat) -> TFloat:
"""sqrt(Tensor self) -> Tensor"""
return op.Sqrt(self)
@@ -7615,25 +7975,18 @@ def aten_square(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::squeeze")
+@torch_op("aten::squeeze", trace_only=True)
def aten_squeeze(self: TTensor) -> TTensor:
"""squeeze(Tensor(a) self) -> Tensor(a)"""
return op.Squeeze(self)
-@torch_op("aten::squeeze.dim")
+@torch_op("aten::squeeze.dim", trace_only=True)
def aten_squeeze_dim(self: TTensor, dim: int) -> TTensor:
- result = self
- if Rank(self) > 0: # type: ignore[operator]
- # check if specified dimension is 1, do squeeze
- shape = op.Shape(self)
- dim_size = op.Gather(shape, dim, axis=0)
- if dim_size == 1:
- dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
- result = op.Squeeze(self, dims)
-
- return result
+ if len(self.shape) == 0:
+ return op.Identity(self)
+ return op.Squeeze(self, [dim])
@torch_op("aten::squeeze.dim", complex=True, trace_only=True)
@@ -7642,6 +7995,9 @@ def aten_squeeze_dim_complex(self: TTensor, dim: int) -> TTensor:
# Account for the complex dimension in ONNX
dim = dim - 1
+ if len(self.shape) == 1:
+ # The single dimension is the complex dimension
+ return op.Identity(self)
return aten_squeeze_dim(self, dim)
@@ -7668,168 +8024,126 @@ def aten_stack_complex(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTen
return aten_stack(tensors, dim)
-@torch_op("aten::stack")
+@torch_op("aten::stack", trace_only=True)
def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrString:
"""stack(Tensor[] tensors, int dim=0) -> Tensor"""
+ if isinstance(tensors, Sequence):
+ unsqueezed = [op.Unsqueeze(t, op.Constant(value_ints=[dim])) for t in tensors]
+ return op.Concat(*unsqueezed, axis=dim)
return op.ConcatFromSequence(tensors, axis=dim, new_axis=1)
-def aten_std(self: TensorType, unbiased: bool = True) -> TensorType:
+# std is decomposed by PyTroch
+def aten_std(self: TReal, unbiased: bool = True) -> TReal:
"""std(Tensor self, bool unbiased=True) -> Tensor"""
+ var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
+ return op.Sqrt(var)
- raise NotImplementedError()
+# std_dim is decomposed by PyTroch
+def aten_std_dim(
+ self: TReal,
+ dim: Sequence[int],
+ unbiased: Optional[bool] = True,
+ keepdim: Optional[bool] = False,
+) -> TReal:
+ """std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor"""
-def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]:
- """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
+ var = _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)
+ return op.Sqrt(var)
- raise NotImplementedError()
+# std is decomposed by PyTroch
+def aten_std_correction(
+ self: TReal,
+ # FIXME(justinchuby): Make dim Optional[Sequence[int]]
+ dim: Optional[int] = None,
+ correction: Optional[float] = None,
+ keepdim: bool = False,
+) -> TReal:
+ """std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor"""
-@torch_op("aten::stft", private=True)
-def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
- signal_rank = Rank(self)
- if signal_rank == 1:
- # Add a batch dimension
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
- return self, signal_rank
-
-
-@torch_op("aten::stft", private=True)
-def _center_window_around_zeros_if_needed(
- window: TFloatOrBFloat16, n_fft: int
-) -> TFloatOrBFloat16:
- # first dimension
- n_win = op.Shape(window, start=0, end=1)
- # Center window around zeros if needed (required by ONNX's STFT)
- if n_win < n_fft:
- left = (n_fft - n_win) / 2
-
- right = n_fft - left - n_win
- left = op.Reshape(left, op.Constant(value_ints=[1]))
- right = op.Reshape(right, op.Constant(value_ints=[1]))
-
- left_win = op.Expand(op.Constant(value_ints=[0]), left)
- right_win = op.Expand(op.Constant(value_ints=[0]), right)
- right_win = op.CastLike(right_win, window)
- left_win = op.CastLike(left_win, window)
- window = op.Concat(left_win, window, right_win, axis=0)
- return window
-
-
-@torch_op("aten::stft", private=True)
-def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
- left = (n_fft - win_length) / 2
-
- right = n_fft - left - win_length
- left = op.Reshape(left, op.Constant(value_ints=[1]))
- right = op.Reshape(right, op.Constant(value_ints=[1]))
- win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
-
- left_win = op.Expand(op.Constant(value_ints=[0]), left)
- right_win = op.Expand(op.Constant(value_ints=[0]), right)
- window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
- return op.Concat(left_win, window_list, right_win, axis=0)
-
-
-@torch_op("aten::stft", private=True)
-def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
- n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
- window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
- return window
-
-
-@torch_op("aten::stft", private=True)
-def _normalize_fft_result(
- signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
-) -> TFloatOrBFloat16:
- n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
- sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
- result = result / sqrt_nfft
- return result
+ if correction is None:
+ correction = 1.0
+ if dim is None:
+ var = _aten_var_onnx(self, correction=correction, keepdim=keepdim)
+ else:
+ var = _aten_var_dim_onnx(self, dims=dim, correction=correction, keepdim=keepdim)
+ return op.Sqrt(var)
-@torch_op("aten::stft", private=True)
-def _aten_stft_onnx(
- signal: TFloatOrBFloat16,
- frame_step_const: INT64,
- window: Union[TFloatOrBFloat16, INT64],
- frame_length_const: INT64,
- signal_rank: INT64,
- onesided: int,
-) -> TFloatOrBFloat16:
- window = op.CastLike(window, signal)
- result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
- result = op.Transpose(result, perm=[0, 2, 1, 3])
- # Remove batch dimension, if needed
- if signal_rank == 1:
- result = op.Squeeze(result, op.Constant(value_ints=[0]))
- return result
+# std_mean is decomposed by PyTroch
+def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
+ """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
-@torch_op("aten::stft", trace_only=True)
-def aten_stft(
- self: TFloatOrBFloat16,
- n_fft: int,
- hop_length: Optional[int] = None,
- win_length: Optional[int] = None,
- window: Optional[TFloatOrBFloat16] = None,
- normalized: bool = False,
- onesided: Optional[bool] = None,
- return_complex: Optional[bool] = None,
-) -> TFloatOrBFloat16:
- """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
-
- # NOTE: regarless of the value of return_complex, we always return a real representation.
- del return_complex
-
- # Get STFT sizes
- if hop_length is None:
- # core dump
- # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
- hop_length = n_fft // 4
- frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
- frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
-
- # Pre-process input if needed
- self, signal_rank = _add_batch_dimension(self)
-
- # Get window and make sure it's the same size as `win_length` or `n_fft`
- if window is not None and window.shape[0] is not None:
- window = _center_window_around_zeros_if_needed(window, n_fft)
- elif window is None:
- if win_length is not None:
- window = _create_window_from_win_length(win_length, n_fft)
- else:
- window = _create_window_from_n_fft(n_fft)
+ # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction"
+ # If not this case, should be explicitly set correction value according to unbiased value
+ var, mean = _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False)
+ return op.Sqrt(var), mean
- if onesided is None or onesided:
- onesided = 1
- else:
- onesided = 0
- # remove batch dimension included
- result = _aten_stft_onnx(
- self, frame_step_const, window, frame_length_const, signal_rank, onesided
+
+# std_mean is decomposed by PyTroch
+def aten_std_mean_dim(
+ self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
+) -> Tuple[TReal, TReal]:
+ """std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)"""
+
+ # Although dim is Optional in signature, but we assume it must have value for this overload
+ # Assert(dim is not None)
+ var, mean = _aten_var_mean_dim_onnx(
+ self, dims=dim, correction=float(unbiased), keepdim=keepdim
)
+ return op.Sqrt(var), mean
- # Normalize, if needed
- if normalized:
- result = _normalize_fft_result(self, result, n_fft)
- return result
+# std_mean is decomposed by PyTroch
+def aten_std_mean_correction(
+ self: TReal,
+ # FIXME(justinchuby): Make dim Optional[Sequence[int]]
+ dim: Optional[int] = None,
+ correction: Optional[float] = None,
+ keepdim: bool = False,
+) -> Tuple[TReal, TReal]:
+ """std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)"""
+
+ if correction is None:
+ correction = 1.0
+
+ if dim is None:
+ var, mean = _aten_var_mean_onnx(self, correction=correction, keepdim=keepdim)
+ else:
+ var, mean = _aten_var_mean_dim_onnx(
+ self, dims=dim, correction=correction, keepdim=keepdim
+ )
+ return op.Sqrt(var), mean
-@torch_op(("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"))
+@torch_op(
+ (
+ "aten::sub.Tensor",
+ "aten::sub.Scalar",
+ "aten::subtract.Tensor",
+ "aten::subtract.Scalar",
+ "_operator::sub",
+ ),
+ trace_only=True,
+)
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
- alpha = op.CastLike(alpha, other)
- other = op.Mul(other, alpha)
-
+ if alpha != 1.0:
+ alpha = op.CastLike(alpha, other)
+ other = op.Mul(other, alpha)
return op.Sub(self, other)
@torch_op(
- ("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"),
+ (
+ "aten::sub.Tensor",
+ "aten::sub.Scalar",
+ "aten::subtract.Tensor",
+ "aten::subtract.Scalar",
+ ),
trace_only=True,
complex=True,
)
@@ -7839,53 +8153,35 @@ def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return aten_sub(self, other, alpha=alpha)
-@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
-def aten_sum_dim_IntList(
- self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
-) -> TReal:
- """sum(Tensor self, SymInt dim, bool keepdim, *, ScalarType? dtype=None) -> Tensor"""
-
- # NOTE: trace_only because both if branches need to be the same type, but we have
- # a cast in the if branch.
-
- # TODO: Combine the overloads when OptionalHasElement() works
- if dim is None:
- result = _aten_sum_dim_none(self, keepdim=keepdim)
+@torch_op("aten::sum", trace_only=True)
+def aten_sum(self: TReal, dtype: int = -1) -> TReal:
+ """sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"""
+ if len(self.shape) == 0:
+ result = op.Identity(self)
else:
- result = _aten_sum_dim_onnx(self, dim, keepdim=keepdim)
-
- if dtype != -1:
+ result = op.ReduceSum(self, keepdims=False)
+ if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)
-
return result
-@torch_op("aten::sum", private=True, traceable=True)
-def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
- self_is_scalar = IsScalar(self)
- if self_is_scalar:
- self = op.Reshape(self, op.Constant(value_ints=[-1]))
-
- if IsScalar(dim):
+@torch_op("aten::sum.dim_IntList", trace_only=True)
+def aten_sum_dim_IntList(
+ self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
+) -> TReal:
+ """sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
+ if len(self.shape) == 0:
+ result = op.Identity(self)
+ elif dim is None:
+ result = op.ReduceSum(self, keepdims=keepdim)
+ else:
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
dim = op.Cast(dim, to=INT64.dtype)
- result = op.ReduceSum(self, dim, keepdims=keepdim)
+ result = op.ReduceSum(self, dim, keepdims=keepdim)
- if self_is_scalar:
- result = op.Squeeze(result)
- return result
-
-
-@torch_op("aten::sum", private=True)
-def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal:
- self_is_scalar = IsScalar(self)
- if self_is_scalar:
- self = op.Reshape(self, op.Constant(value_ints=[-1]))
-
- result = op.ReduceSum(self, keepdims=keepdim)
+ if dtype != -1 and dtype is not None:
+ result = op.Cast(result, to=dtype)
- if self_is_scalar:
- result = op.Squeeze(result)
return result
@@ -7915,17 +8211,10 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::sym_size")
-def aten_sym_size(self: TReal, dim: int = 0) -> TReal:
- """sym_size(Tensor self, int dim) -> Tensor"""
- # NOTE: onnxscript doesn't support attribute process,
- # so op.Shape(self, start=dim, end=dim + 1) is not supported.
- shape = op.Shape(self)
- # Reshape helps dim from int to tensor, and
- # input arguments support attribute processing.
- start = op.Reshape(dim, op.Constant(value_ints=[1]))
- end = op.Reshape(dim + 1, op.Constant(value_ints=[1]))
- return op.Slice(shape, start, end)
+@torch_op("aten::sym_size.int", trace_only=True)
+def aten_sym_size(self: TensorType, dim: int = 0) -> INT64:
+ """sym_size.int(Tensor self, int dim) -> SymInt"""
+ return op.Squeeze(op.Shape(self, end=dim + 1, start=dim))
def aten_symeig(
@@ -7936,7 +8225,7 @@ def aten_symeig(
raise NotImplementedError()
-@torch_op("aten::t", traceable=True)
+@torch_op("aten::t", trace_only=True)
def aten_t(self: TTensor) -> TTensor:
"""t(Tensor(a) self) -> Tensor(a)"""
@@ -7969,33 +8258,33 @@ def aten_take_along_dim(
raise NotImplementedError()
-@torch_op("aten::tan")
+@torch_op("aten::tan", trace_only=True)
def aten_tan(self: TFloat) -> TFloat:
"""tan(Tensor self) -> Tensor"""
return op.Tan(self)
-@torch_op("aten::tanh")
+@torch_op("aten::tanh", trace_only=True)
def aten_tanh(self: TFloat) -> TFloat:
"""tanh(Tensor self) -> Tensor"""
return op.Tanh(self)
-@torch_op("aten::tensor.bool")
+@torch_op("aten::tensor.bool", trace_only=True)
def aten_tensor_bool(self: bool, dtype: int) -> TensorType:
tensor = op.Constant(value_int=self)
return op.Cast(tensor, to=dtype)
-@torch_op("aten::tensor.float")
+@torch_op("aten::tensor.float", trace_only=True)
def aten_tensor_float(self: float, dtype: int) -> TensorType:
tensor = op.Constant(value_float=self)
return op.Cast(tensor, to=dtype)
-@torch_op("aten::tensor.int")
+@torch_op("aten::tensor.int", trace_only=True)
def aten_tensor_int(self: int, dtype: int) -> TensorType:
tensor = op.Constant(value_int=self)
return op.Cast(tensor, to=dtype)
@@ -8045,7 +8334,7 @@ def aten_tile(self: TTensor, dims: INT64) -> TTensor:
exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d)
self_shape = op.Shape(self)
self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0)
- self = op.Reshape(self, self_final_shape)
+ self = op.Reshape(self, self_final_shape, allowzero=True)
return op.Tile(self, dims)
@@ -8112,20 +8401,14 @@ def aten_to_sparse_csr(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::topk", traceable=True)
+@torch_op("aten::topk", trace_only=True)
def aten_topk(
- self: TReal, k: INT64, dim: int = -1, largest: bool = True, sorted: bool = True
+ self: TReal, k: int, dim: int = -1, largest: bool = True, sorted: bool = True
) -> Tuple[TReal, INT64]:
"""topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)"""
- self_is_scalar = IsScalar(self)
- if self_is_scalar:
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
- k = op.Reshape(op.Cast(k, to=INT64.dtype), op.Constant(value_ints=[1]))
- values, indices = op.TopK(self, k, axis=dim, largest=largest, sorted=sorted)
- if self_is_scalar:
- values = op.Squeeze(values, op.Constant(value_ints=[0]))
- indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
+ # We do not handle scalar inputs for topk
+ values, indices = op.TopK(self, [k], axis=dim, largest=largest, sorted=sorted)
return values, indices
@@ -8141,8 +8424,8 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()
-@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True)
-def aten_transpose(self, dim0: int, dim1: int):
+@torch_op("aten::transpose.int", trace_only=True)
+def aten_transpose(self: TTensor, dim0: int, dim1: int) -> TTensor:
"""transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"""
# Use trace only to construct the prem attribute in Transpose
@@ -8161,8 +8444,8 @@ def aten_transpose(self, dim0: int, dim1: int):
return result
-@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True, complex=True)
-def aten_transpose_complex(self, dim0: int, dim1: int):
+@torch_op("aten::transpose.int", trace_only=True, complex=True)
+def aten_transpose_complex(self: TTensor, dim0: int, dim1: int) -> TTensor:
"""transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"""
# Use trace only to construct the prem attribute in Transpose
@@ -8199,7 +8482,7 @@ def aten_triangular_solve(
raise NotImplementedError()
-@torch_op("aten::tril")
+@torch_op("aten::tril", trace_only=True)
def aten_tril(self: TTensor, diagonal: int = 0) -> TTensor:
"""tril(Tensor self, int diagonal=0) -> Tensor"""
@@ -8227,7 +8510,7 @@ def aten_triplet_margin_loss(
raise NotImplementedError()
-@torch_op("aten::triu")
+@torch_op("aten::triu", trace_only=True)
def aten_triu(self: TTensor, diagonal: int = 0) -> TTensor:
"""triu(Tensor self, int diagonal=0) -> Tensor"""
@@ -8240,40 +8523,43 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::trunc")
-def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op("aten::trunc", trace_only=True)
+def aten_trunc(self: TFloat) -> TFloat:
"""trunc(Tensor self) -> Tensor"""
-
- # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
- integer_parts = op.Floor(op.Abs(self))
- is_negative = op.Less(self, 0.0)
- return op.Where(is_negative, op.Neg(integer_parts), integer_parts)
+ # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-2658170591
+ return op.Floor(op.Abs(self)) * op.Sign(self)
-def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("aten::type_as", trace_only=True)
+def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
"""type_as(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.CastLike(self, other)
-@torch_op(("aten::unbind", "aten::unbind.int"))
+@torch_op("aten::unbind.int", trace_only=True)
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
- split_sizes = op.Constant(value_int=1)
- return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False)
+ if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
+ # We can create a definitive split op if the input shape is static
+ # Only torch>=2.7 supports correctly generating the correct number of outputs for Split
+ outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
+ return [op.Squeeze(out, [dim]) for out in outputs]
+
+ return op.SplitToSequence(self, axis=dim, keepdims=False)
-@torch_op("aten::unflatten")
-def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
+@torch_op("aten::unflatten.int", trace_only=True)
+def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]):
"""unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)"""
self_size = op.Shape(self)
# PyTorch accepts negative dim as reversed counting
- self_rank = op.Size(self_size)
- dim = self_rank + dim
- dim = dim % self_rank
+ self_rank = len(self.shape)
+ if dim < 0:
+ dim = self_rank + dim
head_start_idx = op.Constant(value_ints=[0])
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))
@@ -8283,9 +8569,17 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
tail_end_idx = op.Constant(value_ints=[_INT64_MAX])
tail_part_rank = op.Slice(self_size, tail_start_idx, tail_end_idx)
- final_shape = op.Concat(head_part_rank, sizes, tail_part_rank, axis=0)
+ sizes = [op.Reshape(size, op.Constant(value_ints=[1])) for size in sizes]
- return op.Reshape(self, final_shape)
+ # corner case 1: head part is None
+ if dim == 0:
+ final_shape = op.Concat(*sizes, tail_part_rank, axis=0)
+ # corner case 2: tail part is None
+ elif dim == self_rank - 1:
+ final_shape = op.Concat(head_part_rank, *sizes, axis=0)
+ else:
+ final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0)
+ return op.Reshape(self, final_shape, allowzero=True)
@torch_op("aten::unfold", trace_only=True)
@@ -8294,44 +8588,42 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
self_rank = len(self.shape)
if self_rank == 0:
- result = op.Unsqueeze(self, 0)
+ result = op.Unsqueeze(self, [0])
else:
# Handle negative dimension
if dimension < 0:
dimension = dimension + self_rank
- dim_size = self.shape[dimension]
- target_end = (dim_size - size) // step + 1
- if target_end >= 1: # the rank of final reuslt will be self_rank + 1
- self_rank = self_rank + 1
+
+ input_shape = op.Shape(self)
+ dim_size = op.Gather(input_shape, op.Constant(value_ints=[dimension]))
+
+ # Create indices for each window
+ window_starts = op.Range(0, op.Sub(dim_size, size - 1), step)
+
+ # Create the base indices for one window
+ window_indices = list(range(size))
+
+ # Broadcast to create all indices
+ starts_expanded = op.Unsqueeze(window_starts, [1]) # [num_windows, 1]
+ indices_expanded = op.Unsqueeze(window_indices, [0]) # [1, size]
+ all_indices = op.Add(starts_expanded, indices_expanded) # [num_windows, size]
+
+ # Gather along the specified dimension
+ result = op.Gather(self, all_indices, axis=dimension)
+
+ # The result shape is now [..., num_windows, size, ...] with num_windows at position 'dimension'.
+ # We need to move the size dimension to the end:
+ # Current shape: [..., num_windows, size, ...]
+ # Target shape: [..., num_windows, ..., size]
+
+ # Move the size dimension (at position dimension+1) to the end
# perm need to be list[int], so have to be generated in trace_only mode
- perm = list(range(self_rank))
- # from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1
+ perm = list(range(self_rank + 1))
perm.append(perm.pop(dimension + 1))
- result = _aten_unfold_onnx(self, dimension, size, step, target_end, perm)
- return result
+ result = op.Transpose(result, perm=perm)
-@torch_op("aten::unfold", private=True)
-def _aten_unfold_onnx(
- self: TTensor, dim: int, size: int, step: int, target_end: int, perm: Sequence[int]
-) -> TTensor:
- dims = op.Reshape(op.Constant(value_int=dim), op.Constant(value_ints=[-1]))
- # FIXME(justinchuby): obtain the dtype for SequenceEmpty, currently it assumes float
- seq_result = op.SequenceEmpty()
- i = op.Constant(value_int=0)
- cond = i < target_end
- while cond: # because for loop cannot work here, so use while loop
- starts = op.Reshape(i * step, [-1]) # starts is [0, step, step*2, step*3, ...]
- ends = starts + size # ends is [0+size, step+size, step*2+size, step*3+size, ...]
- slice_result = op.Slice(self, starts, ends, dims)
- # sequence only support float32
- slice_result_float32 = op.Cast(slice_result, to=FLOAT.dtype)
- seq_result = op.SequenceInsert(seq_result, slice_result_float32)
- i = i + 1
- cond = i < target_end
- concat_result = op.ConcatFromSequence(seq_result, axis=dim, new_axis=1)
- result = op.Transpose(concat_result, perm=perm)
- return op.CastLike(result, self)
+ return result
def aten_unfold_backward(
@@ -8359,16 +8651,84 @@ def aten_unique_consecutive(
raise NotImplementedError()
+@torch_op("aten::_unique", trace_only=True)
+def aten__unique(
+ self: TensorType,
+ sorted: bool = True, # pylint: disable=unused-argument
+ return_inverse: bool = False,
+) -> tuple[TensorType, TensorType]:
+ """_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""
+
+ unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
+ input_size = op.Shape(self)
+ if return_inverse:
+ inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
+ else:
+ input_numel = op.ReduceProd(input_size, keepdims=False)
+ if input_numel == 0:
+ inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
+ else:
+ inverse_indices = op.ConstantOfShape([0])
+ inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
+ return unique_values, inverse_indices
+
+
+@torch_op("aten::_unique2", trace_only=True)
+def aten__unique2(
+ self: TensorType,
+ sorted: bool = True, # pylint: disable=unused-argument
+ return_inverse: bool = False,
+ return_counts: bool = False,
+) -> tuple[TensorType, TensorType, TensorType]:
+ """_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
+
+ unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
+ input_size = op.Shape(self)
+ if return_inverse:
+ inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
+ else:
+ input_numel = op.ReduceProd(input_size, keepdims=False)
+ if input_numel == 0:
+ inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
+ else:
+ inverse_indices = op.ConstantOfShape([0])
+ inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
+ if not return_counts:
+ counts = op.ConstantOfShape([0])
+ counts = op.Cast(counts, to=INT64.dtype)
+ return unique_values, inverse_indices, counts
+
+
+@torch_op("aten::unique_dim", trace_only=True)
def aten_unique_dim(
self: TensorType,
dim: int,
- sorted: bool = True,
+ sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
return_counts: bool = False,
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
- raise NotImplementedError()
+ unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)
+ input_size = op.Shape(self)
+ # Normalize dim to be non-negative
+ input_ndim = op.Max(op.Size(input_size), op.Constant(value_ints=[1]))
+ dim = op.Mod(dim, input_ndim)
+ if return_inverse:
+ inverse_indices = op.Reshape(
+ inverse_indices,
+ op.Reshape(op.Slice(input_size, dim, dim + 1), op.Constant(value_ints=[-1])),
+ )
+ else:
+ inverse_indices = op.ConstantOfShape([0])
+ inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
+ if return_counts:
+ output_size = op.Shape(unique_values)
+ counts = op.Reshape(counts, op.Reshape(op.Slice(output_size, dim, dim + 1), [-1]))
+ else:
+ counts = op.ConstantOfShape([0])
+ counts = op.Cast(counts, to=INT64.dtype)
+ return unique_values, inverse_indices, counts
def aten_unique_dim_consecutive(
@@ -8385,10 +8745,11 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType
raise NotImplementedError()
-def aten_unsafe_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType:
+@torch_op("aten::unsafe_split.Tensor")
+def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]:
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""
- raise NotImplementedError()
+ return op.SplitToSequence(self, split_size, axis=dim)
def aten_unsafe_split_with_sizes(
@@ -8399,12 +8760,11 @@ def aten_unsafe_split_with_sizes(
raise NotImplementedError()
-@torch_op("aten::unsqueeze")
+@torch_op("aten::unsqueeze", trace_only=True)
def aten_unsqueeze(self: TTensor, dim: int) -> TTensor:
"""unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"""
- dim = op.Cast(dim, to=INT64.dtype)
- return op.Unsqueeze(self, dim)
+ return op.Unsqueeze(self, [dim])
def aten_unsqueeze_copy(self: TensorType, dim: int) -> TensorType:
@@ -8441,7 +8801,7 @@ def aten_vander(
raise NotImplementedError()
-@torch_op("aten::var", trace_only=True)
+# var is decomposed by PyTroch
def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
"""var(Tensor self, bool unbiased=True) -> Tensor"""
@@ -8450,7 +8810,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
return _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
-@torch_op("aten::var.dim", trace_only=True)
+# var is decomposed by PyTroch
def aten_var_dim(
self: TReal,
dim: Sequence[int],
@@ -8462,7 +8822,7 @@ def aten_var_dim(
return _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)
-@torch_op("aten::var.correction", trace_only=True)
+# var is decomposed by PyTroch
def aten_var_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -8482,7 +8842,7 @@ def aten_var_correction(
return var
-@torch_op("aten::var", private=True, traceable=True)
+# var is decomposed by PyTroch
def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal:
mean = op.ReduceMean(self, keepdims=keepdim)
sub_mean = op.Sub(self, mean)
@@ -8499,7 +8859,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe
return var
-@torch_op("aten::var.dim", private=True, traceable=True)
+# var is decomposed by PyTroch
def _aten_var_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> TReal:
@@ -8520,7 +8880,7 @@ def _aten_var_dim_onnx(
return var
-@torch_op("aten::var_mean", trace_only=True)
+# var_mean is decomposed by PyTroch
def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
@@ -8529,7 +8889,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
return _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False)
-@torch_op("aten::var_mean.dim", trace_only=True)
+# var_mean is decomposed by PyTroch
def aten_var_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
@@ -8540,7 +8900,7 @@ def aten_var_mean_dim(
return _aten_var_mean_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)
-@torch_op("aten::var_mean.correction", trace_only=True)
+# var_mean is decomposed by PyTroch
def aten_var_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -8562,7 +8922,7 @@ def aten_var_mean_correction(
return var, mean
-@torch_op("aten::var_mean", private=True)
+# var_mean is decomposed by PyTroch
def _aten_var_mean_onnx(
self: TReal, correction: float = 1.0, keepdim: bool = False
) -> Tuple[TReal, TReal]:
@@ -8582,7 +8942,7 @@ def _aten_var_mean_onnx(
return var, mean
-@torch_op("aten::var_mean.dim", private=True)
+# var_mean is decomposed by PyTroch
def _aten_var_mean_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> Tuple[TReal, TReal]:
@@ -8610,32 +8970,31 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::view")
-def aten_view(self: TTensor, size: IntType) -> TTensor:
+@torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True)
+def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
- size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
- return op.Reshape(self, size)
+ size = common_ops.merge_dims(size)
+ return op.Reshape(self, size, allowzero=True)
-@torch_op("aten::view", complex=True)
-def aten_view_complex(self: TTensor, size: IntType) -> TTensor:
+@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True)
+def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
- size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
- complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0)
- return op.Reshape(self, complex_size)
+ complex_size = common_ops.merge_dims([*size, 2])
+ return op.Reshape(self, complex_size, allowzero=True)
-@torch_op("aten::view_as")
+@torch_op("aten::view_as", trace_only=True)
def aten_view_as(self: TTensor, other: TTensor2) -> TTensor:
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
size = op.Shape(other)
- return op.Reshape(self, size)
+ return op.Reshape(self, size, allowzero=True)
-@torch_op("aten::view_as_complex")
+@torch_op("aten::view_as_complex", trace_only=True)
def aten_view_as_complex(self: TTensor) -> TTensor:
"""view_as_complex(Tensor(a) self) -> Tensor(a)"""
@@ -8644,7 +9003,7 @@ def aten_view_as_complex(self: TTensor) -> TTensor:
return op.Identity(self)
-@torch_op("aten::view_as_complex_copy")
+@torch_op("aten::view_as_complex_copy", trace_only=True)
def aten_view_as_complex_copy(self: TTensor) -> TTensor:
"""view_as_complex_copy(Tensor self) -> Tensor"""
@@ -8653,7 +9012,7 @@ def aten_view_as_complex_copy(self: TTensor) -> TTensor:
return op.Identity(self)
-@torch_op("aten::view_as_real", complex=True)
+@torch_op("aten::view_as_real", complex=True, trace_only=True)
def aten_view_as_real(self: TTensor) -> TTensor:
"""view_as_real(Tensor(a) self) -> Tensor(a)"""
@@ -8662,7 +9021,7 @@ def aten_view_as_real(self: TTensor) -> TTensor:
return op.Identity(self)
-@torch_op("aten::view_as_real_copy", complex=True)
+@torch_op("aten::view_as_real_copy", complex=True, trace_only=True)
def aten_view_as_real_copy(self: TTensor) -> TTensor:
"""view_as_real_copy(Tensor self) -> Tensor"""
@@ -8671,15 +9030,15 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor:
return op.Identity(self)
-@torch_op("aten::view_copy")
-def aten_view_copy(self: TTensor, size: IntType) -> TTensor:
+@torch_op("aten::view_copy", trace_only=True)
+def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view_copy(Tensor self, SymInt[] size) -> Tensor"""
- size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
+ size = common_ops.merge_dims(size)
return op.Reshape(self, size)
-@torch_op("aten::vstack")
+# Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918
def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
"""vstack(Tensor[] tensors) -> Tensor"""
@@ -8697,7 +9056,15 @@ def reshape_to_2d(tensor):
return op.ConcatFromSequence(tensors_2d, axis=0)
-@torch_op(("aten::where", "aten::where.self"))
+@torch_op(
+ (
+ "aten::where.Scalar",
+ "aten::where.ScalarSelf",
+ "aten::where.ScalarOther",
+ "aten::where.self",
+ ),
+ trace_only=True,
+)
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""
@@ -8710,33 +9077,44 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::zeros")
-def aten_zeros(size: IntType, dtype: int = FLOAT.dtype):
+@torch_op("aten::zeros", trace_only=True)
+def aten_zeros(
+ size: Sequence[INT64],
+ dtype: int = FLOAT.dtype,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+) -> TensorType:
"""zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
+ if dtype == -1:
+ dtype = FLOAT.dtype
- size = op.Cast(size, to=INT64.dtype)
- zero = op.Constant(value_float=0.0)
- zero = op.Cast(zero, to=dtype)
+ zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
+ size = common_ops.merge_dims(size)
return op.Expand(zero, size)
@torch_op("aten::zeros_like", trace_only=True)
-def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor:
+def aten_zeros_like(
+ self: TTensor,
+ dtype: int = -1,
+ layout: str = "",
+ device: str = "",
+ pin_memory: bool = False,
+ memory_format: str = "",
+) -> TTensor:
"""zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.
+ if dtype is None:
+ dtype = -1
if dtype == -1:
zero = op.CastLike(0, self)
else:
zero = op.Cast(0, to=dtype)
- return _aten_zeros_like_onnx(self, zero)
-
-
-@torch_op("aten::zeros_like", private=True)
-def _aten_zeros_like_onnx(self: TTensor, zero) -> TTensor:
shape = op.Shape(self)
return op.Expand(zero, shape)
diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py
index f35b4f611b..ea92dc347d 100644
--- a/onnxscript/function_libs/torch_lib/ops/fft.py
+++ b/onnxscript/function_libs/torch_lib/ops/fft.py
@@ -21,98 +21,33 @@
from onnxscript.onnx_types import TensorType
-@torch_op(
- ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
- private=True,
- complex=True,
- traceable=True,
-)
def _fftn_onnx_normalization(
- self,
- transformed: TFloat,
+ self: TFloat,
normalization: int,
- forward: bool,
- dims: Sequence[int],
-) -> TFloat:
- # Obtain the total_sample_count (n) for normalization
- self_shape = op.Shape(self)
- total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
- total_sample_count = op.CastLike(total_sample_count, transformed)
-
- # Normalize the result
- # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
- # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
- if normalization == 1:
- # "forward" - normalize by 1/n
- if forward:
- result = op.Div(transformed, op.Sqrt(total_sample_count))
- else:
- result = op.Mul(transformed, op.Sqrt(total_sample_count))
- elif normalization == 2:
- # "ortho" - normalize by 1/sqrt(n)
- if forward:
- result = op.Div(transformed, total_sample_count)
- else:
- result = transformed
- else:
- # "backward" - no normalization
- if forward:
- result = transformed
- else:
- result = op.Mul(transformed, total_sample_count)
-
- return result
-
-
-@torch_op(
- ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
- trace_only=True,
- private=True,
- complex=True,
-)
-def _fftn_onnx(
- self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
+ signal_size: INT64,
+ inverse: bool = False,
) -> TFloat:
- """Standard complex to complex or real to complex FFT (forward or backward).
-
- This is a private shared function for implementing the various FFT functions.
-
- Args:
- self: The input tensor.
- dims: The dimensions to apply FFT.
- normalization: The normalization mode.
- inverse: Whether to compute the inverse FFT.
- onesided: Whether to compute the one-sided FFT, which retains only the
- positive frequencies.
-
- Returns:
- The transformed tensor.
- """
-
- # NOTE: trace_only because we need to process each dimension in a loop
- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
-
- # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
- # dimension at the beginning to represent the batch dimension.
- transformed = op.Unsqueeze(self, axes=[0])
-
- # Add 1 to account for the batch dimension when counting axes from the left
- new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
-
- for dim in new_dims[:-1]:
- transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)
-
- # Torch computers one-sided FFT on the last dimension only.
- if onesided:
- transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
+ """Normalize in forward or backward direction."""
+ # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
+ # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
+ # Modes:
+ # 0: no normalization (backward)
+ # 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
+ # 2: divide by signal_size (forward)
+ signal_size = op.CastLike(signal_size, self)
+ if not inverse:
+ # Forward normalization
+ if normalization == 1:
+ self = op.Div(self, op.Sqrt(signal_size))
+ elif normalization == 2:
+ self = op.Div(self, signal_size)
else:
- transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)
-
- # Remove the batch dimension
- transformed = op.Squeeze(transformed, axes=[0])
-
- return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
+ # Backward normalization, accounting for op.DFT already dividing by signal_size
+ if normalization == 0:
+ self = op.Mul(self, signal_size)
+ elif normalization == 1:
+ self = op.Mul(self, op.Sqrt(signal_size))
+ return self
@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
@@ -124,14 +59,34 @@ def aten__fft_c2c(
Standard complex to complex FFT (forward or backward).
"""
- # NOTE: trace_only because we need to negate forward
- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
+ # NOTE: SymInt dim is not supported because DFT-17 needs a static axis
# ONNX DFT input assumes the last dimension is the complex dimension.
- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
- dim = [d - 1 if d < 0 else d for d in dim]
- return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)
+
+ unsqueeze_first_dim = 0 in dim
+ # 1. Add a new dimension for the end and batch dimension, if needed
+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
+ # If needed, add 1 to account for the batch dimension.
+
+ if unsqueeze_first_dim:
+ transformed = op.Unsqueeze(self, axes=[0])
+ dim = [d + 1 for d in dim]
+ else:
+ transformed = self
+
+ for dimension in reversed(dim):
+ transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False)
+ transformed = _fftn_onnx_normalization(
+ transformed,
+ normalization,
+ op.Shape(transformed, start=dimension, end=dimension + 1),
+ not forward,
+ )
+
+ if unsqueeze_first_dim:
+ transformed = op.Squeeze(transformed, axes=[0])
+
+ return transformed
@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
@@ -139,24 +94,52 @@ def aten__fft_c2r(
self: TFloat,
dim: Sequence[int],
normalization: int,
- last_dim_size: INT64, # pylint: disable=unused-argument
+ last_dim_size: INT64,
) -> TFloat:
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
- Complex to real inverse FFT.
+ Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation.
"""
-
- # TODO(justinchuby): Figure out what last_dim_size does
-
- self_rank = len(self.shape)
- # ONNX DFT input assumes the last dimension is the complex dimension.
- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
- dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
- transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
- # Take only the real part
- real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])
-
- return op.Squeeze(real_part, axes=[-1])
+ if len(dim) != 1:
+ raise NotImplementedError("Only one dimension is supported for inverse FFT")
+
+ dimension = dim[0]
+ unsqueeze_first_dim = dimension == 0
+ # 1. Add a new dimension for batch dimension, if needed
+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
+ # If needed, add 1 to account for the batch dimension.
+
+ if unsqueeze_first_dim:
+ transformed = op.Unsqueeze(self, axes=[0])
+ dimension = 1
+ else:
+ transformed = self
+
+ # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
+ # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
+ # place no such restriction on the ONNX side.
+ transformed = op.DFT(
+ transformed,
+ dft_length=last_dim_size,
+ axis=dimension,
+ inverse=True,
+ onesided=False,
+ )
+ transformed = _fftn_onnx_normalization(
+ transformed,
+ normalization,
+ op.Shape(transformed, start=dimension, end=dimension + 1),
+ inverse=True,
+ )
+
+ if unsqueeze_first_dim:
+ transformed = op.Squeeze(transformed, axes=[0])
+
+ # Remove the imaginary part
+ transformed = op.Slice(transformed, [0], [1], [-1])
+ transformed = op.Squeeze(transformed, axes=[-1])
+
+ return transformed
@torch_op("aten::_fft_r2c", trace_only=True)
@@ -168,17 +151,37 @@ def aten__fft_r2c(
Real to complex forward FFT.
"""
- # Add a new dimension at the end
- signal = op.Unsqueeze(self, axes=[-1])
# No need to fill the imaginary part because ONNX DFT accepts real inputs
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
- self_rank = len(self.shape)
- # ONNX DFT input assumes the last dimension is the complex dimension.
- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
- dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
+ unsqueeze_first_dim = 0 in dim
+ # 1. Add a new dimension for the end and batch dimension, if needed
+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
+ # If needed, add 1 to account for the batch dimension.
+
+ if unsqueeze_first_dim:
+ transformed = op.Unsqueeze(self, axes=[0, -1])
+ dim = [d + 1 for d in dim]
+ else:
+ transformed = op.Unsqueeze(self, axes=[-1])
+
+ for idx, dimension in enumerate(reversed(dim)):
+ transformed = _fftn_onnx_normalization(
+ transformed,
+ normalization,
+ op.Shape(transformed, start=dimension, end=dimension + 1),
+ inverse=False,
+ )
+ if idx > 0:
+ transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
+ else:
+ # Torch computes one-sided FFT on the last dimension only.
+ transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=onesided)
+
+ if unsqueeze_first_dim:
+ transformed = op.Squeeze(transformed, axes=[0])
- return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)
+ return transformed
def aten_fft_fft(
diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py
index 7890fb1c0b..c9d870bd86 100644
--- a/onnxscript/function_libs/torch_lib/ops/linalg.py
+++ b/onnxscript/function_libs/torch_lib/ops/linalg.py
@@ -12,17 +12,15 @@
from __future__ import annotations
+import math
from typing import Optional, Sequence
-from onnxscript import BOOL, FLOAT, INT64
-from onnxscript.function_libs.torch_lib.ops import common as common_ops
+from onnxscript import BOOL
from onnxscript.function_libs.torch_lib.registration import torch_op
-from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
+from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TTensor
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType
-IsScalar = common_ops.IsScalar
-
def aten_linalg_cholesky(self: TensorType, upper: bool = False) -> TensorType:
"""linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor"""
@@ -44,13 +42,14 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType:
raise NotImplementedError()
-def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> TensorType:
+def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
"""linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor"""
+ # Same implementation as aten_cross
raise NotImplementedError()
-@torch_op(("aten::linalg_det", "aten::det"))
+@torch_op(("aten::_linalg_det", "aten::linalg_det", "aten::det"))
def aten_linalg_det(A: TFloat) -> TFloat:
"""linalg_det(Tensor A) -> Tensor"""
@@ -326,73 +325,30 @@ def aten_linalg_vector_norm(
if dtype != -1:
self = op.Cast(self, to=dtype)
- if dim is None or (isinstance(dim, tuple) and len(dim) == 0):
+ if dim is None:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
keepdim = False
- return _aten_linalg_vector_norm_no_dim_onnx(self, ord, keepdim)
- else:
- return _aten_linalg_vector_norm_onnx(self, ord, dim, keepdim)
-
-
-@torch_op("aten::linalg_vector_norm", private=True)
-def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool) -> TFloat:
- self_is_scalar = IsScalar(self)
- if self_is_scalar:
- self = op.Unsqueeze(self, axes=[0])
-
- self = op.Abs(self)
- ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
- # TODO(justinchuby): Evaluate IsInf in trace mode
- if op.IsInf(ord, detect_negative=0, detect_positive=1):
- result = op.ReduceMax(self, keepdims=keepdim)
- elif op.IsInf(ord, detect_negative=1, detect_positive=0):
- result = op.ReduceMin(self, keepdims=keepdim)
- elif ord == 0.0: # sum(x!=0) means count non-zero elements
- self_bool = op.Cast(self, to=BOOL.dtype)
- self_0_1 = op.CastLike(self_bool, self)
- result = op.ReduceSum(self_0_1, keepdims=False)
- # TODO(microsoft/onnxruntime#18338): Use ReduceL1/L2 when ONNX Runtime is fixed
else:
- ord_float = op.CastLike(ord, self)
- self_pow = op.Pow(self, ord_float)
- result = op.Pow(op.ReduceSum(self_pow, keepdims=keepdim), op.Div(1.0, ord_float))
-
- if self_is_scalar:
- result = op.Squeeze(result)
-
- return result
-
-
-@torch_op("aten::linalg_vector_norm", private=True)
-def _aten_linalg_vector_norm_onnx(
- self: TFloat, ord: float, dim: INT64, keepdim: bool
-) -> TFloat:
- self_is_scalar = IsScalar(self)
- if self_is_scalar:
- self = op.Unsqueeze(self, axes=[0])
-
- dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
- self = op.Abs(self)
- ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
- # TODO(justinchuby): Evaluate IsInf in trace mode
- if op.IsInf(ord, detect_negative=0, detect_positive=1):
- result = op.ReduceMax(self, dim, keepdims=keepdim)
- elif op.IsInf(ord, detect_negative=1, detect_positive=0):
- result = op.ReduceMin(self, dim, keepdims=keepdim)
+ dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
+
+ if math.isinf(ord):
+ self = op.Abs(self)
+ if ord > 0:
+ return op.ReduceMax(self, dim, keepdims=keepdim)
+ else:
+ return op.ReduceMin(self, dim, keepdims=keepdim)
elif ord == 0.0: # sum(x!=0) means count non-zero elements
self_bool = op.Cast(self, to=BOOL.dtype)
self_0_1 = op.CastLike(self_bool, self)
- result = op.ReduceSum(self_0_1, dim, keepdims=keepdim)
+ return op.ReduceSum(self_0_1, dim, keepdims=keepdim)
elif ord == 1.0:
- result = op.ReduceL1(self, dim, keepdims=keepdim)
+ return op.ReduceL1(self, dim, keepdims=keepdim)
elif ord == 2.0:
- result = op.ReduceL2(self, dim, keepdims=keepdim)
+ return op.ReduceL2(self, dim, keepdims=keepdim)
else:
- ord_float = op.CastLike(ord, self)
- self_pow = op.Pow(self, ord_float)
- result = op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), op.Div(1.0, ord_float))
-
- if self_is_scalar:
- result = op.Squeeze(result)
-
- return result
+ if ord < 0 or ord % 2 != 0:
+ # Not an even integer (could be odd, fractional or negative), use Abs
+ self = op.Abs(self)
+ self_pow = op.Pow(self, ord)
+ exp = op.CastLike(1 / ord, self)
+ return op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), exp)
diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py
index 7730008efb..2a7a46ec28 100644
--- a/onnxscript/function_libs/torch_lib/ops/nn.py
+++ b/onnxscript/function_libs/torch_lib/ops/nn.py
@@ -17,16 +17,14 @@
import math
from typing import Optional, Sequence, Tuple, TypeVar, Union
-import onnx
-
-from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64
+from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
TFloat,
- TFloatOrBFloat16,
TFloatOrUInt8,
+ TInt,
TReal,
TTensor,
)
@@ -36,62 +34,14 @@
_MATH_PI = math.pi
Rank = common_ops.Rank
+_INT64_MAX = 9223372036854775807
+_INT64_MIN = -9223372036854775808
+
# All float types but float32
TFloatUnlessFloat32 = TypeVar("TFloatUnlessFloat32", bound=Union[BFLOAT16, FLOAT16, DOUBLE])
-@torch_op("aten::aten_adaptive_avg_pool1d", traceable=True)
-def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
- """adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor"""
-
- # assert output_size == [1]
- # TODO(justinchuby): Specify input constraints
-
- if Rank(self) == 2:
- # Unbatched case
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
- pooled = op.GlobalAveragePool(self)
- result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
- else:
- result = op.GlobalAveragePool(self)
-
- return result
-
-
-@torch_op("aten::aten_adaptive_avg_pool2d", traceable=True)
-def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
- """adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor"""
-
- # assert output_size == [1, 1]
- # TODO(justinchuby): Specify input constraints
-
- if Rank(self) == 3:
- # Unbatched case
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
- pooled = op.GlobalAveragePool(self)
- result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
- else:
- result = op.GlobalAveragePool(self)
-
- return result
-
-
-@torch_op("aten::aten_adaptive_avg_pool3d", traceable=True)
-def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat:
- """adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor"""
-
- # assert output_size == [1, 1, 1]
- # TODO(justinchuby): Specify input constraints
-
- if Rank(self) == 4:
- # Unbatched case
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
- pooled = op.GlobalAveragePool(self)
- result = op.Squeeze(pooled, op.Constant(value_ints=[0]))
- else:
- result = op.GlobalAveragePool(self)
-
- return result
+# NOTE: Implementations of adaptive_average_pool are handled by torch decomp
def aten_adaptive_max_pool1d(
@@ -206,7 +156,7 @@ def aten_avg_pool2d(
padding: Sequence[int] = (0, 0),
ceil_mode: bool = False,
count_include_pad: bool = True,
- divisor_override: Optional[int] = None, # pylint: disable=unused-argument
+ divisor_override: Optional[int] = None,
) -> TFloat:
"""avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor"""
@@ -267,7 +217,7 @@ def aten_avg_pool3d(
padding: Sequence[int] = (0, 0, 0),
ceil_mode: bool = False,
count_include_pad: bool = True,
- divisor_override: Optional[int] = None, # pylint: disable=unused-argument
+ divisor_override: Optional[int] = None,
) -> TFloat:
"""avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor"""
@@ -343,21 +293,17 @@ def aten_binary_cross_entropy_backward(
raise NotImplementedError()
-@torch_op("aten::celu")
-def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT:
+@torch_op("aten::celu", trace_only=True)
+def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat:
"""celu(Tensor self, Scalar alpha=1.0) -> Tensor"""
- return op.Celu(self, alpha=alpha) # op.Celu only support float32
-
+ if self.dtype != FLOAT.dtype:
+ self_upcasted = op.Cast(self, to=FLOAT.dtype)
-@torch_op("aten::celu", traceable=True)
-def aten_celu_type_promoted(
- self: TFloatUnlessFloat32, alpha: float = 1.0
-) -> TFloatUnlessFloat32:
- """celu(Tensor self, Scalar alpha=1.0) -> Tensor"""
+ # op.Celu only support float32
+ return op.Cast(op.Celu(self_upcasted, alpha=alpha), to=self.dtype)
- self_upcasted = op.Cast(self, to=FLOAT.dtype)
- return op.CastLike(op.Celu(self_upcasted, alpha=alpha), self)
+ return op.Celu(self, alpha=alpha)
@torch_op("aten::col2im", trace_only=True)
@@ -409,15 +355,15 @@ def aten_conv_depthwise3d(
raise NotImplementedError()
-@torch_op("aten::cross_entropy_loss", traceable=True)
+@torch_op("aten::cross_entropy_loss", trace_only=True)
def aten_cross_entropy_loss(
- self: TFloatOrBFloat16,
+ self: TFloat,
target: IntType,
- weight: Optional[TFloatOrBFloat16] = None,
+ weight: Optional[TFloat] = None,
reduction: int = 1, # default is 'mean'
ignore_index: int = -100,
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
-) -> TFloatOrBFloat16:
+) -> TFloat:
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""
if reduction == 0: # "none"
@@ -436,7 +382,7 @@ def aten_cross_entropy_loss(
return result
-@torch_op("aten::elu")
+@torch_op("aten::elu", trace_only=True)
def aten_elu(
self: TFloat,
alpha: float = 1.0,
@@ -445,9 +391,10 @@ def aten_elu(
) -> TFloat:
"""elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor"""
- # del scale
- # del input_scale
- return op.Elu(self, alpha=alpha)
+ input_scale = op.CastLike(input_scale, self)
+ scale = op.CastLike(scale, self)
+ self = op.Mul(self, input_scale)
+ return op.Mul(op.Elu(self, alpha=alpha), scale)
def aten_elu_backward(
@@ -526,34 +473,32 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal:
return result
-@torch_op("aten::gelu", private=True)
def _aten_gelu_approximate_none(self: TReal) -> TReal:
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
# GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)]
- inner = op.Div(self, 1.4142135623730951)
+ inner = op.Div(self, ir.tensor(1.4142135623730951, dtype=self.dtype))
erf = op.Erf(inner)
- inner = op.Add(erf, 1)
- inner = op.Mul(self, inner)
- result = op.Mul(0.5, inner)
+ inner = op.Add(erf, ir.tensor(1, dtype=self.dtype))
+ inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner)
+ result = op.Mul(self, inner)
return result
-@torch_op("aten::gelu", private=True)
def _aten_gelu_approximate_tanh(self: TReal) -> TReal:
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
- cubed = op.Pow(self, 3)
- inner = op.Mul(0.044715, cubed)
+ cubed = op.Pow(self, ir.tensor(3, dtype=self.dtype))
+ inner = op.Mul(ir.tensor(0.044715, dtype=self.dtype), cubed)
inner = op.Add(self, inner)
- # Prefer explicit graph construction over precomputed constants for clarity.
- two_over_pi = op.CastLike(op.Div(2.0, _MATH_PI), self)
- inner = op.Mul(op.Sqrt(two_over_pi), inner)
+ # math.sqrt(2.0/math.pi) = 0.7978845608028654
+ sqrt_two_over_pi = ir.tensor(0.7978845608028654, dtype=self.dtype)
+ inner = op.Mul(sqrt_two_over_pi, inner)
inner = op.Tanh(inner)
- inner = op.Add(inner, 1)
- inner = op.Mul(self, inner)
- result = op.Mul(0.5, inner)
+ inner = op.Add(inner, ir.tensor(1, dtype=self.dtype))
+ inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner)
+ result = op.Mul(self, inner)
return result
@@ -565,10 +510,13 @@ def aten_gelu_backward(
raise NotImplementedError()
-def aten_glu(self: TensorType, dim: int = -1) -> TensorType:
+@torch_op("aten::glu")
+def aten_glu(self: TFloat, dim: int = -1) -> TFloat:
"""glu(Tensor self, int dim=-1) -> Tensor"""
- raise NotImplementedError()
+ first, second = op.Split(self, axis=dim, num_outputs=2)
+ result = op.Mul(first, op.Sigmoid(second))
+ return result
def aten_glu_backward(grad_output: TensorType, self: TensorType, dim: int) -> TensorType:
@@ -590,13 +538,63 @@ def aten_glu_backward_jvp(
raise NotImplementedError()
+@torch_op("aten::group_norm", trace_only=True)
+def aten_group_norm(
+ input: TFloat,
+ num_groups: int,
+ weight: Optional[TFloat] = None,
+ bias: Optional[TFloat] = None,
+ eps: float = 1e-05,
+ cudnn_enabled: bool = True,
+) -> TensorType:
+ """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor"""
+
+ # Actually we don't need N,C,HxW value because the input tensor has that information
+ if weight is None: # Set to 1.0 as default, the shape is Channel size
+ weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2))
+
+ if bias is None: # Set to 0.0 as default, the shape is Channel size
+ bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2))
+
+ # Because onnx.GroupNorm() need size=group for weight and bias
+ # But the torch's aten function's input need size=channel, the size mismatched
+ # So we have to use onnx.InstanceNorm() to simulate
+ neg_1 = op.Constant(value_ints=[-1])
+ # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter
+ group_tensor = op.Reshape(num_groups, neg_1)
+ # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1]
+ shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0)
+ input_reshaped = op.Reshape(input, shape_input)
+ weight_inst_norm = op.Expand(
+ op.CastLike(op.Constant(value_float=1.0), input), group_tensor
+ )
+ bias_inst_norm = op.Expand(op.CastLike(op.Constant(value_float=0.0), input), group_tensor)
+ norm = op.InstanceNormalization(
+ input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
+ )
+ # Reshape back to input's shape
+ norm = op.Reshape(norm, op.Shape(input))
+ # Using the input weight and bias to do affine
+ # But need to unsqueeze to the target shape for broading cast easy
+ input_rank = Rank(input)
+ one = op.Constant(value_int=1)
+ axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one)
+ weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
+ bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
+ weight_full_shape = op.CastLike(weight_full_shape, norm)
+ norm_mul_weight = op.Mul(norm, weight_full_shape)
+ bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight)
+ norm_result = op.Add(norm_mul_weight, bias_full_shape)
+ return norm_result
+
+
def aten_glu_jvp(glu: TensorType, x: TensorType, dx: TensorType, dim: int) -> TensorType:
"""glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor"""
raise NotImplementedError()
-@torch_op("aten::hardsigmoid")
+@torch_op("aten::hardsigmoid", trace_only=True)
def aten_hardsigmoid(self: TFloat) -> TFloat:
"""hardsigmoid(Tensor self) -> Tensor"""
@@ -629,12 +627,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T
return op.Clip(self, min_val, max_val)
+@torch_op("aten::hardtanh_backward", trace_only=True)
def aten_hardtanh_backward(
grad_output: TensorType, self: TensorType, min_val: float, max_val: float
) -> TensorType:
"""hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor"""
- raise NotImplementedError()
+ max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0)
+ min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0)
+ return op.Mul(op.Mul(grad_output, max_mask), min_mask)
def aten_huber_loss(
@@ -653,16 +654,138 @@ def aten_huber_loss_backward(
raise NotImplementedError()
+def _get_im2col_indices_along_dim(
+ input_d: TInt,
+ kernel_size_d: int,
+ dilation_d: int,
+ padding_d: int,
+ stride_d: int,
+):
+ # Input is always 4-D (N, C, H, W)
+ # Calculate indices of sliding blocks along spatial dimension
+ # Slide kernel over input each dim d:
+ # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
+ # with steps = stride
+
+ blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1)))
+
+ # Stride kernel over input and find starting indices along dim d
+ blocks_d_indices = op.Range(0, blocks_d, stride_d)
+ blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0])
+
+ # Apply dilation on kernel and find its indices along dim d
+ kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d)
+ kernel_mask = op.Unsqueeze(kernel_grid, [1])
+
+ # Broadcast and add kernel staring positions (indices) with
+ # kernel_grid along dim d, to get block indices along dim d
+ block_mask = op.Add(blocks_d_indices, kernel_mask)
+
+ return block_mask
+
+
+def _get_im2col_padded_input(input, padding_h, padding_w):
+ # Input is always 4-D tensor (N, C, H, W)
+ # Padding tensor has the following format: (padding_h, padding_w)
+ # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
+ pad = op.Concat(
+ op.Constant(value_ints=[0, 0]),
+ op.Unsqueeze(padding_h, [0]),
+ op.Unsqueeze(padding_w, [0]),
+ op.Constant(value_ints=[0, 0]),
+ op.Unsqueeze(padding_h, [0]),
+ op.Unsqueeze(padding_w, [0]),
+ axis=0,
+ )
+ return op.Pad(input, pad)
+
+
+def _get_im2col_output_shape(input, kernel_h, kernel_w):
+ input_shape = op.Shape(input)
+ batch_dim = op.Gather(input_shape, 0, axis=0)
+ channel_dim = op.Gather(input_shape, 1, axis=0)
+ channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w)
+
+ return op.Concat(
+ op.Unsqueeze(batch_dim, [0]),
+ op.Unsqueeze(channel_unfolded, [0]),
+ op.Constant(value_ints=[-1]),
+ axis=0,
+ )
+
+
+@torch_op("aten::im2col", trace_only=True)
def aten_im2col(
- self: TensorType,
+ self: TReal,
kernel_size: Sequence[int],
- dilation: Sequence[int],
- padding: Sequence[int],
- stride: Sequence[int],
+ dilation: Sequence[int] = (1, 1),
+ padding: Sequence[int] = (0, 0),
+ stride: Sequence[int] = (1, 1),
) -> TensorType:
- """im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor"""
+ """im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor"""
- raise NotImplementedError()
+ input_shape = op.Shape(self)
+ input_h = op.Gather(input_shape, 2, axis=0)
+ input_w = op.Gather(input_shape, 3, axis=0)
+
+ if not isinstance(kernel_size, Sequence):
+ kernel_size = (kernel_size, kernel_size)
+ kernel_sizes = list(kernel_size)
+
+ if not isinstance(dilation, Sequence):
+ dilation = (dilation, dilation)
+ dilations = list(dilation)
+
+ if not isinstance(padding, Sequence):
+ padding = (padding, padding)
+ pads = list(padding)
+
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ strides = list(stride)
+
+ stride_h, stride_w = strides[0], strides[1]
+ padding_h, padding_w = pads[0], pads[1]
+ dilation_h, dilation_w = dilations[0], dilations[1]
+ kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1]
+
+ blocks_row_indices = _get_im2col_indices_along_dim(
+ input_h, kernel_h, dilation_h, padding_h, stride_h
+ )
+ blocks_col_indices = _get_im2col_indices_along_dim(
+ input_w, kernel_w, dilation_w, padding_w, stride_w
+ )
+
+ output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w)
+ padded_input = _get_im2col_padded_input(self, padding_h, padding_w)
+
+ # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
+ # [[[[1., 2., 3.,],
+ # [4., 5., 6.,],
+ # [7., 8., 9.,]]]]
+ # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
+ # [[[[[1., 2., 3.],
+ # [4., 5., 6.]],
+ # [[4., 5., 6.],
+ # [7., 8., 9.]]]]]
+ # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
+ # [[[[[[1., 2.],
+ # [4., 5.]],
+ # [[2., 3.],
+ # [5., 6]]],
+ # [[[4., 5.],
+ # [7., 8.]],
+ # [[5., 6.],
+ # [8., 9.]]]]]]
+ # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
+ # [[[1., 2., 4., 5.],
+ # [2., 3., 5., 6.],
+ # [4., 5., 7., 8.],
+ # [5., 6., 8., 9.]]]
+ output = op.Gather(padded_input, blocks_row_indices, axis=2)
+ output = op.Gather(output, blocks_col_indices, axis=4)
+ output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5])
+ return op.Reshape(output, output_shape)
def aten_infinitely_differentiable_gelu_backward(
@@ -679,8 +802,8 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te
raise NotImplementedError()
-@torch_op("aten::leaky_relu")
-def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16:
+@torch_op("aten::leaky_relu", trace_only=True)
+def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""
return op.LeakyRelu(self, alpha=negative_slope)
@@ -694,31 +817,27 @@ def aten_leaky_relu_backward(
raise NotImplementedError()
-@torch_op("aten::linear")
-def aten_linear(input: TFloat, weight: TFloat) -> TFloat:
+@torch_op("aten::linear", trace_only=True)
+def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat:
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""
- # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
- # Optimizers may consider this path and replace it with Gemm
- # We do not use Gemm here because input can have batch dimensions, which Gemm does not support
- weight_transposed = op.Transpose(weight, perm=[1, 0])
- return op.MatMul(input, weight_transposed)
-
-
-@torch_op("aten::linear")
-def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:
- """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""
-
- # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
- # Optimizers may consider this path and replace it with Gemm
- # We do not use Gemm here because input can have batch dimensions, which Gemm does not support
- weight_transposed = op.Transpose(weight, perm=[1, 0])
+ if len(input.shape) == 2 and len(weight.shape) == 2:
+ # Use Gemm for the rank 2 input
+ return op.Gemm(input, weight, bias, transB=True)
+ if len(weight.shape) == 1:
+ # In rare cases the weight can be 1d
+ weight_transposed = op.Unsqueeze(weight, [1])
+ else:
+ assert len(weight.shape) == 2
+ weight_transposed = op.Transpose(weight, perm=[1, 0])
mul = op.MatMul(input, weight_transposed)
+ if bias is None:
+ return mul
return op.Add(mul, bias)
@torch_op("aten::log_sigmoid")
-def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+def aten_log_sigmoid(self: TFloat) -> TFloat:
"""log_sigmoid(Tensor self) -> Tensor"""
return op.Log(op.Sigmoid(self))
@@ -871,7 +990,6 @@ def aten_max_pool2d(
return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3)
-@torch_op("internal::max_pool", private=True, traceable=True)
def _aten_max_pool_onnx(
self: TFloatOrUInt8,
kernel_shape: Sequence[int],
@@ -883,7 +1001,7 @@ def _aten_max_pool_onnx(
) -> TFloatOrUInt8:
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
+ self = op.Unsqueeze(self, [0])
pool_result, _ = op.MaxPool(
self,
@@ -895,7 +1013,7 @@ def _aten_max_pool_onnx(
)
if self_rank_is_unbatched_rank:
- pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
+ pool_result = op.Squeeze(pool_result, [0])
return pool_result
@@ -1003,7 +1121,6 @@ def aten_max_pool3d_with_indices(
)
-@torch_op("internal::max_pool_with_indices", private=True, traceable=True)
def _aten_max_pool_with_indices_onnx(
self: TFloatOrUInt8,
kernel_size: Sequence[int],
@@ -1018,7 +1135,7 @@ def _aten_max_pool_with_indices_onnx(
) -> Tuple[TFloatOrUInt8, INT64]:
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
if self_rank_is_unbatched_rank:
- self = op.Unsqueeze(self, axes=0)
+ self = op.Unsqueeze(self, axes=[0])
pool_result, indices = op.MaxPool(
self,
@@ -1073,8 +1190,8 @@ def _aten_max_pool_with_indices_onnx(
indices = op.Sub(indices, delta)
if self_rank_is_unbatched_rank:
- pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
- indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
+ pool_result = op.Squeeze(pool_result, [0])
+ indices = op.Squeeze(indices, [0])
return (pool_result, indices)
@@ -1159,7 +1276,7 @@ def aten_mkldnn_reorder_conv3d_weight(
raise NotImplementedError()
-@torch_op("aten::mse_loss", traceable=True)
+@torch_op("aten::mse_loss", trace_only=True)
def aten_mse_loss(self: TReal, target: TReal, reduction: int = 1) -> TReal:
"""mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor"""
# FIXME: When reduction=0, the shape(result) will be different than other case
@@ -1235,10 +1352,11 @@ def aten_multilabel_margin_loss_forward(
raise NotImplementedError()
-@torch_op("aten::nll_loss", traceable=True)
+@torch_op("aten::nll_loss", trace_only=True)
def aten_nll_loss(
self: TFloat,
target: INT64,
+ weight: Optional[TFloat] = None,
reduction: int = 1,
ignore_index: int = -100,
) -> TFloat:
@@ -1246,62 +1364,22 @@ def aten_nll_loss(
self_rank_is_1 = Rank(self) == 1
if self_rank_is_1: # self rank should be at least 2
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
+ self = op.Unsqueeze(self, [0])
rank_target = Rank(target)
if rank_target == 0: # target rank should be at least 1
- target = op.Unsqueeze(target, op.Constant(value_ints=[0]))
+ target = op.Unsqueeze(target, [0])
if reduction == 0:
- result = op.NegativeLogLikelihoodLoss(
- self, target, ignore_index=ignore_index, reduction="none"
- )
+ reduction_str = "none"
elif reduction == 1:
- result = op.NegativeLogLikelihoodLoss(
- self, target, ignore_index=ignore_index, reduction="mean"
- )
+ reduction_str = "mean"
else: # assert reduction == 2
- result = op.NegativeLogLikelihoodLoss(
- self, target, ignore_index=ignore_index, reduction="sum"
- )
-
- if self_rank_is_1:
- result = op.Squeeze(result)
-
- return result
-
-
-@torch_op("aten::nll_loss", traceable=True)
-def aten_nll_loss_weight(
- self: TFloat,
- target: INT64,
- weight: TFloat,
- reduction: int = 1,
- ignore_index: int = -100,
-) -> TFloat:
- """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"""
-
- self_rank_is_1 = Rank(self) == 1
- if self_rank_is_1:
- # self rank should be at least 2
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
-
- rank_target = Rank(target)
- if rank_target == 0: # target rank should be at least 1
- target = op.Unsqueeze(target, op.Constant(value_ints=[0]))
+ reduction_str = "sum"
- if reduction == 0:
- result = op.NegativeLogLikelihoodLoss(
- self, target, weight, ignore_index=ignore_index, reduction="none"
- )
- elif reduction == 1:
- result = op.NegativeLogLikelihoodLoss(
- self, target, weight, ignore_index=ignore_index, reduction="mean"
- )
- else:
- result = op.NegativeLogLikelihoodLoss(
- self, target, weight, ignore_index=ignore_index, reduction="sum"
- )
+ result = op.NegativeLogLikelihoodLoss(
+ self, target, weight, ignore_index=ignore_index, reduction=reduction_str
+ )
if self_rank_is_1:
result = op.Squeeze(result)
@@ -1361,16 +1439,23 @@ def aten_nll_loss_backward(
raise NotImplementedError()
+@torch_op("aten::nll_loss_forward", trace_only=True)
def aten_nll_loss_forward(
self: TensorType,
target: TensorType,
weight: Optional[TensorType],
reduction: int,
- ignore_index: INT64,
+ ignore_index: int,
) -> tuple[TensorType, TensorType]:
"""nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)"""
- raise NotImplementedError()
+ output = aten_nll_loss(self, target, weight, reduction, ignore_index)
+ # FIXME: Fake a total_weight tensor for now. It should be different based on weight, reduction and ignore_index
+ if weight is None:
+ total_weight = op.CastLike(op.Size(output), self)
+ else:
+ total_weight = op.CastLike(op.Size(output), weight)
+ return output, total_weight
def aten_nll_loss_nd(
@@ -1391,12 +1476,56 @@ def aten_one_hot(self: TensorType, num_classes: int = -1) -> TensorType:
raise NotImplementedError()
+def _process_padding(padding: Sequence[INT64 | int], rank: int) -> INT64:
+ """Convert PyTorch padding for ONNX Pad."""
+ assert isinstance(padding, (list, tuple))
+ if all(isinstance(pad, int) for pad in padding):
+ paddings = padding
+ zeros = [0] * (rank * 2 - len(paddings))
+ paddings = [*paddings, *zeros]
+ paddings = paddings[-2::-2] + paddings[-1::-2]
+ return op.Constant(value=ir.tensor(paddings, dtype=ir.DataType.INT64))
+ else:
+ paddings = []
+ for pad in padding:
+ if isinstance(pad, int):
+ paddings.append(op.Constant(value_ints=[pad]))
+ else:
+ # Dynamic value
+ paddings.append(op.Reshape(pad, [-1]))
+ # Create a series of 1d zero tensors
+ zero = op.Constant(value_ints=[0])
+ zeros = [zero] * (rank * 2 - len(paddings))
+ paddings = [*paddings, *zeros]
+ # Interleave the padding values
+ paddings = paddings[-2::-2] + paddings[-1::-2]
+ return op.Concat(*paddings, axis=0)
+
+
+@torch_op("aten::pad", trace_only=True)
def aten_pad(
- self: TensorType, pad: INT64, mode: str = "constant", value: Optional[float] = None
+ self: TensorType,
+ pad: Sequence[INT64],
+ mode: str = "constant",
+ value: Optional[float] = None,
) -> TensorType:
"""pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor"""
- raise NotImplementedError()
+ rank = len(self.shape)
+ paddings = _process_padding(pad, rank)
+ const_value = (
+ op.Constant(value=ir.tensor(value, dtype=ir.DataType(self.dtype)))
+ if value is not None
+ else None
+ )
+ onnx_mode = {
+ "constant": "constant",
+ "reflect": "reflect",
+ "replicate": "edge",
+ "circular": "wrap",
+ }[mode]
+
+ return op.Pad(self, paddings, constant_value=const_value, mode=onnx_mode)
def aten_pad_sequence(
@@ -1407,18 +1536,15 @@ def aten_pad_sequence(
raise NotImplementedError()
-@torch_op("aten::reflection_pad1d")
-def aten_reflection_pad1d(self: TFloat, padding: INT64) -> TFloat:
+@torch_op("aten::reflection_pad1d", trace_only=True)
+def aten_reflection_pad1d(self: TFloat, padding: Sequence[INT64]) -> TFloat:
"""reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor"""
# assert len(padding) == 2
# Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y]
- start = op.Slice(padding, [0], [1], axes=[0])
- end = op.Slice(padding, [1], [2], axes=[0])
- padding_onnx = op.Concat(
- op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0
- )
- return op.Pad(self, padding_onnx, mode="reflect")
+ rank = len(self.shape)
+ paddings = _process_padding(padding, rank)
+ return op.Pad(self, paddings, mode="reflect")
def aten_reflection_pad1d_backward(
@@ -1429,37 +1555,12 @@ def aten_reflection_pad1d_backward(
raise NotImplementedError()
-@torch_op("aten::reflection_pad2d")
-def aten_reflection_pad2d(self: TTensor, padding: INT64) -> TTensor:
+@torch_op("aten::reflection_pad2d", trace_only=True)
+def aten_reflection_pad2d(self: TTensor, padding: Sequence[INT64]) -> TTensor:
"""reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor"""
- # Convert torch padding format to onnx padding format
- # Python code is:
- # dim = len(self.shape)
- # paddings = list(padding[:]) + [0] * (dim * 2 - len(padding))
- # paddings = paddings[-2::-2] + paddings[-1::-2]
-
- neg_1 = op.Constant(value_ints=[-1])
- zero = op.Constant(value_ints=[0])
- # [0] * (rank * 2 - len(padding))
- rank = Rank(self)
- zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1)
- zeros = op.Expand(zero, zero_count)
- # list(padding[:]) + [0] * (dim * 2 - len(padding))
- torch_paddings = op.Concat(padding, zeros, axis=0)
- # paddings[-2::-2]
- size_d = op.Size(torch_paddings)
- steps = op.Constant(value_ints=[-2])
- starts = steps
- ends = op.Sub(starts, size_d)
- odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
- # paddings[-1::-2]
- starts = neg_1
- ends = op.Sub(starts, size_d)
- even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
- # paddings[-2::-2] + paddings[-1::-2]
- onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
-
- return op.Pad(self, onnx_padding, mode="reflect")
+ rank = len(self.shape)
+ paddings = _process_padding(padding, rank)
+ return op.Pad(self, paddings, mode="reflect")
def aten_reflection_pad2d_backward(
@@ -1470,10 +1571,12 @@ def aten_reflection_pad2d_backward(
raise NotImplementedError()
-def aten_reflection_pad3d(self: TensorType, padding: INT64) -> TensorType:
+@torch_op("aten::reflection_pad3d", trace_only=True)
+def aten_reflection_pad3d(self: TensorType, padding: Sequence[INT64]) -> TensorType:
"""reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor"""
-
- raise NotImplementedError()
+ rank = len(self.shape)
+ paddings = _process_padding(padding, rank)
+ return op.Pad(self, paddings, mode="reflect")
def aten_reflection_pad3d_backward(
@@ -1484,14 +1587,14 @@ def aten_reflection_pad3d_backward(
raise NotImplementedError()
-@torch_op("aten::relu")
+@torch_op("aten::relu", trace_only=True)
def aten_relu(self: TReal) -> TReal:
"""relu(Tensor self) -> Tensor"""
return op.Relu(self)
-@torch_op("aten::relu6", traceable=True)
+@torch_op("aten::relu6", trace_only=True)
def aten_relu6(self: TReal) -> TReal:
"""relu6(Tensor self) -> Tensor"""
@@ -1499,18 +1602,13 @@ def aten_relu6(self: TReal) -> TReal:
return op.Min(op.Relu(self), six)
-@torch_op("aten::replication_pad1d")
-def aten_replication_pad1d(self: TensorType, padding: INT64) -> TensorType:
+@torch_op("aten::replication_pad1d", trace_only=True)
+def aten_replication_pad1d(self: TensorType, padding: Sequence[INT64]) -> TensorType:
"""replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor"""
- # assert len(padding) == 2
- # Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y]
- start = op.Slice(padding, [0], [1], axes=[0])
- end = op.Slice(padding, [1], [2], axes=[0])
- padding_onnx = op.Concat(
- op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0
- )
- return op.Pad(self, padding_onnx, mode="edge")
+ rank = len(self.shape)
+ paddings = _process_padding(padding, rank)
+ return op.Pad(self, paddings, mode="edge")
def aten_replication_pad1d_backward(
@@ -1521,32 +1619,13 @@ def aten_replication_pad1d_backward(
raise NotImplementedError()
-@torch_op("aten::replication_pad2d")
-def aten_replication_pad2d(self: TTensor, padding: INT64) -> TTensor:
+@torch_op("aten::replication_pad2d", trace_only=True)
+def aten_replication_pad2d(self: TTensor, padding: Sequence[INT64]) -> TTensor:
"""replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor"""
- neg_1 = op.Constant(value_ints=[-1])
- zero = op.Constant(value_ints=[0])
- # [0] * (rank * 2 - len(padding))
- rank = Rank(self)
- zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1)
- zeros = op.Expand(zero, zero_count)
- # list(padding[:]) + [0] * (dim * 2 - len(padding))
- torch_paddings = op.Concat(padding, zeros, axis=0)
- # paddings[-2::-2]
- size_d = op.Size(torch_paddings)
- steps = op.Constant(value_ints=[-2])
- starts = steps
- ends = op.Sub(starts, size_d)
- odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
- # paddings[-1::-2]
- starts = neg_1
- ends = op.Sub(starts, size_d)
- even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
- # paddings[-2::-2] + paddings[-1::-2]
- onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
-
- return op.Pad(self, onnx_padding, mode="edge")
+ rank = len(self.shape)
+ paddings = _process_padding(padding, rank)
+ return op.Pad(self, paddings, mode="edge")
def aten_replication_pad2d_backward(
@@ -1557,32 +1636,13 @@ def aten_replication_pad2d_backward(
raise NotImplementedError()
-@torch_op("aten::replication_pad3d")
-def aten_replication_pad3d(self: TTensor, padding: INT64) -> TTensor:
+@torch_op("aten::replication_pad3d", trace_only=True)
+def aten_replication_pad3d(self: TTensor, padding: Sequence[INT64]) -> TTensor:
"""replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor"""
- neg_1 = op.Constant(value_ints=[-1])
- zero = op.Constant(value_ints=[0])
- # [0] * (rank * 2 - len(padding))
- rank = Rank(self)
- zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1)
- zeros = op.Expand(zero, zero_count)
- # list(padding[:]) + [0] * (dim * 2 - len(padding))
- torch_paddings = op.Concat(padding, zeros, axis=0)
- # paddings[-2::-2]
- size_d = op.Size(torch_paddings)
- steps = op.Constant(value_ints=[-2])
- starts = steps
- ends = op.Sub(starts, size_d)
- odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
- # paddings[-1::-2]
- starts = neg_1
- ends = op.Sub(starts, size_d)
- even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
- # paddings[-2::-2] + paddings[-1::-2]
- onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
-
- return op.Pad(self, onnx_padding, mode="edge")
+ rank = len(self.shape)
+ paddings = _process_padding(padding, rank)
+ return op.Pad(self, paddings, mode="edge")
def aten_replication_pad3d_backward(
@@ -1620,7 +1680,6 @@ def aten_rrelu_with_noise_backward(
raise NotImplementedError()
-@torch_op("aten::scaled_dot_product_attention", private=True)
def _causal_attention_mask(query: TFloat, key: TFloat) -> TFloat:
"""Create a causal mask for the given query and key tensors.
@@ -1636,20 +1695,30 @@ def _causal_attention_mask(query: TFloat, key: TFloat) -> TFloat:
Returns:
Tensor of shape [L, S]
"""
- target_length = op.Shape(query)[-2:-1]
- source_length = op.Shape(key)[-2:-1]
+ q_shape = op.Shape(query)
+ k_shape = op.Shape(key)
+
+ target_length = op.Slice(
+ q_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1])
+ )
+ source_length = op.Slice(
+ k_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1])
+ )
# attn_mask = torch.ones(L, S) := {
size = op.Concat(target_length, source_length, axis=0)
- attn_mask = op.Expand(1.0, size)
+ attn_mask = op.Expand(op.Constant(value_float=1.0), size)
# }
attn_mask = op.Trilu(attn_mask, upper=0)
# The causal mask has 0s in the lower triangle and -inf in the upper triangle.
- attn_mask = op.Where(op.Equal(attn_mask, 0.0), op.Constant(value_float=-float("inf")), 0.0)
+ attn_mask = op.Where(
+ op.Equal(attn_mask, op.Constant(value_float=0.0)),
+ op.Constant(value_float=-float("inf")),
+ op.Constant(value_float=0.0),
+ )
attn_mask = op.CastLike(attn_mask, query)
return attn_mask
-@torch_op("aten::scaled_dot_product_attention", private=True)
def _attention_scale(query: TFloat) -> TFloat:
"""Calculate the scale factor for the attention result.
@@ -1659,22 +1728,85 @@ def _attention_scale(query: TFloat) -> TFloat:
Returns:
Scalar scale factor := 1 / math.sqrt(query.size(-1))
"""
- embedding_size = op.CastLike(op.Shape(query)[-1], query)
- scale = op.Div(1.0, op.Sqrt(embedding_size))
+ q_shape = op.Shape(query)
+ q_last_dim = op.Gather(q_shape, op.Constant(value_ints=[-1]))
+ embedding_size = op.CastLike(q_last_dim, query)
+ one = op.Constant(value_float=1.0)
+ cast_one = op.CastLike(one, query)
+ scale = op.Div(cast_one, op.Sqrt(embedding_size))
return scale
+def _attention_repeat_kv_for_group_query(
+ query: TFloat, key: TFloat, value: TFloat
+) -> Tuple[TFloat, TFloat]:
+ """Expand key and value for group query attention.
+
+ repeat_interleave is applied on key and value to match the number of heads in query.
+
+ Args:
+ query: Tensor of shape [B, q_num_heads, q_S, E]
+ key: Tensor of shape [B, k_num_heads, kv_S, E]
+ value: Tensor of shape [B, v_num_heads, kv_S, E]
+
+ Returns:
+ Tuple of (expanded_key, expanded_value) where:
+ - expanded_key: Tensor of shape [B, q_num_heads, kv_S, E]
+ - expanded_value: Tensor of shape [B, q_num_heads, kv_S, E
+ """
+
+ assert (
+ query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0
+ ), (
+ "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
+ )
+
+ # NOTE: QKV are expected to be 4D tensors
+
+ batch_size = op.Shape(query, start=0, end=1) # [B]
+ q_num_heads = op.Shape(query, start=1, end=2) # [Hq]
+ kv_num_heads = op.Shape(key, start=1, end=2) # [Hk]
+ qk_head_size = op.Shape(key, start=3, end=4) # [Dk]
+ v_head_size = op.Shape(value, start=3, end=4) # [Dv]
+ new_kv_seq_len = op.Shape(key, start=2, end=3) # [T]
+
+ interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk
+ two = op.Constant(value_int=2)
+ k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk]
+ v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv]
+
+ k_expand_shape = op.Concat(
+ batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0
+ )
+ k_expand = op.Expand(k_unsqueezed, k_expand_shape)
+ v_expand_shape = op.Concat(
+ batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0
+ )
+ v_expand = op.Expand(v_unsqueezed, v_expand_shape)
+
+ k_attention_shape = op.Concat(
+ batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0
+ )
+ v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0)
+
+ expanded_key = op.Reshape(k_expand, k_attention_shape)
+ expanded_value = op.Reshape(v_expand, v_attention_shape)
+
+ return expanded_key, expanded_value
+
+
@torch_op("aten::scaled_dot_product_attention", trace_only=True)
def aten_scaled_dot_product_attention(
query: TFloat,
key: TFloat,
value: TFloat,
- attn_mask: Optional[TFloat] = None,
+ attn_mask: Optional[TensorType] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
+ enable_gqa: bool = False,
) -> TFloat:
- """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
+ """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -1690,9 +1822,13 @@ def aten_scaled_dot_product_attention(
L is the target sequence length, S is the source sequence length, and E is the embedding size.
"""
# Use trace_only to handle optional inputs
- assert (not is_causal) or (
- is_causal and attn_mask is None
- ), "is_causal and attn_mask cannot be set at the same time"
+ assert (not is_causal) or (is_causal and attn_mask is None), (
+ "is_causal and attn_mask cannot be set at the same time"
+ )
+
+ assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, (
+ "only 4D query, key, and value are supported"
+ )
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if scale is None:
@@ -1702,17 +1838,28 @@ def aten_scaled_dot_product_attention(
if is_causal:
attn_mask = _causal_attention_mask(query, key)
+ if enable_gqa:
+ key, value = _attention_repeat_kv_for_group_query(query, key, value)
+ else:
+ assert query.shape[1] == key.shape[1] == value.shape[1], (
+ "SDPA (MHA) requires q_num_heads = kv_num_heads"
+ )
+
if attn_mask is None:
return _aten_scaled_dot_product_attention_no_mask_onnx(
query, key, value, scale, dropout_p
)
+ if attn_mask.dtype == ir.DataType.BOOL:
+ return _aten_scaled_dot_product_attention_bool_mask_onnx(
+ query, key, value, attn_mask, scale, dropout_p
+ )
+
return _aten_scaled_dot_product_attention_float_mask_onnx(
query, key, value, attn_mask, scale, dropout_p
)
-@torch_op("aten::_scaled_dot_product_flash_attention", private=True)
def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(
query: TFloat,
) -> Tuple[FLOAT, INT64, INT64, FLOAT]:
@@ -1720,15 +1867,11 @@ def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(
op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3])
)
logsumexp = op.Expand(0.0, query_first_three_dims)
- # TODO: Eliminate `make_tensor` usage when ORT supports empty tensor.
- empty_tensor_int = op.Cast(
- op.ConstantOfShape(
- op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
- ),
- to=INT64.dtype,
+ empty_tensor_int = op.ConstantOfShape(
+ op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)
empty_tensor_float = op.ConstantOfShape(
- op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], []))
+ op.Constant(value=ir.tensor([], dtype=ir.DataType.FLOAT))
)
empty_int = op.Constant(value_int=0)
@@ -1742,7 +1885,7 @@ def aten__scaled_dot_product_flash_attention(
value: TFloat,
dropout_p: float = 0.0,
is_causal: bool = False,
- return_debug_mask: bool = False, # pylint: disable=unused-argument
+ return_debug_mask: bool = False,
scale: Optional[float] = None,
) -> Tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]:
"""_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1779,7 +1922,6 @@ def aten__scaled_dot_product_flash_attention(
)
-@torch_op("aten::_scaled_dot_product_efficient_attention", private=True, traceable=True)
def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
query: TFloat,
compute_log_sumexp: bool,
@@ -1788,9 +1930,9 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
query = op.Transpose(query, perm=[0, 2, 1, 3])
query_shape = op.Shape(query)
- query_first_dims = query_shape[:1]
- query_second_dims = query_shape[1:2]
- num_heads = query_shape[-2:-1]
+ query_first_dims = op.Slice(query_shape, op.Constant(value_ints=[_INT64_MIN]), [1])
+ query_second_dims = op.Slice(query_shape, [1], [2])
+ num_heads = op.Slice(query_shape, [-2], [-1])
if compute_log_sumexp:
logsumexp_dim = op.Cast(
@@ -1803,22 +1945,50 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0))
# See Note [Seed and Offset]:
- empty_tensor_int = op.Cast(
- op.ConstantOfShape(
- op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
- ),
- to=INT64.dtype,
+ empty_tensor_int = op.ConstantOfShape(
+ op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)
return logsum_exp, empty_tensor_int
+@torch_op("aten::_scaled_dot_product_flash_attention_for_cpu", trace_only=True)
+def aten__scaled_dot_product_flash_attention_for_cpu(
+ query: TFloat,
+ key: TFloat,
+ value: TFloat,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ attn_mask: Optional[TFloat] = None,
+ scale: Optional[float] = None,
+) -> Tuple[TFloat, FLOAT]:
+ """_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)"""
+ result = aten_scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ )
+ query_shape = op.Shape(query)
+ query_first_dims = op.Slice(query_shape, [0], [1])
+ query_second_dims = op.Slice(query_shape, [1], [2])
+ num_heads = op.Slice(query_shape, [-2], [-1])
+ logsumexp_dim = op.Cast(
+ op.Ceil(op.Cast(query_second_dims, to=FLOAT.dtype) / 32.0) * 32.0, to=INT64.dtype
+ )
+ logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, logsumexp_dim, axis=0))
+ return result, logsum_exp
+
+
@torch_op("aten::_scaled_dot_product_efficient_attention", trace_only=True)
def aten__scaled_dot_product_efficient_attention(
query: TFloat,
key: TFloat,
value: TFloat,
- attn_bias: Optional[TFloat], # pylint: disable=unused-argument
+ attn_bias: Optional[TFloat],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
@@ -1827,7 +1997,7 @@ def aten__scaled_dot_product_efficient_attention(
"""_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)"""
result = aten_scaled_dot_product_attention(
- query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale
+ query, key, value, attn_bias, dropout_p=dropout_p, is_causal=is_causal, scale=scale
)
# The followings are not comsumed by the graph.
@@ -1846,58 +2016,6 @@ def aten__scaled_dot_product_efficient_attention(
)
-@torch_op("aten::scaled_dot_product_attention", trace_only=True)
-def aten_scaled_dot_product_attention_bool_mask(
- query: TFloat,
- key: TFloat,
- value: TFloat,
- attn_mask: Optional[BOOL] = None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- scale: Optional[float] = None,
-) -> TFloat:
- """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
-
- Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
-
- Equivalent to the PyTorch code::
- scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
- attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
- attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
- attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
- attn_weight = torch.dropout(attn_weight, dropout_p)
- return attn_weight @ V
-
- where Q, K, V are the query, key, and value tensors, respectively.
- L is the target sequence length, S is the source sequence length, and E is the embedding size.
- """
- # Use trace_only to handle optional inputs
- assert (not is_causal) or (
- is_causal and attn_mask is None
- ), "is_causal and attn_mask cannot be set at the same time"
-
- if scale is None:
- scale = _attention_scale(query)
- scale = op.CastLike(scale, query)
-
- if is_causal:
- attn_mask = _causal_attention_mask(query, key)
- # The causal mask is always float
- return _aten_scaled_dot_product_attention_float_mask_onnx(
- query, key, value, attn_mask, scale, dropout_p
- )
-
- if attn_mask is None:
- return _aten_scaled_dot_product_attention_no_mask_onnx(
- query, key, value, scale, dropout_p
- )
-
- return _aten_scaled_dot_product_attention_bool_mask_onnx(
- query, key, value, attn_mask, scale, dropout_p
- )
-
-
-@torch_op("aten::scaled_dot_product_attention", private=True)
def _aten_scaled_dot_product_attention_no_mask_onnx(
query: TFloat,
key: TFloat,
@@ -1907,9 +2025,9 @@ def _aten_scaled_dot_product_attention_no_mask_onnx(
) -> TFloat:
# Swap the last two axes of key
key_shape = op.Shape(key)
- key_last_dim = key_shape[-1:]
- key_second_last_dim = key_shape[-2:-1]
- key_first_dims = key_shape[:-2]
+ key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX]))
+ key_second_last_dim = op.Slice(key_shape, [-2], [-1])
+ key_first_dims = op.Slice(key_shape, op.Constant(value_ints=[_INT64_MIN]), [-2])
# Contract the dimensions that are not the last two so we can transpose
# with a static permutation.
key_squeezed_shape = op.Concat(
@@ -1928,11 +2046,11 @@ def _aten_scaled_dot_product_attention_no_mask_onnx(
op.MatMul(query_scaled, key_transposed_scaled),
axis=-1,
)
- attn_weight, _ = op.Dropout(attn_weight, dropout_p)
+ if dropout_p != 0:
+ attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)
-@torch_op("aten::scaled_dot_product_attention", private=True)
def _aten_scaled_dot_product_attention_bool_mask_onnx(
query: TFloat,
key: TFloat,
@@ -1943,9 +2061,9 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
) -> TFloat:
# Swap the last two axes of key
key_shape = op.Shape(key)
- key_last_dim = key_shape[-1:]
- key_second_last_dim = key_shape[-2:-1]
- key_first_dims = key_shape[:-2]
+ key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX]))
+ key_second_last_dim = op.Slice(key_shape, [-2], [-1])
+ key_first_dims = op.Slice(key_shape, op.Constant(value_ints=[_INT64_MIN]), [-2])
# Contract the dimensions that are not the last two so we can transpose
# with a static permutation.
key_squeezed_shape = op.Concat(
@@ -1961,16 +2079,24 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
query_scaled = op.Mul(query, op.Sqrt(scale))
key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
- attn_mask = op.Where(attn_mask, 0.0, op.Constant(value_float=-float("inf")))
+ zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype))
+ neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype))
+ attn_mask = op.Where(attn_mask, zero, neg_inf)
attn_weight = op.Softmax(
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
axis=-1,
)
- attn_weight, _ = op.Dropout(attn_weight, dropout_p)
+ # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
+ # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
+ # This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match
+ # the behavior of PyTorch with boolean masks.
+ # Reference: https://github.com/pytorch/pytorch/issues/103749
+ attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight)
+ if dropout_p != 0:
+ attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)
-@torch_op("aten::scaled_dot_product_attention", private=True)
def _aten_scaled_dot_product_attention_float_mask_onnx(
query: TFloat,
key: TFloat,
@@ -1981,9 +2107,9 @@ def _aten_scaled_dot_product_attention_float_mask_onnx(
) -> TFloat:
# Swap the last two axes of key
key_shape = op.Shape(key)
- key_last_dim = key_shape[-1:]
- key_second_last_dim = key_shape[-2:-1]
- key_first_dims = key_shape[:-2]
+ key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX]))
+ key_second_last_dim = op.Slice(key_shape, [-2], [-1])
+ key_first_dims = op.Slice(key_shape, op.Constant(value_ints=[_INT64_MIN]), [-2])
# Contract the dimensions that are not the last two so we can transpose
# with a static permutation.
key_squeezed_shape = op.Concat(
@@ -2002,7 +2128,8 @@ def _aten_scaled_dot_product_attention_float_mask_onnx(
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
axis=-1,
)
- attn_weight, _ = op.Dropout(attn_weight, dropout_p)
+ if dropout_p != 0:
+ attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)
@@ -2012,10 +2139,11 @@ def aten_sigmoid_backward(grad_output: TensorType, output: TensorType) -> Tensor
raise NotImplementedError()
-def aten_silu(self: TensorType) -> TensorType:
+@torch_op("aten::silu", trace_only=True)
+def aten_silu(self: TFloat) -> TFloat:
"""silu(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Mul(self, op.Sigmoid(self))
def aten_silu_backward(grad_output: TensorType, self: TensorType) -> TensorType:
@@ -2202,27 +2330,20 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str:
return "align_corners" if align_corners else "pytorch_half_pixel"
-@torch_op(
- (
- "aten::upsample_bicubic2d",
- "aten::upsample_bilinear2d",
- "aten::upsample_nearest1d",
- "aten::upsample_nearest2d",
- "aten::upsample_nearest3d",
- ),
- private=True,
-)
def _aten_upsample_output_size(
self: TReal,
output_size: INT64,
mode: str,
coordinate_transformation_mode: str,
+ antialias: int = 0,
) -> TReal:
- self_shape = op.Shape(self)
- starts = op.Constant(value_ints=[0])
- ends = op.Constant(value_ints=[2])
- batch_channel = op.Slice(self_shape, starts, ends)
- output_size = op.Concat(batch_channel, output_size, axis=0)
+ batch_and_channel = op.Shape(self, end=2, start=0)
+ # When output_size is passed in as a list of integers, the torch.onnx
+ # graph builder when handling op.Concat may fail
+ # to determine the output type. We cast it to INT64 to ensure the output
+ output_size = op.Cast(output_size, to=INT64.dtype)
+ # Append the batch and channel dimensions to the requested output size
+ output_size = op.Concat(batch_and_channel, output_size, axis=0)
return op.Resize(
self,
None,
@@ -2231,25 +2352,28 @@ def _aten_upsample_output_size(
mode=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
+ antialias=antialias,
)
-@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True)
def _aten_upsample_scales(
self: TReal,
- scale_factors: TFloat,
+ scale_factors: Sequence[float],
mode: str,
coordinate_transformation_mode: str,
+ antialias: int = 0,
) -> TReal:
- scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
- scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
return op.Resize(
self,
None,
- scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
+ op.Constant(
+ value_floats=[1.0, 1.0, *scale_factors]
+ ), # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
+ nearest_mode="floor",
+ antialias=antialias,
)
@@ -2274,6 +2398,28 @@ def aten_upsample_bicubic2d(
)
+@torch_op("aten::_upsample_bicubic2d_aa", trace_only=True)
+def aten__upsample_bicubic2d_aa(
+ self: TReal,
+ output_size: INT64,
+ align_corners: bool,
+ scales_h: Optional[float] = None,
+ scales_w: Optional[float] = None,
+) -> TReal:
+ """_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
+
+ # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
+ # unless when align_corners is True, in which case we do not know what is going on.
+ coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
+ return _aten_upsample_output_size(
+ self,
+ output_size,
+ mode="cubic",
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ antialias=1,
+ )
+
+
@torch_op("aten::upsample_bicubic2d.vec", trace_only=True)
def aten_upsample_bicubic2d_vec(
self: TReal,
@@ -2287,7 +2433,7 @@ def aten_upsample_bicubic2d_vec(
if scale_factors is not None:
result = _aten_upsample_scales(
self,
- op.Constant(value_floats=scale_factors),
+ scale_factors,
mode="cubic",
coordinate_transformation_mode=coordinate_transformation_mode,
)
@@ -2336,6 +2482,28 @@ def aten_upsample_bilinear2d(
)
+@torch_op("aten::_upsample_bilinear2d_aa", trace_only=True)
+def aten__upsample_bilinear2d_aa(
+ self: TReal,
+ output_size: INT64,
+ align_corners: bool,
+ scales_h: Optional[float] = None,
+ scales_w: Optional[float] = None,
+) -> TReal:
+ """_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
+
+ # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
+ # unless when align_corners is True, in which case we do not know what is going on.
+ coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
+ return _aten_upsample_output_size(
+ self,
+ output_size,
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ mode="linear",
+ antialias=1,
+ )
+
+
@torch_op("aten::upsample_bilinear2d.vec", trace_only=True)
def aten_upsample_bilinear2d_vec(
self: TReal,
@@ -2349,11 +2517,12 @@ def aten_upsample_bilinear2d_vec(
if scale_factors is not None:
result = _aten_upsample_scales(
self,
- op.Constant(value_floats=scale_factors),
+ scale_factors,
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)
else:
+ assert output_size is not None
result = _aten_upsample_output_size(
self,
output_size,
@@ -2382,9 +2551,8 @@ def aten_upsample_linear1d(
self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None
) -> TReal:
"""upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor"""
- # FIXME(justinchuby): Support when scales is provided and align_corners is False
- del scales
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
+ # scales is ignored in PyTorch
return _aten_upsample_output_size(
self,
output_size,
@@ -2407,31 +2575,35 @@ def aten_upsample_linear1d_backward(
@torch_op("aten::upsample_nearest1d", trace_only=True)
def aten_upsample_nearest1d(
- self: TReal, size: INT64, scale_factor: Optional[float] = None
+ self: TReal, output_size: INT64, scales: Optional[float] = None
) -> TReal:
"""upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor"""
- if size is not None:
- return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
+ if scales is not None:
+ return _aten_upsample_scales(self, [scales], "nearest", "asymmetric")
else:
- return _aten_upsample_nearest1d_scales(self, scale_factor)
+ return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")
-@torch_op("aten::upsample_nearest1d", private=True)
-def _aten_upsample_nearest1d_scales(
- self: TReal,
- scale_factors: TFloat,
+@torch_op(
+ (
+ "aten::upsample_nearest1d.vec",
+ "aten::upsample_nearest2d.vec",
+ "aten::upsample_nearest3d.vec",
+ ),
+ trace_only=True,
+)
+def aten_upsample_nearestnd_vec(
+ input: TReal,
+ output_size: Optional[INT64] = None,
+ scale_factors: Optional[Sequence[float]] = None,
) -> TReal:
- scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
- scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
- return op.Resize(
- self,
- None,
- scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
- None,
- mode="nearest",
- coordinate_transformation_mode="asymmetric",
- nearest_mode="floor",
- )
+ """upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor"""
+
+ if scale_factors is not None:
+ return _aten_upsample_scales(input, scale_factors, "nearest", "asymmetric")
+ else:
+ assert output_size is not None
+ return _aten_upsample_output_size(input, output_size, "nearest", "asymmetric")
def aten_upsample_nearest1d_backward(
@@ -2448,18 +2620,21 @@ def aten_upsample_nearest1d_backward(
@torch_op("aten::upsample_nearest2d", trace_only=True)
def aten_upsample_nearest2d(
self: TReal,
- size: INT64,
+ output_size: INT64,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TReal:
"""upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor"""
- # NOTE: trace_only because optional attributes are not supported by ONNX
- # TODO(justinchuby): Conditionally use scales
- del scales_h
- del scales_w
-
- return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
+ if scales_h is not None and scales_w is not None:
+ return _aten_upsample_scales(
+ self,
+ [scales_h, scales_w],
+ "nearest",
+ "asymmetric",
+ )
+ else:
+ return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")
def aten_upsample_nearest2d_backward(
@@ -2477,18 +2652,22 @@ def aten_upsample_nearest2d_backward(
@torch_op("aten::upsample_nearest3d", trace_only=True)
def aten_upsample_nearest3d(
self: TReal,
- size: INT64,
+ output_size: INT64,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TReal:
"""upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor"""
- del scales_h
- del scales_w
- del scales_d
-
- return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
+ if scales_d is not None and scales_h is not None and scales_w is not None:
+ return _aten_upsample_scales(
+ self,
+ [scales_d, scales_h, scales_w],
+ "nearest",
+ "asymmetric",
+ )
+ else:
+ return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")
def aten_upsample_nearest3d_backward(
@@ -2528,6 +2707,33 @@ def aten_upsample_trilinear3d(
)
+@torch_op("aten::upsample_trilinear3d.vec", trace_only=True)
+def aten_upsample_trilinear3d_vec(
+ self: TReal,
+ output_size: INT64,
+ align_corners: bool,
+ scale_factors: Optional[Sequence[float]],
+) -> TReal:
+ """upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor"""
+
+ coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
+ if scale_factors is not None:
+ result = _aten_upsample_scales(
+ self,
+ scale_factors,
+ mode="linear",
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ )
+ else:
+ result = _aten_upsample_output_size(
+ self,
+ output_size,
+ mode="linear",
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ )
+ return result
+
+
def aten_upsample_trilinear3d_backward(
grad_output: TensorType,
output_size: INT64,
diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py
index 3136559b13..f53e9c1133 100644
--- a/onnxscript/function_libs/torch_lib/ops/prims.py
+++ b/onnxscript/function_libs/torch_lib/ops/prims.py
@@ -19,31 +19,35 @@
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import RealType, TTensor
from onnxscript.onnx_opset import opset18 as op
-from onnxscript.onnx_types import TensorType
+from onnxscript.onnx_types import BOOL, TensorType
-def prims_abs(self: TensorType) -> TensorType:
+@torch_op("prims::abs", trace_only=True)
+def prims_abs(self: TTensor) -> TTensor:
"""abs(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Abs(self)
+@torch_op("prims::acos", trace_only=True)
def prims_acos(self: TensorType) -> TensorType:
"""acos(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Acos(self)
+@torch_op("prims::acosh", trace_only=True)
def prims_acosh(self: TensorType) -> TensorType:
"""acosh(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Acosh(self)
-def prims_add(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::add", trace_only=True)
+def prims_add(self: TTensor, other: TTensor) -> TTensor:
"""add(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Add(self, other)
def prims_amax(
@@ -78,22 +82,25 @@ def prims_as_strided_scatter(
raise NotImplementedError()
-def prims_asin(self: TensorType) -> TensorType:
+@torch_op("prims::asin", trace_only=True)
+def prims_asin(self: TTensor) -> TTensor:
"""asin(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Asin(self)
-def prims_asinh(self: TensorType) -> TensorType:
+@torch_op("prims::asinh", trace_only=True)
+def prims_asinh(self: TTensor) -> TTensor:
"""asinh(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Asinh(self)
-def prims_atan(self: TensorType) -> TensorType:
+@torch_op("prims::atan", trace_only=True)
+def prims_atan(self: TTensor) -> TTensor:
"""atan(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Atan(self)
def prims_atan2(self: TensorType, other: TensorType) -> TensorType:
@@ -102,10 +109,11 @@ def prims_atan2(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_atanh(self: TensorType) -> TensorType:
+@torch_op("prims::atanh", trace_only=True)
+def prims_atanh(self: TTensor) -> TTensor:
"""atanh(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Atanh(self)
def prims_bessel_i0(self: TensorType) -> TensorType:
@@ -168,12 +176,33 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
+@torch_op("prims::broadcast_in_dim", trace_only=True)
def prims_broadcast_in_dim(
- a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
+ a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int]
) -> TensorType:
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""
- raise NotImplementedError()
+ target_rank = len(shape)
+
+ if not broadcast_dimensions:
+ # Special case: no broadcast dimensions - all target dims should be 1
+ return op.Expand(a, common_ops.merge_dims(shape))
+
+ # Create base shape of all 1s
+ ones = [1] * target_rank
+
+ # For each broadcast dimension, we'll replace the 1 with the actual input dimension
+ # Since broadcast_dimensions is compile-time known, we can do this with individual operations
+ intermediate_shape = ones
+
+ for i, broadcast_dim in enumerate(broadcast_dimensions):
+ # Get the input dimension value
+ input_dim_value = op.Shape(a, start=i, end=i + 1)
+ intermediate_shape[broadcast_dim] = input_dim_value
+
+ # Reshape input to intermediate shape and expand to target
+ reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape))
+ return op.Expand(reshaped, shape)
def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType:
@@ -188,10 +217,11 @@ def prims_cbrt(self: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_ceil(self: TensorType) -> TensorType:
+@torch_op("prims::ceil", trace_only=True)
+def prims_ceil(self: TTensor) -> TTensor:
"""ceil(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Ceil(self)
def prims_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@@ -239,16 +269,18 @@ def prims_copy_to(a: TensorType, b: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_cos(self: TensorType) -> TensorType:
+@torch_op("prims::cos", trace_only=True)
+def prims_cos(self: TTensor) -> TTensor:
"""cos(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Cos(self)
-def prims_cosh(self: TensorType) -> TensorType:
+@torch_op("prims::cosh", trace_only=True)
+def prims_cosh(self: TTensor) -> TTensor:
"""cosh(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Cosh(self)
@torch_op("prims::device_put")
@@ -268,10 +300,11 @@ def prims_digamma(self: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_div(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::div", trace_only=True)
+def prims_div(self: TTensor, other: TTensor) -> TTensor:
"""div(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Div(self, other)
def prims_empty(shape: INT64, dtype: int, device: str, requires_grad: bool) -> TensorType:
@@ -288,16 +321,18 @@ def prims_empty_strided(
raise NotImplementedError()
-def prims_eq(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::eq", trace_only=True)
+def prims_eq(self: TTensor, other: TTensor) -> TTensor:
"""eq(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Equal(self, other)
-def prims_erf(self: TensorType) -> TensorType:
+@torch_op("prims::erf", trace_only=True)
+def prims_erf(self: TTensor) -> TTensor:
"""erf(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Erf(self)
def prims_erf_inv(self: TensorType) -> TensorType:
@@ -318,10 +353,11 @@ def prims_erfcx(self: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_exp(self: TensorType) -> TensorType:
+@torch_op("prims::exp", trace_only=True)
+def prims_exp(self: TTensor) -> TTensor:
"""exp(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Exp(self)
def prims_exp2(self: TensorType) -> TensorType:
@@ -360,10 +396,11 @@ def prims_fill(self: TensorType, value: float) -> TensorType:
raise NotImplementedError()
-def prims_floor(self: TensorType) -> TensorType:
+@torch_op("prims::floor", trace_only=True)
+def prims_floor(self: TTensor) -> TTensor:
"""floor(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Floor(self)
def prims_fmax(self: TensorType, other: TensorType) -> TensorType:
@@ -406,16 +443,18 @@ def prims_gcd(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_ge(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::ge", trace_only=True)
+def prims_ge(self: TTensor, other: TTensor) -> TTensor:
"""ge(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.GreaterOrEqual(self, other)
-def prims_gt(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::gt", trace_only=True)
+def prims_gt(self: TTensor, other: TTensor) -> TTensor:
"""gt(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Greater(self, other)
def prims_hypot(self: TensorType, other: TensorType) -> TensorType:
@@ -462,10 +501,11 @@ def prims_item(a: TensorType) -> float:
raise NotImplementedError()
+@torch_op("prims::le", trace_only=True)
def prims_le(self: TensorType, other: TensorType) -> TensorType:
"""le(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.LessOrEqual(self, other)
def prims_lgamma(self: TensorType) -> TensorType:
@@ -474,10 +514,11 @@ def prims_lgamma(self: TensorType) -> TensorType:
raise NotImplementedError()
+@torch_op("prims::log", trace_only=True)
def prims_log(self: TensorType) -> TensorType:
"""log(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Log(self)
def prims_log10(self: TensorType) -> TensorType:
@@ -498,10 +539,11 @@ def prims_log2(self: TensorType) -> TensorType:
raise NotImplementedError()
+@torch_op("prims::lt", trace_only=True)
def prims_lt(self: TensorType, other: TensorType) -> TensorType:
"""lt(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Less(self, other)
def prims_maximum(self: TensorType, other: TensorType) -> TensorType:
@@ -528,10 +570,11 @@ def prims_minium_value(dtype: int) -> float:
raise NotImplementedError()
-def prims_mul(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::mul", trace_only=True)
+def prims_mul(self: TTensor, other: TTensor) -> TTensor:
"""mul(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Mul(self, other)
def prims_ndtri(self: TensorType) -> TensorType:
@@ -540,16 +583,18 @@ def prims_ndtri(self: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_ne(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::ne", trace_only=True)
+def prims_ne(self: TTensor, other: TTensor) -> TTensor:
"""ne(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Not(op.Equal(self, other))
-def prims_neg(self: TensorType) -> TensorType:
+@torch_op("prims::neg", trace_only=True)
+def prims_neg(self: TTensor) -> TTensor:
"""neg(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Neg(self)
def prims_nextafter(self: TensorType, other: TensorType) -> TensorType:
@@ -566,10 +611,11 @@ def prims_normal(
raise NotImplementedError()
-def prims_pow(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::pow", trace_only=True)
+def prims_pow(self: TTensor, other: TTensor) -> TTensor:
"""pow(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Pow(self, other)
def prims_prod(
@@ -598,16 +644,18 @@ def prims_remainder(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_reshape(a: TensorType, shape: INT64) -> TensorType:
+@torch_op("prims::reshape", trace_only=True)
+def prims_reshape(a: TTensor, shape: INT64) -> TTensor:
"""reshape(Tensor a, SymInt[] shape) -> Tensor"""
- raise NotImplementedError()
+ return op.Reshape(a, shape)
+@torch_op("prims::resize", trace_only=True)
def prims_resize(a: TensorType, shape: INT64) -> TensorType:
"""resize(Tensor a, SymInt[] shape) -> Tensor"""
- raise NotImplementedError()
+ return op.Expand(a, shape)
def prims_rev(a: TensorType, dims: Sequence[int]) -> TensorType:
@@ -616,10 +664,11 @@ def prims_rev(a: TensorType, dims: Sequence[int]) -> TensorType:
raise NotImplementedError()
+@torch_op("prims::round", trace_only=True)
def prims_round(self: TensorType) -> TensorType:
"""round(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Round(self)
def prims_rsqrt(self: TensorType) -> TensorType:
@@ -660,16 +709,18 @@ def prims_signbit(self: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_sin(self: TensorType) -> TensorType:
+@torch_op("prims::sin", trace_only=True)
+def prims_sin(self: TTensor) -> TTensor:
"""sin(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Sin(self)
-def prims_sinh(self: TensorType) -> TensorType:
+@torch_op("prims::sinh", trace_only=True)
+def prims_sinh(self: TTensor) -> TTensor:
"""sinh(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Sinh(self)
def prims_slice(
@@ -700,22 +751,25 @@ def prims_split_dim(a: TensorType, dim: int, outer_length: INT64) -> TensorType:
raise NotImplementedError()
-def prims_sqrt(self: TensorType) -> TensorType:
+@torch_op("prims::sqrt", trace_only=True)
+def prims_sqrt(self: TTensor) -> TTensor:
"""sqrt(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Sqrt(self)
-def prims_squeeze(a: TensorType, dimensions: Sequence[int]) -> TensorType:
+@torch_op("prims::squeeze", trace_only=True)
+def prims_squeeze(a: TTensor, dimensions: Sequence[int]) -> TTensor:
"""squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)"""
- raise NotImplementedError()
+ return op.Squeeze(a, axes=dimensions)
-def prims_sub(self: TensorType, other: TensorType) -> TensorType:
+@torch_op("prims::sub", trace_only=True)
+def prims_sub(self: TTensor, other: TTensor) -> TTensor:
"""sub(Tensor self, Tensor other) -> Tensor"""
- raise NotImplementedError()
+ return op.Sub(self, other)
def prims_sum(
@@ -732,22 +786,25 @@ def prims_svd(A: TensorType, full_matrices: bool) -> tuple[TensorType, TensorTyp
raise NotImplementedError()
-def prims_tan(self: TensorType) -> TensorType:
+@torch_op("prims::tan", trace_only=True)
+def prims_tan(self: TTensor) -> TTensor:
"""tan(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Tan(self)
-def prims_tanh(self: TensorType) -> TensorType:
+@torch_op("prims::tanh", trace_only=True)
+def prims_tanh(self: TTensor) -> TTensor:
"""tanh(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Tanh(self)
+@torch_op("prims::transpose", trace_only=True)
def prims_transpose(a: TensorType, permutation: Sequence[int]) -> TensorType:
"""transpose(Tensor(a) a, int[] permutation) -> Tensor(a)"""
- raise NotImplementedError()
+ return op.Transpose(a, perm=permutation)
def prims_trunc(self: TensorType) -> TensorType:
@@ -764,6 +821,7 @@ def prims_uniform(
raise NotImplementedError()
+@torch_op("prims::var", trace_only=True)
def prims_var(
inp: TensorType,
dims: Optional[Sequence[int]],
@@ -772,7 +830,26 @@ def prims_var(
) -> TensorType:
"""var(Tensor inp, int[]? dims, *, int correction, ScalarType? output_dtype=None) -> Tensor"""
- raise NotImplementedError()
+ if not dims:
+ # dims can be empty in practice. We just use a None so it is not added in the ONNX graph
+ dims = None
+ sub_mean = op.Sub(inp, op.ReduceMean(inp, dims, keepdims=True))
+ sqr_mean = op.Mul(sub_mean, sub_mean)
+ var = op.ReduceMean(sqr_mean, dims, keepdims=False)
+ # Adjust var according to correction value
+ if correction != 0:
+ inp_shape = op.Shape(inp)
+ dim_size = op.Gather(inp_shape, dims, axis=0)
+ numel_float = op.CastLike(op.ReduceProd(dim_size, keepdims=False), inp)
+ mul = op.Mul(var, numel_float)
+ # Subtract the correction value
+ sub = op.Sub(numel_float, op.CastLike(correction, inp))
+ var = op.Div(mul, sub)
+
+ if output_dtype is not None and output_dtype != -1:
+ var = op.Cast(var, to=output_dtype)
+
+ return var
def prims_view_of(a: TensorType) -> TensorType:
@@ -781,10 +858,11 @@ def prims_view_of(a: TensorType) -> TensorType:
raise NotImplementedError()
-def prims_where(pred: TensorType, a: TensorType, b: TensorType) -> TensorType:
+@torch_op("prims::where", trace_only=True)
+def prims_where(pred: BOOL, a: TTensor, b: TTensor) -> TTensor:
"""where(Tensor pred, Tensor a, Tensor b) -> Tensor"""
- raise NotImplementedError()
+ return op.Where(pred, a, b)
def prims_zeta(self: TensorType, other: TensorType) -> TensorType:
diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
new file mode 100644
index 0000000000..92962a9ea6
--- /dev/null
+++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
@@ -0,0 +1,63 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
+# pylint: disable=unused-argument
+"""quantized_decomposed ops defined in https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
+
+- No inplace operators.
+- All functions should not have the script() decorator. This is because
+ we want to delay the compilation of the function.
+"""
+
+from __future__ import annotations
+
+from onnxscript.function_libs.torch_lib.ops import common
+from onnxscript.function_libs.torch_lib.registration import torch_op
+from onnxscript.onnx_opset import opset18 as op
+from onnxscript.onnx_types import TensorType
+
+
+@torch_op(
+ (
+ "quantized_decomposed::quantize_per_tensor",
+ "quantized_decomposed::quantize_per_tensor.tensor",
+ "quantized_decomposed::quantize_per_tensor.tensor2",
+ ),
+ trace_only=True,
+)
+def quantized_decomposed_quantize_per_tensor(
+ input: TensorType,
+ scale: float,
+ zero_point: int,
+ quant_min: int,
+ quant_max: int,
+ dtype: int,
+) -> TensorType:
+ # TODO(justinchuby): Use dtype when we use opset 21
+ return op.QuantizeLinear(input, scale, common.constant(zero_point, dtype=dtype))
+
+
+@torch_op(
+ (
+ "quantized_decomposed::dequantize_per_tensor",
+ "quantized_decomposed::dequantize_per_tensor.tensor",
+ "quantized_decomposed::dequantize_per_tensor.tensor2",
+ ),
+ trace_only=True,
+)
+def quantized_decomposed_dequantize_per_tensor(
+ input: TensorType,
+ scale: float,
+ zero_point: int,
+ quant_min: int,
+ quant_max: int,
+ dtype: int,
+ out_dtype: int = -1,
+) -> TensorType:
+ # TODO(justinchuby): Use dtype when we use opset 21
+ dequantized = op.DequantizeLinear(input, scale, common.constant(zero_point, dtype=dtype))
+ if out_dtype in (-1, None):
+ # out_dtype can be None as well
+ return dequantized
+ assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}"
+ return op.Cast(dequantized, to=out_dtype)
diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py
index 6719581f62..1b123394d3 100644
--- a/onnxscript/function_libs/torch_lib/ops/special.py
+++ b/onnxscript/function_libs/torch_lib/ops/special.py
@@ -15,14 +15,12 @@
import math
from typing import Optional, Sequence
-from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
-from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16
+from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType
_MATH_PI = math.pi
-IsScalar = common_ops.IsScalar
def aten_special_airy_ai(x: TensorType) -> TensorType:
@@ -92,21 +90,21 @@ def aten_special_entr(self: TensorType) -> TensorType:
@torch_op(("aten::erf", "aten::special_erf"))
-def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+def aten_special_erf(self: TFloat) -> TFloat:
"""erf(Tensor self) -> Tensor"""
return op.Erf(self)
@torch_op(("aten::erfc", "aten::special_erfc"))
-def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+def aten_special_erfc(self: TFloat) -> TFloat:
"""erfc(Tensor self) -> Tensor"""
return op.Sub(1, op.Erf(self))
@torch_op("aten::special_erfcx")
-def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
+def aten_special_erfcx(self: TFloat) -> TFloat:
"""special_erfcx(Tensor self) -> Tensor"""
return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self)))
@@ -130,10 +128,11 @@ def aten_special_expit(self: TensorType) -> TensorType:
raise NotImplementedError()
-def aten_special_expm1(self: TensorType) -> TensorType:
+@torch_op(("aten::expm1", "aten::special_expm1"))
+def aten_special_expm1(self: TFloat) -> TFloat:
"""special_expm1(Tensor self) -> Tensor"""
- raise NotImplementedError()
+ return op.Sub(op.Exp(self), 1)
def aten_special_gammainc(self: TensorType, other: TensorType) -> TensorType:
@@ -214,15 +213,13 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op(("aten::log_softmax", "aten::special_log_softmax"), trace_only=True)
-def aten_special_log_softmax(
- self: TFloatOrBFloat16, dim: int, dtype: int = -1
-) -> TFloatOrBFloat16:
+@torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True)
+def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""
- self_is_scalar = IsScalar(self)
+ self_is_scalar = len(self.shape) == 0
if self_is_scalar:
- self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
+ self = op.Unsqueeze(self, [0])
result = op.LogSoftmax(self, axis=dim)
if dtype != -1:
result = op.Cast(result, to=dtype)
@@ -364,8 +361,8 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()
-@torch_op("aten::xlogy")
-def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
+@torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other"))
+def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat:
"""special_xlogy(Tensor self, Tensor other) -> Tensor"""
# https://pytorch.org/docs/stable/special.html#torch.special.xlogy
diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py
index 05d8f62179..162d69d747 100644
--- a/onnxscript/function_libs/torch_lib/registration.py
+++ b/onnxscript/function_libs/torch_lib/registration.py
@@ -1,9 +1,10 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
"""Registry for aten functions."""
from __future__ import annotations
import re
-from types import FunctionType
from typing import Any, Callable, Generator, Optional
import onnxscript
@@ -99,8 +100,7 @@ def torch_op(
trace_only: bool = False,
private: bool = False,
complex: bool = False,
- traceable: bool = False,
-) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
+) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.
Args:
@@ -113,24 +113,12 @@ def torch_op(
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function expects complex-valued inputs.
- traceable: Whether the function can also be traced. This is an **experimental** flag.
- A function is traceable if it can both be scripted and traced to produce
- the same result for a given input. Specifically:
-
- - A function _can_ be tagged with traceable if its if branches (if any)
- can be statically evaluated.
- - A function _should_ be tagged with traceable if it contains if branches
- and/or CastLike nodes so that they can be evaluated away with the
- EXPERIMENTAL_PREFER_TRACING on.
- - A function without if branches or CastLike nodes _should not_ be tagged
- with traceable because inlining will do the same thing.
- - A function with `@graph` defined for a `Scan` op is not traceable yet.
"""
if registry is None:
registry = default_registry
def wrapper(
- func: FunctionType,
+ func: Callable,
) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction:
# Compile the function
custom_opset = onnxscript.values.Opset(domain=_constants.DOMAIN, version=1)
@@ -139,9 +127,7 @@ def wrapper(
if trace_only:
processed_func = onnxscript.values.TracedOnnxFunction(custom_opset, func)
else:
- assert isinstance(func, FunctionType)
processed_func = onnxscript.script(opset=custom_opset)(func)
- processed_func.experimental_traceable = traceable
assert registry is not None
for name_ in _check_and_normalize_names(name):
diff --git a/onnxscript/function_libs/torch_lib/tensor_typing.py b/onnxscript/function_libs/torch_lib/tensor_typing.py
index 7b5287f417..1f27c0cff0 100644
--- a/onnxscript/function_libs/torch_lib/tensor_typing.py
+++ b/onnxscript/function_libs/torch_lib/tensor_typing.py
@@ -42,7 +42,7 @@
INT64,
UINT8,
]
-_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
+_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
IntType = Union[INT8, INT16, INT32, INT64]
RealType = Union[
BFLOAT16,
@@ -61,7 +61,6 @@
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
TFloat = TypeVar("TFloat", bound=_FloatType)
-TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8])
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
diff --git a/onnxscript/ir/README.md b/onnxscript/ir/README.md
index dae5c09a5b..21d5cd124d 100644
--- a/onnxscript/ir/README.md
+++ b/onnxscript/ir/README.md
@@ -1,22 +1,3 @@
-# ONNX IR
+# Where is the code?
-An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
-
-## Features ✨
-
-- Full ONNX spec support: all valid models representable by ONNX protobuf, and a subset of invalid models (so you can load and fix them).
-- Low memory footprint: mmap'ed external tensors; unified interface for ONNX TensorProto, Numpy arrays and PyTorch Tensors etc. No tensor size limitation. Zero copies.
-- Straightforward access patterns: Access value information and traverse the graph topology at ease.
-- Robust mutation: Create as many iterators as you like on the graph while mutating it.
-- Speed: Performant graph manipulation, serialization/deserialization to Protobuf.
-- Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
-- No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
-
-## Code Organization 🗺️
-
-- [`_protocols.py`](_protocols.py): Interfaces defined for all entities in the IR.
-- [`_core.py`](_core.py): Implementation of the core entities in the IR, including `Model`, `Graph`, `Node`, `Value`, and others.
-- [`_enums.py`](_enums.py): Definition of the type enums that correspond to the `DataType` and `AttributeType` in `onnx.proto`.
-- [`_name_authority.py`](_name_authority.py): The authority for giving names to entities in the graph, used internally.
-- [`_linked_list.py`](_linked_list.py): The data structure as the node container in the graph that supports robust iteration and mutation. Internal.
-- [`_metadata.py`](_metadata.py): Metadata store for all entities in the IR.
+The ONNX IR has migrated to https://github.com/onnx/ir-py as a standalone project. The original onnxscript APIs are aliased here for compatibility.
diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py
index 7bfbeabad9..6240347886 100644
--- a/onnxscript/ir/__init__.py
+++ b/onnxscript/ir/__init__.py
@@ -1,129 +1,4 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""In-memory intermediate representation for ONNX graphs."""
-
-__all__ = [
- # Modules
- "serde",
- # IR classes
- "Attr",
- "AttrFloat32",
- "AttrFloat32s",
- "AttrGraph",
- "AttrGraphs",
- "AttrInt64",
- "AttrInt64s",
- "AttrSparseTensor",
- "AttrSparseTensors",
- "AttrString",
- "AttrStrings",
- "AttrTensor",
- "AttrTensors",
- "TypeAndShape",
- "AttrTypeProto",
- "AttrTypeProtos",
- "SymbolicDim",
- "ExternalTensor",
- "StringTensor",
- "Function",
- "Graph",
- "GraphView",
- "Input",
- "Model",
- "Node",
- "RefAttr",
- "Shape",
- "Tensor",
- "Value",
- "TensorType",
- "OptionalType",
- "SequenceType",
- "SparseTensorType",
- # Protocols
- "ArrayCompatible",
- "DLPackCompatible",
- "TensorProtocol",
- "ValueProtocol",
- "ModelProtocol",
- "NodeProtocol",
- "GraphProtocol",
- "GraphViewProtocol",
- "AttributeProtocol",
- "ReferenceAttributeProtocol",
- "SparseTensorProtocol",
- "SymbolicDimProtocol",
- "ShapeProtocol",
- "TypeProtocol",
- "MapTypeProtocol",
- "FunctionProtocol",
- # Enums
- "AttributeType",
- "DataType",
- # Types
- "OperatorIdentifier",
- # Protobuf compatible types
- "TensorProtoTensor",
-]
-
-from onnxscript.ir import serde
-from onnxscript.ir._core import (
- Attr,
- AttrFloat32,
- AttrFloat32s,
- AttrGraph,
- AttrGraphs,
- AttrInt64,
- AttrInt64s,
- AttrSparseTensor,
- AttrSparseTensors,
- AttrString,
- AttrStrings,
- AttrTensor,
- AttrTensors,
- AttrTypeProto,
- AttrTypeProtos,
- ExternalTensor,
- Function,
- Graph,
- GraphView,
- Input,
- Model,
- Node,
- OptionalType,
- RefAttr,
- SequenceType,
- Shape,
- SparseTensorType,
- StringTensor,
- SymbolicDim,
- Tensor,
- TensorType,
- TypeAndShape,
- Value,
-)
-from onnxscript.ir._enums import (
- AttributeType,
- DataType,
-)
-from onnxscript.ir._protocols import (
- ArrayCompatible,
- AttributeProtocol,
- DLPackCompatible,
- FunctionProtocol,
- GraphProtocol,
- GraphViewProtocol,
- MapTypeProtocol,
- ModelProtocol,
- NodeProtocol,
- OperatorIdentifier,
- ReferenceAttributeProtocol,
- ShapeProtocol,
- SparseTensorProtocol,
- SymbolicDimProtocol,
- TensorProtocol,
- TypeProtocol,
- ValueProtocol,
-)
-from onnxscript.ir.serde import TensorProtoTensor
+# pylint: disable=wildcard-import,unused-wildcard-import
+from onnx_ir import * # type: ignore # noqa: F403
diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py
deleted file mode 100644
index 7eba1cb283..0000000000
--- a/onnxscript/ir/_convenience.py
+++ /dev/null
@@ -1,287 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Convenience methods for constructing and manipulating the IR.
-
-This is an internal only module. We should choose to expose some of the methods
-after they are proven to be useful.
-"""
-
-from __future__ import annotations
-
-__all__ = [
- "convert_attribute",
- "convert_attributes",
- "replace_all_uses_with",
-]
-
-from typing import Mapping, Sequence, Union
-
-import onnx
-
-from onnxscript.ir import _core, _enums, _protocols, serde
-
-SupportedAttrTypes = Union[
- str,
- int,
- float,
- Sequence[int],
- Sequence[float],
- Sequence[str],
- _protocols.TensorProtocol, # This includes all in-memory tensor types
- onnx.TensorProto,
- _core.Attr,
- _core.RefAttr,
- _protocols.GraphProtocol,
- Sequence[_protocols.GraphProtocol],
- _protocols.TypeProtocol,
- Sequence[_protocols.TypeProtocol],
- None,
-]
-
-
-def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
- """Infer the attribute type based on the type of the Python object."""
- if isinstance(attr, int):
- return _enums.AttributeType.INT
- if isinstance(attr, float):
- return _enums.AttributeType.FLOAT
- if isinstance(attr, str):
- return _enums.AttributeType.STRING
- if isinstance(attr, (_core.Attr, _core.RefAttr)):
- return attr.type
- if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
- return _enums.AttributeType.INTS
- if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
- return _enums.AttributeType.FLOATS
- if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
- return _enums.AttributeType.STRINGS
- if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
- # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
- return _enums.AttributeType.TENSOR
- if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)):
- return _enums.AttributeType.GRAPH
- if isinstance(attr, Sequence) and all(
- isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
- ):
- return _enums.AttributeType.GRAPHS
- if isinstance(
- attr,
- (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
- ):
- return _enums.AttributeType.TYPE_PROTO
- if isinstance(attr, Sequence) and all(
- isinstance(
- x,
- (
- _core.TensorType,
- _core.SequenceType,
- _core.OptionalType,
- _protocols.TypeProtocol,
- ),
- )
- for x in attr
- ):
- return _enums.AttributeType.TYPE_PROTOS
- raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
-
-
-def convert_attribute(
- name: str,
- attr: SupportedAttrTypes,
- attr_type: _enums.AttributeType | None = None,
-) -> _core.Attr | _core.RefAttr:
- """Convert a Python object to a _core.Attr object.
-
- This method is useful when constructing nodes with attributes. It infers the
- attribute type based on the type of the Python value.
-
- Args:
- name: The name of the attribute.
- attr: The value of the attribute.
- attr_type: The type of the attribute. This is required when attr is None.
- When provided, it overrides the inferred type.
-
- Returns:
- A ``Attr`` object.
-
- Raises:
- ValueError: If :param:`attr` is ``None`` and :param:`attr_type` is not provided.
- TypeError: If the type of the attribute is not supported.
- """
- if attr is None:
- if attr_type is None:
- raise ValueError("attr_type must be provided when attr is None")
- return _core.Attr(name, attr_type, None)
-
- if isinstance(attr, (_core.Attr, _core.RefAttr)):
- if attr.name != name:
- raise ValueError(
- f"Attribute name '{attr.name}' does not match provided name '{name}'"
- )
- if attr_type is not None and attr.type != attr_type:
- raise ValueError(
- f"Attribute type '{attr.type}' does not match provided type '{attr_type}'"
- )
- return attr
-
- if attr_type is None:
- attr_type = _infer_attribute_type(attr)
-
- if attr_type == _enums.AttributeType.INT:
- return _core.AttrInt64(name, attr) # type: ignore
- if attr_type == _enums.AttributeType.FLOAT:
- return _core.AttrFloat32(name, attr) # type: ignore
- if attr_type == _enums.AttributeType.STRING:
- return _core.AttrString(name, attr) # type: ignore
- if attr_type == _enums.AttributeType.INTS:
- return _core.AttrInt64s(name, attr) # type: ignore
- if attr_type == _enums.AttributeType.FLOATS:
- return _core.AttrFloat32s(name, attr) # type: ignore
- if attr_type == _enums.AttributeType.STRINGS:
- return _core.AttrStrings(name, attr) # type: ignore
- if attr_type == _enums.AttributeType.TENSOR:
- if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
- return _core.AttrTensor(name, attr)
- if isinstance(attr, onnx.TensorProto):
- return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
- if attr_type == _enums.AttributeType.GRAPH:
- return _core.AttrGraph(name, attr) # type: ignore[arg-type]
- if attr_type == _enums.AttributeType.GRAPHS:
- return _core.AttrGraphs(name, attr) # type: ignore[arg-type]
- if attr_type == _enums.AttributeType.TYPE_PROTO:
- return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
- if attr_type == _enums.AttributeType.TYPE_PROTOS:
- return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type]
- raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
-
-
-def convert_attributes(
- attrs: Mapping[str, SupportedAttrTypes],
-) -> list[_core.Attr | _core.RefAttr]:
- """Convert a dictionary of attributes to a list of _core.Attr objects.
-
- It infers the attribute type based on the type of the value. The supported
- types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
- :class:`_core.Tensor`, and :class:`_core.Attr`::
-
- >>> from onnxscript import ir
- >>> import onnx
- >>> import numpy as np
- >>> attrs = {
- ... "int": 1,
- ... "float": 1.0,
- ... "str": "hello",
- ... "ints": [1, 2, 3],
- ... "floats": [1.0, 2.0, 3.0],
- ... "strings": ["hello", "world"],
- ... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])),
- ... "tensor_proto":
- ... onnx.TensorProto(
- ... dims=[3],
- ... data_type=onnx.TensorProto.FLOAT,
- ... float_data=[1.0, 2.0, 3.0],
- ... name="proto",
- ... ),
- ... "graph": ir.Graph([], [], nodes=[], name="graph0"),
- ... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")],
- ... "type_proto": ir.TensorType(ir.DataType.FLOAT),
- ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
- ... }
- >>> convert_attributes(attrs)
- [AttrInt64('int', 1), AttrFloat32('float', 1.0), AttrString('str', 'hello'), AttrInt64s('ints', [1, 2, 3]), AttrFloat32s('floats', [1.0, 2.0, 3.0]), AttrStrings('strings', ['hello', 'world']), AttrTensor('tensor', Tensor(array([1., 2., 3.]), name='')), AttrTensor('tensor_proto', TensorProtoTensor(name='proto')), AttrInt64s('graph', Graph(
- name='graph0',
- inputs=(
-
- ),
- outputs=(
-
- ),
- len()=0
- )), AttrGraphs('graphs', [Graph(
- name='graph1',
- inputs=(
-
- ),
- outputs=(
-
- ),
- len()=0
- ), Graph(
- name='graph2',
- inputs=(
-
- ),
- outputs=(
-
- ),
- len()=0
- )]), AttrTypeProto('type_proto', Tensor(FLOAT)), AttrTypeProtos('type_protos', [Tensor(FLOAT), Tensor(FLOAT)])]
-
- Args:
- attrs: A dictionary of {: } to convert.
-
- Returns:
- A list of _core.Attr objects.
- """
- attributes: list[_core.Attr | _core.RefAttr] = []
- for name, attr in attrs.items():
- attributes.append(convert_attribute(name, attr))
- return attributes
-
-
-def replace_all_uses_with(
- values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
- replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
-) -> None:
- """Replace all uses of the given values with the replacements.
-
- This is useful when nodes in the graph are replaced with new nodes, where
- the old users need to be updated to use the outputs of the new nodes.
-
- For example, suppose we have the following graph::
-
- A -> {B, C}
-
- We want to replace the node A with a new node D::
-
- >>> from onnxscript import ir
- >>> input = ir.Input("input")
- >>> node_a = ir.Node("", "A", [input])
- >>> node_b = ir.Node("", "B", node_a.outputs)
- >>> node_c = ir.Node("", "C", node_a.outputs)
- >>> node_d = ir.Node("", "D", [input])
- >>> replace_all_uses_with(node_a.outputs, node_d.outputs)
- >>> len(node_b.inputs)
- 1
- >>> node_b.inputs[0].producer().op_type
- 'D'
- >>> len(node_c.inputs)
- 1
- >>> node_c.inputs[0].producer().op_type
- 'D'
- >>> len(node_a.outputs[0].uses())
- 0
-
- When values and replacements are sequences, they are zipped into pairs. All
- users of the first value is replaced with the first replacement, and so on.
-
- .. note::
- You still need to update the graph outputs if any of the values being
- replaced are part of the graph outputs. Be sure to remove the old nodes
- from the graph using ``graph.remove()`` if they are no longer needed.
-
- Args:
- values: The value or values to be replaced.
- replacements: The new value or values to use as inputs.
- """
- if not isinstance(values, Sequence):
- values = (values,)
- if not isinstance(replacements, Sequence):
- replacements = (replacements,)
- if len(values) != len(replacements):
- raise ValueError("The number of values and replacements must match.")
- for value, replacement in zip(values, replacements):
- for user_node, index in tuple(value.uses()):
- user_node.replace_input_with(index, replacement)
diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py
deleted file mode 100644
index 1dedd0b6a7..0000000000
--- a/onnxscript/ir/_core.py
+++ /dev/null
@@ -1,2618 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""data structures for the intermediate representation."""
-
-# NOTES for developers:
-# NOTE: None of these classes will have a "to_onnx" or "from_protobuf" method because
-# We cannot assume that the build tool chain has protoc installed and would like
-# to keep this module protobuf free. This way we separate the concerns of the IR
-# and the serialization/deserialization.
-#
-# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead.
-
-from __future__ import annotations
-
-import abc
-import contextlib
-import dataclasses
-import math
-import mmap
-import os
-import sys
-import textwrap
-import typing
-from typing import (
- AbstractSet,
- Any,
- Collection,
- Generic,
- Iterable,
- Iterator,
- OrderedDict,
- Sequence,
- Union,
-)
-
-import numpy as np
-
-from onnxscript.ir import (
- _display,
- _enums,
- _linked_list,
- _metadata,
- _name_authority,
- _protocols,
- _type_casting,
-)
-
-if typing.TYPE_CHECKING:
- import numpy.typing as npt
- from typing_extensions import TypeGuard
-
-TArrayCompatible = typing.TypeVar(
- "TArrayCompatible",
- bound=Union[_protocols.ArrayCompatible, _protocols.DLPackCompatible],
-)
-
-# System is little endian
-_IS_LITTLE_ENDIAN = sys.byteorder == "little"
-# Data types that are not supported by numpy
-_NON_NUMPY_NATIVE_TYPES = frozenset(
- (
- _enums.DataType.BFLOAT16,
- _enums.DataType.FLOAT8E4M3FN,
- _enums.DataType.FLOAT8E4M3FNUZ,
- _enums.DataType.FLOAT8E5M2,
- _enums.DataType.FLOAT8E5M2FNUZ,
- _enums.DataType.INT4,
- _enums.DataType.UINT4,
- )
-)
-
-
-def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]:
- """Use this function to check if an object is compatible with numpy.
-
- Avoid isinstance checks with the ArrayCompatible protocol for performance reasons.
- """
- return hasattr(obj, "__array__")
-
-
-def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
- """Use this function to check if an object is compatible with DLPack.
-
- Avoid isinstance checks with the DLPackCompatible protocol for performance reasons.
- """
- return hasattr(obj, "__dlpack__")
-
-
-class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
- """Convenience Shared methods for classes implementing TensorProtocol."""
-
- __slots__ = ()
-
- def _printable_type_shape(self) -> str:
- """Return a string representation of the shape and data type."""
- return f"{self.dtype},{self.shape}"
-
- def _repr_base(self) -> str:
- """Base string for the repr method.
-
- Example: Tensor
- """
- return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
-
- @property
- def size(self) -> int:
- """The number of elements in the tensor."""
- return np.prod(self.shape.numpy()) # type: ignore[return-value,attr-defined]
-
- @property
- def nbytes(self) -> int:
- """The number of bytes in the tensor."""
- # Use math.ceil because when dtype is INT4, the itemsize is 0.5
- return math.ceil(self.dtype.itemsize * self.size)
-
- def display(self, *, page: bool | None = None) -> None:
- rich = _display.require_rich()
-
- if rich is None:
- status_manager = contextlib.nullcontext()
- else:
- import rich.status # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel
-
- status_manager = rich.status.Status(f"Computing tensor stats for {self!r}")
-
- from onnxscript._thirdparty import ( # pylint: disable=import-outside-toplevel
- asciichartpy,
- )
-
- with status_manager:
- # Construct the text to display
- lines = []
- array = self.numpy().flatten()
- lines.append(repr(self))
- lines.append("")
- nan_values = np.isnan(array)
- nan_count = np.count_nonzero(nan_values)
- inf_count = np.count_nonzero(np.isinf(array))
- numbers = array[~nan_values]
- lines.append(
- f"Min: {np.min(numbers)}, Max: {np.max(numbers)}, "
- f"NaN count: {nan_count}, "
- f"Inf count: {inf_count}"
- )
- # Compute sparsity
- sparse_threathold = 1e-6
- # NOTE: count_nonzero() is faster than sum() for boolean arrays
- sparsity = np.count_nonzero(np.abs(array) < sparse_threathold) / array.size
- lines.append(f"Sparsity (abs<{sparse_threathold}): {sparsity:.2f}")
-
- # Compute histogram
- finite_numbers = array[np.isfinite(array)]
- lines.append("Histogram:")
- hist, bin_edges = np.histogram(finite_numbers, bins=80, density=False)
- lines.append(
- asciichartpy.plot(
- hist, bin_edges=bin_edges, cfg={"height": 8, "format": "{:8.0f}"}
- )
- )
-
- text = "\n".join(lines)
-
- if rich is None:
- print(text)
- elif page:
- import rich.console # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel
-
- console = rich.console.Console()
- with console.pager(styles=True):
- console.print(text)
- else:
- rich.print(text)
-
-
-def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) -> None:
- """Check if the numpy array dtype matches the IR data type.
-
- When the dtype is not one of the numpy native dtypes, the value needs need to be:
-
- - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
- - ``uint8`` for uint4.
- - ``uint8`` for 8-bit data types.
- - ``uint16`` for bfloat16
- """
- if dtype in _NON_NUMPY_NATIVE_TYPES:
- if dtype.itemsize == 2 and array.dtype != np.uint16:
- # TODO(justinchuby): Support the storage dtypes like uint16 for bfloat16.
- raise TypeError(
- f"The numpy array dtype must be uint16 (not {array.dtype}) for IR data type {dtype}."
- )
- if dtype.itemsize == 1 and array.dtype != np.uint8:
- raise TypeError(
- f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}."
- )
- if dtype == _enums.DataType.INT4:
- if array.dtype not in (np.int8, np.uint8):
- raise TypeError(
- f"The numpy array dtype must be int8 or uint8 (not {array.dtype}) for IR data type {dtype}."
- )
- if dtype == _enums.DataType.UINT4:
- if array.dtype != np.uint8:
- raise TypeError(
- f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}."
- )
- return
-
- try:
- dtype_numpy = _enums.DataType.from_numpy(array.dtype)
- except TypeError as e:
- raise TypeError(
- "Failed to convert the numpy dtype to an IR data type. "
- "If you are using a non-native dtype, be sure to specify the corresponding IR dtype when "
- "creating a Tensor."
- ) from e
-
- if dtype_numpy != dtype:
- raise TypeError(
- f"The numpy array dtype {array.dtype} does not match the IR data type {dtype}."
- )
-
-
-class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
- """An immutable concrete tensor.
-
- This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array
- compatible object (e.g. ``np.ndarray``, ``torch.Tensor``) or a ``DLPack`` compatible object.
- The tensor is immutable and the data is not copied at initialization.
-
- To create a tensor from a numpy array::
-
- >>> import numpy as np
- >>> array = np.array([1, 2, 3])
- >>> tensor = Tensor(array)
- >>> # The tensor itself can be treated as a numpy array because it implements the __array__ method
- >>> np.allclose(tensor, array)
- True
-
- To get a numpy array from the tensor, call :meth:`numpy`. To convert the tensor
- to a byte string for serialization, call :meth:`tobytes`.
-
- It is recommended to check the size of the tensor first before accessing the
- underlying data, because accessing the data may be expensive and incur IO
- overhead.
-
- Subclass this class to efficiently handle different types of tensors from different frameworks.
-
- Attributes:
- name: The name of the tensor.
- shape: The shape of the tensor.
- dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
- doc_string: Documentation string.
- raw: The raw data behind this tensor. It can be anything.
- size: The number of elements in the tensor.
- nbytes: The number of bytes in the tensor.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- __slots__ = (
- "_raw",
- "_dtype",
- "_shape",
- "name",
- "doc_string",
- "_metadata_props",
- "_metadata",
- )
-
- def __init__(
- self,
- value: TArrayCompatible,
- dtype: _enums.DataType | None = None,
- *,
- shape: Shape | None = None,
- name: str = "",
- doc_string: str | None = None,
- metadata_props: dict[str, str] | None = None,
- ) -> None:
- """Initialize a tensor.
-
- Args:
- value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
- When the dtype is not one of the numpy native dtypes, the value needs
- to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
- when the value is a numpy array; :param:`dtype` must be specified in this case.
- dtype: The data type of the tensor. It can be None only when value is a numpy array.
- Users are responsible for making sure the dtype matches the value when value is not a numpy array.
- shape: The shape of the tensor. If None, the shape is obtained from the value.
- name: The name of the tensor.
- doc_string: The documentation string.
- metadata_props: The metadata properties.
-
- Raises:
- TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
- TypeError: If the value is a numpy array and the dtype is specified but does not match the dtype of the array.
- ValueError: If the shape is not specified and the value does not have a shape attribute.
- ValueError: If the dtype is not specified and the value is not a numpy array.
- """
- # NOTE: We should not do any copying here for performance reasons
- if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
- raise TypeError(f"Expected an array compatible object, got {type(value)}")
- if shape is None:
- # Obtain the shape from the value
- if not hasattr(value, "shape"):
- raise ValueError(
- f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
- "Please specify the shape explicitly."
- )
- self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
- else:
- self._shape = shape
- self._shape._frozen = True
- if dtype is None:
- if isinstance(value, np.ndarray):
- self._dtype = _enums.DataType.from_numpy(value.dtype)
- else:
- raise ValueError(
- "The dtype must be specified when the value is not a numpy array."
- )
- else:
- if isinstance(value, np.ndarray):
- # Make sure the dtype matches the value
- _check_numpy_representation_type(value, dtype)
- # Users are responsible for making sure the dtype matches the value
- # when value is not a numpy array
- self._dtype = dtype
- self._raw = value
- self.name = name
- self.doc_string = doc_string
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props = metadata_props
-
- def __array__(self, dtype: Any = None) -> np.ndarray:
- if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
- return self._raw.__array__(dtype)
- assert _compatible_with_dlpack(
- self._raw
- ), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
- return np.from_dlpack(self._raw)
-
- def __dlpack__(self, *, stream: Any = None) -> Any:
- if _compatible_with_dlpack(self._raw):
- return self._raw.__dlpack__(stream=stream)
- return self.__array__().__dlpack__(stream=stream)
-
- def __dlpack_device__(self) -> tuple[int, int]:
- if _compatible_with_dlpack(self._raw):
- return self._raw.__dlpack_device__()
- return self.__array__().__dlpack_device__()
-
- def __repr__(self) -> str:
- return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
-
- @property
- def dtype(self) -> _enums.DataType:
- """The data type of the tensor. Immutable."""
- return self._dtype
-
- @property
- def shape(self) -> Shape:
- """The shape of the tensor. Immutable."""
- return self._shape
-
- @property
- def raw(self) -> TArrayCompatible:
- """Backing data of the tensor. Immutable."""
- return self._raw # type: ignore[return-value]
-
- def numpy(self) -> np.ndarray:
- """Return the tensor as a numpy array.
-
- When the data type is not supported by numpy, the value is the bit representation
- of the dtype:
-
- - ``int8`` for int4, with the sign bit extended to 8 bits.
- - ``uint8`` for uint4.
- - ``uint8`` for 8-bit data types like float8.
- - ``uint16`` for bfloat16.
- """
- if isinstance(self._raw, np.ndarray):
- return self._raw
- # We do not cache the value to save memory
- return self.__array__()
-
- def tobytes(self) -> bytes:
- """Returns the value as bytes encoded in little endian.
-
- Override this method for more efficient serialization when the raw
- value is not a numpy array.
- """
- # TODO(justinchuby): Support DLPack
- array = self.numpy()
- if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
- # Pack the array into int4
- array = _type_casting.pack_int4(array)
- else:
- assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
- if not _IS_LITTLE_ENDIAN:
- array = array.view(array.dtype.newbyteorder("<"))
- return array.tobytes()
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
-
-class ExternalTensor(TensorBase, _protocols.TensorProtocol):
- """An immutable concrete tensor with its data store on disk.
-
- This class uses memory mapping to avoid loading the tensor into memory,
- when the data type is supported by numpy. Otherwise, the tensor is loaded
- into memory lazily when accessed.
-
- Calling :attr:`shape` does not incur IO. Checking shape before loading
- the tensor is recommended if IO overhead and memory usage is a concern.
-
- To obtain an array, call :meth:`numpy`. To obtain the bytes,
- call :meth:`tobytes`.
-
- The :attribute:`path` can be a relative path or an absolute path.
- Serializers should handle the path correctly to conform with the ONNX spec.
-
- Attributes:
- path: The path to the data file. This can be a relative path or an absolute path.
- offset: The offset in bytes from the start of the file.
- length: The length of the data in bytes.
- dtype: The data type of the tensor.
- shape: The shape of the tensor.
- name: The name of the tensor. It must be specified.
- doc_string: The documentation string.
- metadata_props: The metadata properties.
- """
-
- __slots__ = (
- "_path",
- "_offset",
- "_length",
- "_dtype",
- "_shape",
- "name",
- "doc_string",
- "_array",
- "raw",
- "_metadata_props",
- "_metadata",
- )
-
- def __init__(
- self,
- path: os.PathLike | str,
- offset: int | None,
- length: int | None,
- dtype: _enums.DataType,
- *,
- shape: Shape,
- name: str,
- doc_string: str | None = None,
- metadata_props: dict[str, str] | None = None,
- ) -> None:
- self._path = path
- self._offset: int | None = offset
- self._length: int | None = length
- self._dtype: _enums.DataType = dtype
- self.name: str = name # mutable
- self._shape: Shape = shape
- self._shape._frozen = True
- self.doc_string: str | None = doc_string # mutable
- self._array: np.ndarray | None = None
- self.raw: mmap.mmap | None = None
- self._metadata_props = metadata_props
- self._metadata: _metadata.MetadataStore | None = None
-
- @property
- def path(self) -> str | os.PathLike:
- # Immutable
- return self._path
-
- @property
- def offset(self) -> int | None:
- # Immutable
- return self._offset
-
- @property
- def length(self) -> int | None:
- # Immutable
- return self._length
-
- @property
- def dtype(self) -> _enums.DataType:
- # Immutable
- return self._dtype
-
- @property
- def shape(self) -> Shape:
- # Immutable
- return self._shape
-
- def _load(self):
- assert self._array is None, "Bug: The array should be loaded only once."
- # Map the whole file into the memory
- # TODO(justinchuby): Verify if this would exhaust the memory address space
- with open(self._path, "rb") as f:
- self.raw = mmap.mmap(
- f.fileno(),
- 0,
- access=mmap.ACCESS_READ,
- )
- # Handle the byte order correctly by always using little endian
- dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
- self._array = np.frombuffer(
- self.raw, dtype=dt, offset=self.offset or 0, count=self.size
- ).reshape(self.shape.numpy())
-
- def __array__(self, dtype: Any = None) -> np.ndarray:
- if self._array is None:
- self._load()
- assert self._array is not None
- return self._array.__array__(dtype)
-
- def __dlpack__(self, *, stream: Any = None) -> Any:
- return self.numpy().__dlpack__(stream=stream)
-
- def __repr__(self) -> str:
- return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})"
-
- def numpy(self) -> np.ndarray:
- """Return the tensor as a numpy array.
-
- The data will be memory mapped into memory and will not taken up physical memory space.
- """
- if self._array is None:
- self._load()
- assert self._array is not None
- return self._array
-
- def tobytes(self) -> bytes:
- """Return the bytes of the tensor.
-
- This will load the tensor into memory.
- """
- if self.raw is None:
- self._load()
- assert self.raw is not None
- offset = self._offset or 0
- length = self._length or self.nbytes
- return self.raw[offset : offset + length]
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
-
-class StringTensor(TensorBase, _protocols.TensorProtocol):
- """Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
-
- __slots__ = (
- "_raw",
- "_shape",
- "name",
- "doc_string",
- "_metadata_props",
- "_metadata",
- )
-
- def __init__(
- self,
- value: Sequence[bytes] | npt.NDArray[np.bytes_],
- *,
- shape: Shape | None = None,
- name: str = "",
- doc_string: str | None = None,
- metadata_props: dict[str, str] | None = None,
- ) -> None:
- """Initialize a tensor.
-
- Args:
- value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes.
- shape: The shape of the tensor. If None, the shape is obtained from the value.
- name: The name of the tensor.
- doc_string: The documentation string.
- metadata_props: The metadata properties.
- """
- if shape is None:
- if not hasattr(value, "shape"):
- raise ValueError(
- f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
- "Please specify the shape explicitly."
- )
- self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
- else:
- self._shape = shape
- self._shape._frozen = True
- self._raw = value
- self.name = name
- self.doc_string = doc_string
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props = metadata_props
-
- def __array__(self, dtype: Any = None) -> np.ndarray:
- if isinstance(self._raw, np.ndarray):
- return self._raw
- assert isinstance(
- self._raw, Sequence
- ), f"Bug: Expected a sequence, got {type(self._raw)}"
- return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())
-
- def __dlpack__(self, *, stream: Any = None) -> Any:
- del stream # unused
- raise TypeError("StringTensor does not support DLPack")
-
- def __dlpack_device__(self) -> tuple[int, int]:
- raise TypeError("StringTensor does not support DLPack")
-
- def __repr__(self) -> str:
- return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
-
- @property
- def dtype(self) -> _enums.DataType:
- """The data type of the tensor. Immutable."""
- return _enums.DataType.STRING
-
- @property
- def shape(self) -> Shape:
- """The shape of the tensor. Immutable."""
- return self._shape
-
- @property
- def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
- """Backing data of the tensor. Immutable."""
- return self._raw # type: ignore[return-value]
-
- def numpy(self) -> npt.NDArray[np.bytes_]:
- """Return the tensor as a numpy array."""
- return self.__array__()
-
- def tobytes(self) -> bytes:
- raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.")
-
- def string_data(self) -> Sequence[bytes]:
- """Return the string data of the tensor."""
- if isinstance(self._raw, np.ndarray):
- return self._raw.flatten().tolist()
- return self._raw
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
-
-class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
- __slots__ = ("_value",)
-
- def __init__(self, value: str | None) -> None:
- """Initialize a symbolic dimension.
-
- Args:
- value: The value of the dimension. It should not be an int.
- """
- if isinstance(value, int):
- raise TypeError("The value of a SymbolicDim cannot be an int")
- self._value = value
-
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, SymbolicDim):
- return self.value == other
- return self.value == other.value
-
- def __hash__(self) -> int:
- return hash(self.value)
-
- @property
- def value(self) -> str | None:
- return self._value
-
- def __str__(self) -> str:
- return f"{self._value}"
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self._value})"
-
-
-class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
- __slots__ = ("_dims", "_frozen")
-
- def __init__(
- self,
- dims: Iterable[int | SymbolicDim | str | None],
- /,
- denotations: Iterable[str | None] | None = None,
- frozen: bool = False,
- ) -> None:
- """Initialize a shape.
-
- Args:
- dims: The dimensions of the shape. Each dimension can be an integer or a
- SymbolicDim or any Python object. When a ``dim`` is not an integer or a
- SymbolicDim, it is converted to a SymbolicDim.
- denotations: The denotations of the dimensions. If None, the denotations are not set.
- Standard denotation can optionally be used to denote tensor
- dimensions with standard semantic descriptions to ensure
- that operations are applied to the correct axis of a tensor.
- Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
- for pre-defined dimension denotations.
- frozen: If True, the shape is immutable and cannot be modified. This
- is useful when the shape is initialized by a Tensor.
- """
- self._dims: list[int | SymbolicDim] = [
- SymbolicDim(dim) if not isinstance(dim, (int, SymbolicDim)) else dim
- for dim in dims
- ]
- self._denotations: list[str | None] = (
- list(denotations) if denotations is not None else [None] * len(self._dims)
- )
- if len(self._denotations) != len(self._dims):
- raise ValueError(
- "The number of denotations, when provided, must be equal to the number of dimensions."
- )
- self._frozen: bool = frozen
-
- @property
- def dims(self) -> tuple[int | SymbolicDim, ...]:
- """All dimensions in the shape.
-
- This property is read-only. Use __getitem__ and __setitem__ to modify the shape or create a new shape.
- """
- return tuple(self._dims)
-
- def rank(self) -> int:
- """The rank of the shape."""
- return len(self._dims)
-
- def numpy(self) -> tuple[int, ...]:
- if any(not isinstance(dim, int) for dim in self._dims):
- raise ValueError(f"Cannot convert the shape {self} to a tuple of ints")
- return tuple(dim for dim in self._dims) # type: ignore
-
- def __len__(self) -> int:
- return len(self._dims)
-
- def __iter__(self) -> Iterator[int | SymbolicDim]:
- return iter(self._dims)
-
- @typing.overload
- def __getitem__(self, index: int) -> int | SymbolicDim: ...
-
- @typing.overload
- def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ...
-
- def __getitem__(self, index):
- return tuple(self._dims)[index]
-
- def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None:
- """Set the dimension at the index.
-
- Args:
- index: The index of the dimension.
- value: The value of the dimension.
-
- Raises:
- TypeError: If the shape is frozen and cannot be modified.
- TypeError: If the value is not an int or SymbolicDim.
- """
- if self._frozen:
- raise TypeError("The shape is frozen and cannot be modified.")
- if isinstance(value, str) or value is None:
- value = SymbolicDim(value)
- if not isinstance(value, (int, SymbolicDim)):
- raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'")
-
- self._dims[index] = value
-
- def get_denotation(self, index: int) -> str | None:
- """Return the denotation of the dimension at the index.
-
- Args:
- index: The index of the dimension.
-
- Returns:
- The denotation of the dimension.
- """
- return self._denotations[index]
-
- def set_denotation(self, index: int, denotation: str | None) -> None:
- """Set the denotation of the dimension at the index.
-
- Args:
- index: The index of the dimension.
- denotation: The denotation of the dimension.
- """
- self._denotations[index] = denotation
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self._dims!r})"
-
- def __str__(self) -> str:
- """Return a string representation of the shape.
-
- E.g. [n,1,3]
- """
- return f"[{','.join([str(dim) for dim in self._dims])}]"
-
- def __eq__(self, other: object) -> bool:
- """Return True if the shapes are equal.
-
- Two shapes are eqaul if all their dimensions are equal.
- """
- if isinstance(other, Shape):
- return self._dims == other._dims
- if not isinstance(other, Iterable):
- return False
- return self._dims == list(other)
-
- def __ne__(self, other: object) -> bool:
- return not self.__eq__(other)
-
-
-def _quoted(string: str) -> str:
- """Return a quoted string.
-
- This function is used to quote value/node names in the IR for better readability.
- """
- return f'"{string}"'
-
-
-class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
- """IR Node.
-
- If the ``graph`` is provided, the node will be added to the graph. Otherwise,
- user is responsible to call ``graph.append(node)`` (or other mutation methods
- in :class:`Graph`) to add the node to the graph.
-
- After the node is initialized, it will add itself as a user of the input values.
-
- The output values of the node are created during node initialization and are immutable.
- To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
- the new output values by calling :meth:`replace_input_with` on the using nodes
- of this node's outputs.
- """
-
- __slots__ = (
- "_name",
- "_domain",
- "_op_type",
- "_inputs",
- "_outputs",
- "_attributes",
- "_overload",
- "_version",
- "doc_string",
- "_metadata",
- "_metadata_props",
- "_graph",
- )
-
- def __init__(
- self,
- domain: str,
- op_type: str,
- inputs: Iterable[Value | None],
- attributes: Iterable[Attr | RefAttr] = (),
- *,
- overload: str = "",
- num_outputs: int | None = None,
- outputs: Sequence[Value] | None = None,
- version: int | None = None,
- graph: Graph | None = None,
- name: str | None = None,
- doc_string: str | None = None,
- metadata_props: dict[str, str] | None = None,
- ):
- """Initialize a node and add it as a user of the input values.
-
- Args:
- domain: The domain of the operator. For onnx operators, this is an empty string.
- op_type: The name of the operator.
- inputs: The input values. When an input is None, it is an empty input.
- attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
- overload: The overload name when the node is invoking a function.
- num_outputs: The number of outputs of the node. If not specified, the number is 1.
- outputs: The output values. If None, the outputs are created during initialization.
- version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
- graph: The graph that the node belongs to. If None, the node is not added to any graph.
- A `Node` must belong to zero or one graph.
- name: The name of the node. If None, the node is anonymous.
- doc_string: The documentation string.
- metadata_props: The metadata properties.
-
- Raises:
- TypeError: If the attributes are not Attr or RefAttr.
- ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
- ValueError: If an output value is None, when outputs is specified.
- ValueError: If an output value has a producer set already, when outputs is specified.
- """
- self._name = name
- self._domain: str = domain
- self._op_type: str = op_type
- # NOTE: Make inputs immutable with the assumption that they are not mutated
- # very often. This way all mutations can be tracked.
- # If necessary, we can cache the inputs and outputs as tuples.
- self._inputs: tuple[Value | None, ...] = tuple(inputs)
- # Values belong to their defining nodes. The values list is immutable
- self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
- attributes = tuple(attributes)
- if attributes and not isinstance(attributes[0], (Attr, RefAttr)):
- raise TypeError(
- f"Expected the attributes to be Attr or RefAttr, got {type(attributes[0])}. "
- "If you are copying the attributes from another node, make sure you call "
- "node.attributes.values() because it is a dictionary."
- )
- self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict(
- (attr.name, attr) for attr in attributes
- )
- self._overload: str = overload
- # TODO(justinchuby): Potentially support a version range
- self._version: int | None = version
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props: dict[str, str] | None = metadata_props
- self._graph: Graph | None = graph
- self.doc_string = doc_string
-
- # Add the node as a use of the inputs
- for i, input_value in enumerate(self._inputs):
- if input_value is not None:
- input_value._add_usage(self, i) # pylint: disable=protected-access
-
- # Add the node to the graph if graph is specified
- if self._graph is not None:
- self._graph.append(self)
-
- def _create_outputs(
- self, num_outputs: int | None, outputs: Sequence[Value] | None
- ) -> tuple[Value, ...]:
- """Check the parameters and create outputs for the node.
-
- Args:
- num_outputs: The number of outputs of the node.
- outputs: The output values of the node.
-
- Returns:
- The output values of the node.
-
- Raises:
- ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
- ValueError: If an output value is None.
- ValueError: If an output value has a producer set already.
- """
- # Check num_outputs and outputs are consistent
- if num_outputs is not None and outputs is not None and num_outputs != len(outputs):
- raise ValueError(
- "num_outputs must be the same as len(outputs) when num_outputs is specified."
- "num_outputs: {num_outputs}, outputs: {outputs}"
- )
- # 1. If outputs is specified (can be empty []), use the outputs
- if outputs is not None:
- # Check all output values are valid first
- for output in outputs:
- if output is None:
- raise ValueError(f"Output value cannot be None. All outputs: {outputs}")
- if output.producer() is not None:
- raise ValueError(
- f"Supplied output value cannot have a producer when used for initializing a Node. "
- f"Output: {output}. All outputs: {outputs}"
- )
- result = []
- for i, output in enumerate(outputs):
- output._producer = self # pylint: disable=protected-access
- output._index = i # pylint: disable=protected-access
- result.append(output)
- return tuple(result)
-
- # 2. If num_outputs is specified, create num_outputs outputs
- if num_outputs is None:
- # Default to 1 output
- num_outputs = 1
- assert num_outputs is not None
- return tuple(Value(self, index=i) for i in range(num_outputs))
-
- def __str__(self) -> str:
- node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * (
- self._overload != ""
- )
- inputs_text = (
- "("
- + ", ".join(
- [
- (
- f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}"
- if x is not None
- else "None"
- )
- for x in self._inputs
- ]
- )
- + ")"
- )
- attributes_text = (
- (" {" + ", ".join([f"{k}={v}" for k, v in self._attributes.items()]) + "}")
- if self._attributes
- else ""
- )
- outputs_text = ", ".join(str(x) for x in self._outputs)
-
- return f"{outputs_text} ⬅️ {node_type_text}{inputs_text}{attributes_text}"
-
- def __repr__(self) -> str:
- return (
- f"{self.__class__.__name__}(name={self._name!r}, domain={self._domain!r}, "
- f"op_type={self._op_type!r}, inputs={self._inputs!r}, attributes={self._attributes!r}, "
- f"overload={self._overload!r}, outputs={self._outputs!r}, "
- f"version={self._version!r}, doc_string={self.doc_string!r})"
- )
-
- @property
- def name(self) -> str | None:
- return self._name
-
- @name.setter
- def name(self, value: str | None) -> None:
- self._name = value
-
- @property
- def domain(self) -> str:
- return self._domain
-
- @domain.setter
- def domain(self, value: str) -> None:
- self._domain = value
-
- @property
- def version(self) -> int | None:
- return self._version
-
- @version.setter
- def version(self, value: int | None) -> None:
- self._version = value
-
- @property
- def op_type(self) -> str:
- return self._op_type
-
- @op_type.setter
- def op_type(self, value: str) -> None:
- self._op_type = value
-
- @property
- def overload(self) -> str:
- return self._overload
-
- @overload.setter
- def overload(self, value: str) -> None:
- self._overload = value
-
- @property
- def inputs(self) -> Sequence[Value | None]:
- return self._inputs
-
- @inputs.setter
- def inputs(self, _: Any) -> None:
- raise AttributeError(
- "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
- )
-
- def replace_input_with(self, index: int, value: Value | None) -> None:
- """Replace an input with a new value."""
- if index < 0 or index >= len(self.inputs):
- raise ValueError(f"Index out of range: {index}")
- old_input = self.inputs[index]
- self._inputs = tuple(
- value if i == index else old_input for i, old_input in enumerate(self.inputs)
- )
- if old_input is not None:
- old_input._remove_usage(self, index) # pylint: disable=protected-access
- if value is not None:
- value._add_usage(self, index) # pylint: disable=protected-access
-
- def prepend(self, /, nodes: Node | Iterable[Node]) -> None:
- """Insert a node before this node in the list of nodes in the graph.
-
- It is the same as calling ``graph.insert_before(self, nodes)``.
-
- Example::
-
- Before: previous_node -> self
- previous_node' -> node -> next_node'
- After: previous_node -> node -> self
- previous_node' -> next_node'
-
- Args:
- nodes: A node or a sequence of nodes to put before this node.
- """
- if self._graph is None:
- raise ValueError("The node to prepend to does not belong to any graph.")
- self._graph.insert_before(self, nodes)
-
- def append(self, /, nodes: Node | Iterable[Node]) -> None:
- """Insert a node after this node in the list of nodes in the graph.
-
- It is the same as calling ``graph.insert_after(self, nodes)``.
-
- Example::
-
- Before: previous_node -> self
- previous_node' -> node -> next_node'
- After: previous_node -> self -> node
- previous_node' -> next_node'
-
- Args:
- nodes: A node or a sequence of nodes to put after this node.
- """
- if self._graph is None:
- raise ValueError("The node to append to does not belong to any graph.")
- self._graph.insert_after(self, nodes)
-
- @property
- def outputs(self) -> Sequence[Value]:
- return self._outputs
-
- @outputs.setter
- def outputs(self, _: Sequence[Value]) -> None:
- raise AttributeError("outputs is immutable. Please create a new node instead.")
-
- @property
- def attributes(self) -> OrderedDict[str, Attr | RefAttr]:
- return self._attributes
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- @property
- def graph(self) -> Graph | None:
- return self._graph
-
- @graph.setter
- def graph(self, value: Graph | None) -> None:
- self._graph = value
-
- def op_identifier(self) -> _protocols.OperatorIdentifier:
- return self.domain, self.op_type, self.overload
-
- def display(self, *, page: bool | None = None) -> None:
- # Add the node's name to the displayed text
- print(f"Node: {self.name!r}")
- if self.doc_string:
- print(f"Doc: {self.doc_string}")
- super().display(page=page)
-
-
-class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable):
- """Tensor types that are non recursive types."""
-
- __slots__ = ("_dtype", "denotation")
-
- def __init__(self, dtype: _enums.DataType, *, denotation: str | None = None) -> None:
- self._dtype = dtype
- self.denotation = denotation
-
- @property
- def dtype(self) -> _enums.DataType:
- return self._dtype
-
- @dtype.setter
- def dtype(self, value: _enums.DataType) -> None:
- self._dtype = value
-
- @property
- def elem_type(self) -> _enums.DataType:
- """Return the element type of the tensor type"""
- return self.dtype
-
- def __eq__(self, other: object) -> bool:
- if self.__class__ is not other.__class__:
- return False
- return self.dtype == other.dtype # type: ignore[attr-defined]
-
- def __repr__(self) -> str:
- # Remove "Type" from name for display
- short_name = self.__class__.__name__[:-4]
- return f"{short_name}({self.dtype!r})"
-
-
-class TensorType(_TensorTypeBase):
- """A type that represents a tensor."""
-
- def __str__(self) -> str:
- return f"{self.dtype}"
-
-
-class SparseTensorType(_TensorTypeBase):
- """A type that represents a sparse tensor."""
-
-
-class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable):
- """Base for recursive types like Optional and Sequence."""
-
- __slots__ = ("_elem_type", "denotation")
-
- def __init__(
- self, elem_type: _protocols.TypeProtocol, *, denotation: str | None = None
- ) -> None:
- self._elem_type = elem_type
- self.denotation = denotation
-
- @property
- def dtype(self) -> _enums.DataType:
- return self._elem_type.dtype
-
- @dtype.setter
- def dtype(self, value: _enums.DataType) -> None:
- self._elem_type.dtype = value
-
- @property
- def elem_type(self) -> _protocols.TypeProtocol:
- return self._elem_type
-
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, _RecursiveTypeBase):
- return False
- if self.__class__ != other.__class__:
- return False
- # Recursively compare the type of the elements
- return self.elem_type == other.elem_type
-
- def __repr__(self) -> str:
- # Remove "Type" from name for display
- short_name = self.__class__.__name__[:-4]
- return f"{short_name}({self.elem_type!r})"
-
-
-class SequenceType(_RecursiveTypeBase):
- """A type that represents a sequence of elements."""
-
-
-class OptionalType(_RecursiveTypeBase):
- """A type that represents an optional element."""
-
-
-class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
- """IR Value.
-
- A value is a named entity that can be used to represent an input or output of a graph,
- a function, or a node. The information it stores generalizes over ``ValueInfoProto``
- in the ONNX specification.
-
- A :class:`Value` is always not owned or owned by exactly one node. When the value is not
- owned, it must be an input of a graph or a function. ``producer`` and ``index``
- are ``None``.
-
- When the value is owned by a node, it is an output of the node.
- The node that produces the value can be accessed with :meth:`producer`.
- The index of the output of the node that produces the value can be accessed with
- :meth:`index`.
-
- To find all the nodes that use this value as an input, call :meth:`uses`.
-
- To check if the value is an output of a graph, call :meth:`is_graph_output`.
-
- Attributes:
- name: The name of the value. A value is always named when it is part of a graph.
- shape: The shape of the value.
- type: The type of the value.
- metadata_props: Metadata.
- """
-
- __slots__ = (
- "_const_value",
- "_index",
- "_metadata_props",
- "_metadata",
- "_name",
- "_producer",
- "_shape",
- "_type",
- "_uses",
- "doc_string",
- )
-
- def __init__(
- self,
- producer: Node | None,
- *,
- index: int | None,
- name: str | None = None,
- shape: Shape | None = None,
- type: _protocols.TypeProtocol | None = None,
- doc_string: str | None = None,
- const_value: _protocols.TensorProtocol
- | Sequence[_protocols.TensorProtocol]
- | None = None,
- ) -> None:
- # producer is None when the value is an input or an initializer
- self._producer: Node | None = producer
- self._index: int | None = index
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props: dict[str, str] | None = None
-
- self._name: str | None = name
- self._shape: Shape | None = shape
- self._type: _protocols.TypeProtocol | None = type
- # TODO(justinchuby): Handle initialization when a const value is provided
- # We can get shape and type information from the const value
- self._const_value = const_value
- # Use a collection of (Node, int) to store uses. This is needed
- # because a single use can use the same value multiple times.
- # Use a dictionary to preserve insertion order so that the visiting order is deterministic
- self._uses: dict[tuple[Node, int], None] = {}
- self.doc_string = doc_string
-
- def __repr__(self) -> str:
- value_name = self.name if self.name else "anonymous:" + str(id(self))
- producer = self.producer()
- producer_text = (
- producer.name is not None or "anonymous_node:" + str(id(producer))
- if producer is not None
- else None
- )
- return f"{self.__class__.__name__}({value_name!r}, type={self.type!r}, shape={self.shape}, producer={producer_text}, index={self.index()})"
-
- def __str__(self) -> str:
- value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
- shape_text = str(self.shape) if self.shape is not None else "?"
- type_text = str(self.type) if self.type is not None else "?"
-
- # Quote the name because in reality the names can have invalid characters
- # that make them hard to read
- return f"%{_quoted(value_name)}<{type_text},{shape_text}>"
-
- def producer(self) -> Node | None:
- """The node that produces this value."""
- return self._producer
-
- def index(self) -> int | None:
- """The index of the output of the defining node."""
- return self._index
-
- def uses(self) -> Collection[tuple[Node, int]]:
- """Return a set of uses of the value.
-
- The set contains tuples of ``(Node, index)`` where the index is the index of the input
- of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
- """
- return self._uses.keys()
-
- def _add_usage(self, use: Node, index: int) -> None:
- """Add a usage of this value.
-
- This is an internal method. It should only be called by the Node class.
- """
- self._uses[(use, index)] = None
-
- def _remove_usage(self, use: Node, index: int) -> None:
- """Remove a node from the uses of this value.
-
- This is an internal method. It should only be called by the Node class.
- """
- self._uses.pop((use, index))
-
- @property
- def name(self) -> str | None:
- return self._name
-
- @name.setter
- def name(self, value: str | None) -> None:
- self._name = value
-
- @property
- def type(self) -> _protocols.TypeProtocol | None:
- """The type of the tensor.
-
- Example types can be ``TensorType``, ``SparseTensorType``, ``SequenceType``, ``OptionalType``.
- To obtain the data type of the tensor, use ``type.dtype`` or conveniently
- :attribute:`dtype`.
- """
- return self._type
-
- @type.setter
- def type(self, value: _protocols.TypeProtocol | None) -> None:
- self._type = value
-
- @property
- def dtype(self) -> _enums.DataType | None:
- """The data type of the tensor."""
- if self._type is None:
- return None
- return self._type.dtype
-
- @dtype.setter
- def dtype(self, value: _enums.DataType) -> None:
- """Set the data type of the tensor.
-
- If the type is not set, it will be initialized to a new TensorType. To
- set the type as other types like ``SequenceType``, initialize the type
- then set :attribute:`type` instead.
- """
- if self._type is None:
- self._type = TensorType(value)
- else:
- self._type.dtype = value
-
- @property
- def shape(self) -> Shape | None:
- return self._shape
-
- @shape.setter
- def shape(self, value: Shape | None) -> None:
- if value is None:
- self._shape = None
- return
- if isinstance(value, Shape):
- self._shape = value
- return
- raise TypeError(f"Expected value to be a Shape or None, got '{type(value)}'")
-
- @property
- def const_value(
- self,
- ) -> _protocols.TensorProtocol | Sequence[_protocols.TensorProtocol] | None:
- """A concrete value.
-
- The value can be backed by different raw data types, such as numpy arrays.
- The only guarantee is that it conforms TensorProtocol.
- """
- return self._const_value
-
- @const_value.setter
- def const_value(
- self,
- value: _protocols.TensorProtocol | Sequence[_protocols.TensorProtocol] | None,
- ) -> None:
- self._const_value = value
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- def is_graph_output(self) -> bool:
- """Whether the value is an output of a graph."""
- if (producer := self.producer()) is None:
- return False
- if (graph := producer.graph) is None:
- return False
- # Cannot use `in` because __eq__ may be defined by subclasses, even though
- # it is not recommended
- return any(output is self for output in graph.outputs)
-
-
-class Input(Value):
- """Input of a Graph or a Function."""
-
- # Slots already defined in Value
- __slots__ = ()
-
- def __init__(
- self,
- name: str | None = None,
- shape: Shape | None = None,
- type: _protocols.TypeProtocol | None = None,
- doc_string: str | None = None,
- ) -> None:
- super().__init__(
- None, index=None, name=name, shape=shape, type=type, doc_string=doc_string
- )
-
-
-def _check_node_safe_to_remove(
- node: Node, to_remove: AbstractSet[Node], graph_outputs: AbstractSet[Value]
-) -> None:
- """Check if a node is safe to remove.
-
- 1. It checks to make sure there are no users of the node that are not
- to be removed before removing it.
- 2. It checks the node does not contribute to any graph outputs.
-
- This check is typically O(1) assuming the number of uses of the node is small
-
- Args:
- node: The node to check.
- to_remove: A set of nodes that are to be removed.
- This set is used to check if the node is still being used by other
- nodes that are not to be removed.
- graph_outputs: A set of values that are outputs of the graph.
-
- Raises:
- ValueError: If the node does not belong to this graph or if there are users of the node.
- ValueError: If the node is still being used by other nodes not to be removed.
- """
- for output in node.outputs:
- if output in graph_outputs:
- raise ValueError(
- f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True."
- )
- for use, _ in output.uses():
- if use in to_remove:
- continue
- raise ValueError(
- f"Node '{use!r}' is still being used by other nodes that are not to be "
- f"removed. All of its uses: {list(output.uses())!r}"
- )
-
-
-class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
- """IR Graph.
-
- Graph represents a computation graph. In addition to the ONNX specification
- specified fields, it also contains a mapping of :attr:`opset_imports`. This
- allows different subgraphs to import different opsets. It is the responsibility
- of the deserializer to reconcile the different opsets.
-
- The `nodes` are not guaranteed to be topologically sorted. But the
- iteration order should be deterministic across different runs. It is the
- responsibility of the user to maintain a topological order of the nodes.
-
- Note that there is not a ``node`` attribute in the Graph. The Graph can be
- seen as a Sequence of nodes and should be used as such. For example, to obtain
- all nodes as a list, call ``list(graph)``.
-
- Attributes:
- name: The name of the graph.
- inputs: The input values of the graph.
- outputs: The output values of the graph.
- initializers: The initializers in the graph.
- doc_string: Documentation string.
- opset_imports: Opsets imported by the graph.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- __slots__ = (
- "name",
- "_inputs",
- "_outputs",
- "_initializers",
- "_doc_string",
- "_opset_imports",
- "_nodes",
- "_metadata",
- "_metadata_props",
- "_name_authority",
- )
-
- def __init__(
- self,
- inputs: Sequence[Input],
- outputs: Sequence[Value],
- *,
- nodes: Iterable[Node],
- initializers: Sequence[_protocols.TensorProtocol] = (),
- doc_string: str | None = None,
- opset_imports: dict[str, int] | None = None,
- name: str | None = None,
- metadata_props: dict[str, str] | None = None,
- ):
- self.name = name
-
- # Private fields that are not to be accessed by any other classes
- self._inputs = list(inputs)
- self._outputs = list(outputs)
- for initializer in initializers:
- if isinstance(initializer, str):
- raise TypeError(
- "Initializer must be a TensorProtocol, not a string. "
- "If you are copying the initializers from another graph, "
- "make sure you call graph.initializers.values() because it is a dictionary."
- )
- if initializer.name is None:
- raise ValueError(f"Initializer must have a name: {initializer}")
- self._initializers = {tensor.name: tensor for tensor in initializers}
- self._doc_string = doc_string
- self._opset_imports = opset_imports or {}
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props: dict[str, str] | None = metadata_props
- self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet()
- # Be sure the initialize the name authority before extending the nodes
- # because it is used to name the nodes and their outputs
- self._name_authority = _name_authority.NameAuthority()
- # Call self.extend not self._nodes.extend so the graph reference is added to the nodes
- self.extend(nodes)
-
- @property
- def inputs(self) -> list[Input]:
- return self._inputs
-
- @property
- def outputs(self) -> list[Value]:
- return self._outputs
-
- @property
- def initializers(self) -> dict[str, _protocols.TensorProtocol]:
- return self._initializers
-
- @property
- def doc_string(self) -> str | None:
- return self._doc_string
-
- @doc_string.setter
- def doc_string(self, value: str | None) -> None:
- self._doc_string = value
-
- @property
- def opset_imports(self) -> dict[str, int]:
- return self._opset_imports
-
- def __getitem__(self, index: int) -> Node:
- return self._nodes[index]
-
- def __len__(self) -> int:
- return len(self._nodes)
-
- def __iter__(self) -> Iterator[Node]:
- return iter(self._nodes)
-
- def __reversed__(self) -> Iterator[Node]:
- return reversed(self._nodes)
-
- def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
- """Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
- if node.graph is not None and node.graph is not self:
- raise ValueError(
- f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()."
- )
- # Give the node and its output values names if they don't not have one
- if node.name is None:
- self._name_authority.name_node(node)
- for value in node._outputs: # pylint: disable=protected-access
- if value.name is None:
- self._name_authority.name_value(value)
- node.graph = self
- return node
-
- # Mutation methods
- def append(self, node: Node, /) -> None:
- """Append a node to the graph in O(1) time.
-
- Args:
- node: The node to append.
-
- Raises:
- ValueError: If the node belongs to another graph.
- """
- self._set_node_graph_to_self_and_assign_names(node)
- self._nodes.append(node)
-
- def extend(self, nodes: Iterable[Node], /) -> None:
- """Extend the graph with the given nodes in O(#new_nodes) time.
-
- Args:
- nodes: The nodes to extend the graph with.
-
- Raises:
- ValueError: If any node belongs to another graph.
- """
- nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in nodes]
- self._nodes.extend(nodes)
-
- def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
- """Remove nodes from the graph in O(#num of nodes) time.
-
- If any errors are raise, to ensure the graph is not left in an inconsistent state,
- the graph is not modified.
-
- Args:
- nodes: The node to remove.
- safe: If True, performs the following actions before removal:
- 1. It checks to make sure there are no users of the node that are not
- to be removed before removing it.
- 2. It checks the node does not contribute to any graph outputs.
- 3. It removes references to all inputs so it is no longer a user of other nodes.
-
- Raises:
- ValueError: If any node to remove does not belong to this graph.
- ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node.
- ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed.
- """
- if not isinstance(nodes, Iterable):
- nodes_set: AbstractSet[Node] = {nodes}
- else:
- nodes_set = frozenset(nodes)
- graph_outputs = frozenset(self.outputs)
- for node in nodes_set:
- if node.graph is not self:
- raise ValueError(f"The node '{node!r}' does not belong to this graph.")
- if safe:
- # Check 1, 2
- _check_node_safe_to_remove(node, nodes_set, graph_outputs)
- for node in nodes_set:
- if safe:
- # 3. Detach from all inputs so that it is no longer a user of other nodes
- for i in range(len(node.inputs)):
- node.replace_input_with(i, None)
- # Set attributes to remove the node from this graph
- node.graph = None
- self._nodes.remove(node)
-
- def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
- """Insert new nodes after the given node in O(#new_nodes) time.
-
- Args:
- node: The node to insert after.
- new_nodes: The new nodes to insert.
-
- Raises:
- ValueError: If any node belongs to another graph.
- """
- if isinstance(new_nodes, Node):
- new_nodes = (new_nodes,)
- new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes]
- self._nodes.insert_after(node, new_nodes)
-
- def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
- """Insert new nodes before the given node in O(#new_nodes) time.
-
- Args:
- node: The node to insert before.
- new_nodes: The new nodes to insert.
-
- Raises:
- ValueError: If any node belongs to another graph.
- """
- if isinstance(new_nodes, Node):
- new_nodes = (new_nodes,)
- new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes]
- self._nodes.insert_before(node, new_nodes)
-
- def sort(self) -> None:
- """Topologically sort the nodes in the graph."""
- raise NotImplementedError("Not implemented yet")
-
- # End of mutation methods
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- def __str__(self) -> str:
- return _graph_str(self)
-
- def __repr__(self) -> str:
- return _graph_repr(self)
-
-
-def _graph_str(graph: Graph | GraphView) -> str:
- """Return a string representation of the graph."""
- # TODO(justinchuby): Show docstrings and metadata
- inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs)
- outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs)
- initializers_text = ",\n".join(str(x) for x in graph.initializers.values())
- if initializers_text:
- initializers_text = (
- "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n),"
- )
- signature = f"""\
-graph(
- name={graph.name or 'anonymous_graph:' + str(id(graph))},
- inputs=({textwrap.indent(inputs_text, ' '*8)}
- ),
- outputs=({textwrap.indent(outputs_text, ' '*8)}
- ),{textwrap.indent(initializers_text, ' '*4)}
-)"""
- node_count = len(graph)
- number_width = len(str(node_count))
- node_lines = []
- for i, node in enumerate(graph):
- node_name = node.name if node.name else f":anonymous_node:{id(node)}"
- node_text = f"# {node_name}\n{node}"
- indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
- # Remove the leading spaces
- indented_node_text = indented_node_text.strip()
- node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
- returns = ", ".join(str(x) for x in graph.outputs)
- body = (
- "{\n"
- + textwrap.indent("\n".join(node_lines), " " * 4)
- + textwrap.indent(f"\nreturn {returns}", " " * 4)
- + "\n}"
- )
-
- return f"{signature} {body}"
-
-
-def _graph_repr(graph: Graph | GraphView) -> str:
- """Return an repr string of the graph."""
- inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs)
- outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs)
- initializers_text = ",\n".join(str(x) for x in graph.initializers.values())
- if initializers_text:
- initializers_text = (
- "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n),"
- )
- return f"""\
-{graph.__class__.__name__}(
- name={graph.name or 'anonymous_graph:' + str(id(graph))!r},
- inputs=({textwrap.indent(inputs_text, ' '*8)}
- ),
- outputs=({textwrap.indent(outputs_text, ' '*8)}
- ),{textwrap.indent(initializers_text, ' '*4)}
- len()={len(graph)}
-)"""
-
-
-class GraphView(Sequence[Node], _display.PrettyPrintable):
- """A read-only view on a graph.
-
- The GraphView is useful for analysis of a subgraph. It can be initialized
- with a subset of nodes from a :class:`Graph`. Creating GraphView does not
- change the ownership of the nodes, and so it is possible to create multiple
- GraphViews that contain the same nodes. If the underlying nodes / connections
- are mutated, the mutation will be reflected in all views as well.
-
- The graph view can be serialized to ONNX::
-
- graph_proto = ir.serde.serialize_graph(graph_view)
-
- It can also be used to create a model::
-
- model = ir.Model(graph_view, ir_version=8)
- model_proto = ir.serde.serialize_model(model)
-
- The model created with a GraphView will have a fixed topology, and its graph
- will remain read-only as a GraphView. No copying will be done during the
- initialization process.
-
- Attributes:
- name: The name of the graph.
- inputs: The input values of the graph.
- outputs: The output values of the graph.
- initializers: The initializers in the graph.
- doc_string: Documentation string.
- opset_imports: Opsets imported by the graph.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- __slots__ = (
- "name",
- "inputs",
- "outputs",
- "initializers",
- "doc_string",
- "opset_imports",
- "nodes",
- "_metadata",
- "_metadata_props",
- )
-
- def __init__(
- self,
- inputs: Sequence[Value],
- outputs: Sequence[Value],
- *,
- nodes: Iterable[Node],
- initializers: Sequence[_protocols.TensorProtocol] = (),
- doc_string: str | None = None,
- opset_imports: dict[str, int] | None = None,
- name: str | None = None,
- metadata_props: dict[str, str] | None = None,
- ):
- self.name = name
- self.inputs = tuple(inputs)
- self.outputs = tuple(outputs)
- for initializer in initializers:
- if initializer.name is None:
- raise ValueError(f"Initializer must have a name: {initializer}")
- self.initializers = {tensor.name: tensor for tensor in initializers}
- self.doc_string = doc_string
- self.opset_imports = opset_imports or {}
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props: dict[str, str] | None = metadata_props
- self._nodes: tuple[Node, ...] = tuple(nodes)
-
- def __getitem__(self, index: int) -> Node:
- return self._nodes[index]
-
- def __len__(self) -> int:
- return len(self._nodes)
-
- def __iter__(self) -> Iterator[Node]:
- return iter(self._nodes)
-
- def __reversed__(self) -> Iterator[Node]:
- return reversed(self._nodes)
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- def __str__(self) -> str:
- return _graph_str(self)
-
- def __repr__(self) -> str:
- return _graph_repr(self)
-
-
-class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
- __slots__ = (
- "graph",
- "ir_version",
- "producer_name",
- "producer_version",
- "domain",
- "model_version",
- "doc_string",
- "_functions",
- "_metadata",
- "_metadata_props",
- )
- """IR Model.
-
- A model is a container for a graph and metadata.
-
- Attributes:
- graph: The graph of the model.
- ir_version: The version of the IR.
- producer_name: The name of the producer.
- producer_version: The version of the producer.
- domain: The domain of the model.
- model_version: The version of the model.
- doc_string: Documentation string.
- functions: The functions defined in the model.
- metadata_props: Metadata.
- """
-
- def __init__(
- self,
- graph: Graph,
- *,
- ir_version: int,
- producer_name: str | None = None,
- producer_version: str | None = None,
- domain: str | None = None,
- model_version: int | None = None,
- doc_string: str | None = None,
- functions: Sequence[Function] = (),
- meta_data_props: dict[str, str] | None = None,
- ) -> None:
- self.graph: Graph = graph
- self.ir_version = ir_version
- self.producer_name = producer_name
- self.producer_version = producer_version
- self.domain = domain
- self.model_version = model_version
- self.doc_string = doc_string
- self._functions = {func.identifier(): func for func in functions}
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props: dict[str, str] | None = meta_data_props
-
- @property
- def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
- return self._functions
-
- @property
- def opset_imports(self) -> dict[str, int]:
- return self.graph.opset_imports
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- def __str__(self) -> str:
- # TODO(justinchuby): Show docstrings and metadata
- signature = f"""\
-<
- ir_version={self.ir_version!r},
- opset_imports={self.opset_imports!r},
- producer_name={self.producer_name!r},
- producer_version={self.producer_version!r},
- domain={self.domain!r},
- model_version={self.model_version!r},
->"""
- graph_text = str(self.graph)
- functions_text = ",\n\n".join(str(func) for func in self.functions.values())
- return f"{signature}\n{graph_text}" + f"\n\n{functions_text}" * len(self.functions)
-
- def __repr__(self) -> str:
- return f"""\
-Model(
- ir_version={self.ir_version!r},
- opset_imports={self.opset_imports!r},
- producer_name={self.producer_name!r},
- producer_version={self.producer_version!r},
- domain={self.domain!r},
- model_version={self.model_version!r},
- functions={self.functions!r},
- graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()}
-)"""
-
-
-class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
- """IR functions.
-
- Like a graph, a function can have nodes that are not topologically sorted. It is
- the responsibility of the user to maintain a topological order of the nodes.
-
- Note that there is not a ``node`` attribute in the Function. The Function can be
- seen as a Sequence of nodes and should be used as such. For example, to obtain
- all nodes as a list, call ``list(function)``.
-
- Attributes:
- name: The function name.
- domain: The domain this function is defined in.
- overload: The overload name when the function is overloaded.
- inputs: The input values of the function.
- attributes: The attributes this function defines.
- outputs: The output values of the function.
- opset_imports: Opsets imported by the function.
- doc_string: Documentation string.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- __slots__ = (
- "_domain",
- "_name",
- "_overload",
- "_graph",
- "_attributes",
- "_metadata",
- "_metadata_props",
- )
-
- def __init__(
- self,
- domain: str,
- name: str,
- overload: str = "",
- *,
- # Ensure the inputs and outputs of the function belong to a graph
- # and not from an outer scope
- graph: Graph,
- attributes: Sequence[Attr],
- metadata_props: dict[str, str] | None = None,
- ) -> None:
- self._domain = domain
- self._name = name
- self._overload = overload
- self._graph = graph
- self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
- self._metadata: _metadata.MetadataStore | None = None
- self._metadata_props: dict[str, str] | None = metadata_props
-
- def identifier(self) -> _protocols.OperatorIdentifier:
- return self.domain, self.name, self.overload
-
- @property
- def name(self) -> str:
- return self._name
-
- @name.setter
- def name(self, value: str) -> None:
- self._name = value
-
- @property
- def domain(self) -> str:
- return self._domain
-
- @domain.setter
- def domain(self, value: str) -> None:
- self._domain = value
-
- @property
- def overload(self) -> str:
- return self._overload
-
- @overload.setter
- def overload(self, value: str) -> None:
- self._overload = value
-
- @property
- def inputs(self) -> list[Input]:
- return self._graph.inputs
-
- @property
- def outputs(self) -> list[Value]:
- return self._graph.outputs
-
- @property
- def attributes(self) -> OrderedDict[str, Attr]:
- return self._attributes
-
- def __getitem__(self, index: int) -> Node:
- return self._graph.__getitem__(index)
-
- def __len__(self) -> int:
- return self._graph.__len__()
-
- def __iter__(self) -> Iterator[Node]:
- return self._graph.__iter__()
-
- def __reversed__(self) -> Iterator[Node]:
- return self._graph.__reversed__()
-
- @property
- def doc_string(self) -> str | None:
- return self._graph.doc_string
-
- @doc_string.setter
- def doc_string(self, value: str | None) -> None:
- self._graph.doc_string = value
-
- @property
- def opset_imports(self) -> dict[str, int]:
- return self._graph.opset_imports
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
- # Mutation methods
- def append(self, node: Node, /) -> None:
- """Append a node to the function in O(1) time."""
- self._graph.append(node)
-
- def extend(self, nodes: Iterable[Node], /) -> None:
- """Extend the function with the given nodes in O(#new_nodes) time."""
- self._graph.extend(nodes)
-
- def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
- """Remove nodes from the graph in O(#num of nodes) time.
-
- If any errors are raise, to ensure the graph is not left in an inconsistent state,
- the graph is not modified.
-
- Args:
- nodes: The node to remove.
- safe: If True, performs the following actions before removal:
- 1. It checks to make sure there are no users of the node that are not
- to be removed before removing it.
- 2. It checks the node does not contribute to any graph outputs.
- 3. It removes references to all inputs so it is no longer a user of other nodes.
-
- Raises:
- ValueError: If any node to remove does not belong to this graph.
- ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node.
- ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed.
- """
- self._graph.remove(nodes, safe=safe)
-
- def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
- """Insert new nodes after the given node in O(#new_nodes) time."""
- self._graph.insert_after(node, new_nodes)
-
- def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
- """Insert new nodes before the given node in O(#new_nodes) time."""
- self._graph.insert_before(node, new_nodes)
-
- def sort(self) -> None:
- """Topologically sort the nodes in the function."""
- self._graph.sort()
-
- # End of mutation methods
-
- def __str__(self) -> str:
- full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "")
- inputs_text = ",\n".join(str(x) for x in self.inputs)
- outputs_text = ",\n".join(str(x) for x in self.outputs)
- attributes_text = ",\n".join(
- f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is None)
- for attr in self.attributes.values()
- )
- if attributes_text:
- attributes_text = (
- "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}"
- )
- signature = f"""\
-<
- opset_imports={self.opset_imports!r},
->
-def {full_name}(
- inputs=(
-{textwrap.indent(inputs_text, ' '*8)}
- ),{textwrap.indent(attributes_text, ' '*4)}
- outputs=(
-{textwrap.indent(outputs_text, ' '*8)}
- ),
-)"""
- node_count = len(self)
- number_width = len(str(node_count))
- node_lines = []
- for i, node in enumerate(self):
- node_name = node.name if node.name else f":anonymous_node:{id(node)}"
- node_text = f"# {node_name}\n{node}"
- indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
- # Remove the leading spaces
- indented_node_text = indented_node_text.strip()
- node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
- returns = ", ".join(str(x) for x in self.outputs)
- body = (
- "{\n"
- + textwrap.indent("\n".join(node_lines), " " * 4)
- + textwrap.indent(f"\nreturn {returns}", " " * 4)
- + "\n}"
- )
-
- return f"{signature} {body}"
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
-
-
-class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
- """Reference attribute."""
-
- __slots__ = ("_name", "_ref_attr_name", "_type", "doc_string")
-
- def __init__(
- self,
- name: str,
- ref_attr_name: str,
- type: _enums.AttributeType,
- *,
- doc_string: str | None = None,
- ) -> None:
- self._name = name
- self._ref_attr_name = ref_attr_name
- self._type = type
- self.doc_string = doc_string
-
- @property
- def name(self) -> str:
- return self._name
-
- @name.setter
- def name(self, value: str) -> None:
- self._name = value
-
- @property
- def ref_attr_name(self) -> str:
- return self._ref_attr_name
-
- @ref_attr_name.setter
- def ref_attr_name(self, value: str) -> None:
- self._ref_attr_name = value
-
- @property
- def type(self) -> _enums.AttributeType:
- return self._type
-
- @type.setter
- def type(self, value: _enums.AttributeType) -> None:
- self._type = value
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})"
-
-
-class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
- """Base class for ONNX attributes."""
-
- __slots__ = ("name", "type", "value", "doc_string")
-
- def __init__(
- self,
- name: str,
- type: _enums.AttributeType,
- value: Any,
- *,
- doc_string: str | None = None,
- ):
- self.name = name
- self.type = type
- self.value = value
- self.doc_string = doc_string
-
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, _protocols.AttributeProtocol):
- return False
-
- if self.name != other.name:
- return False
- if self.type != other.type:
- return False
- if self.value != other.value:
- return False
- if self.doc_string != other.doc_string:
- return False
- return True
-
- def __str__(self) -> str:
- return str(self.value)
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
-
-
-class _SpecializedAttr(Attr):
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.name!r}, {self.value!r})"
-
-
-# NOTE: The following classes are just supporting classes (partially applied) for convenience
-# But I think they would be useful to have in the IR by having the type info
-# explicitly in the class type.
-class AttrFloat32(_SpecializedAttr):
- def __init__(self, name: str, value: float, doc_string: str | None = None):
- super().__init__(
- name,
- _enums.AttributeType.FLOAT,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrInt64(_SpecializedAttr):
- def __init__(self, name: str, value: int, doc_string: str | None = None):
- super().__init__(
- name,
- _enums.AttributeType.INT,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrString(_SpecializedAttr):
- def __init__(self, name: str, value: str, doc_string: str | None = None):
- super().__init__(
- name,
- _enums.AttributeType.STRING,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrTensor(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: _protocols.TensorProtocol,
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.TENSOR,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrGraph(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Graph,
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.GRAPH,
- value,
- doc_string=doc_string,
- )
-
- def __str__(self) -> str:
- return textwrap.indent("\n" + super().__str__(), " " * 4)
-
-
-class AttrFloat32s(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[float],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.FLOATS,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrInt64s(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[int],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.INTS,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrStrings(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[str],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.STRINGS,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrTensors(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[_protocols.TensorProtocol],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.TENSORS,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrGraphs(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[Graph],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.GRAPHS,
- value,
- doc_string=doc_string,
- )
-
-
-# NOTE: SparseTensor should be a sparse tensor proto
-class AttrSparseTensor(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[_protocols.SparseTensorProtocol],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.SPARSE_TENSOR,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrSparseTensors(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[_protocols.SparseTensorProtocol],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.SPARSE_TENSORS,
- value,
- doc_string=doc_string,
- )
-
-
-@dataclasses.dataclass
-class TypeAndShape:
- """Type and shape.
-
- Useful for constructing a type proto.
- """
-
- type: _protocols.TypeProtocol | None
- shape: Shape | None
-
-
-class AttrTypeProto(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: TypeAndShape,
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.TYPE_PROTO,
- value,
- doc_string=doc_string,
- )
-
-
-class AttrTypeProtos(_SpecializedAttr):
- def __init__(
- self,
- name: str,
- value: Sequence[TypeAndShape],
- doc_string: str | None = None,
- ):
- super().__init__(
- name,
- _enums.AttributeType.TYPE_PROTOS,
- value,
- doc_string=doc_string,
- )
diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py
deleted file mode 100644
index 103e5b1700..0000000000
--- a/onnxscript/ir/_core_test.py
+++ /dev/null
@@ -1,580 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-from __future__ import annotations
-
-import pathlib
-import tempfile
-import unittest
-from typing import Any
-
-import numpy as np
-import onnx
-import onnx.external_data_helper
-import parameterized
-import torch
-
-from onnxscript.ir import _core, _enums
-
-
-class TensorTest(unittest.TestCase):
- def test_initialize(self):
- tensor = _core.Tensor(
- np.random.rand(1, 2).astype(np.float32),
- dtype=_enums.DataType.FLOAT,
- shape=_core.Shape((1, 2)),
- name="test",
- )
- self.assertEqual(tensor.name, "test")
- self.assertEqual(tensor.dtype, _enums.DataType.FLOAT)
- self.assertEqual(tensor.shape, _core.Shape((1, 2)))
- np.testing.assert_array_equal(tensor, tensor)
-
- def test_init_raises_when_value_is_not_array(self):
- with self.assertRaises(TypeError):
- _core.Tensor(42)
-
- def test_init_requires_type_when_value_is_not_np_array(self):
- torch_tensor = torch.tensor(42)
- with self.assertRaises(ValueError):
- _core.Tensor(torch_tensor)
-
- @parameterized.parameterized.expand(
- [
- ("bfloat16", np.uint16, _enums.DataType.BFLOAT16),
- (
- "float8e4m3fn",
- np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)})),
- _enums.DataType.FLOAT8E4M3FN,
- ),
- ("float8e4m3fnuz", np.uint8, _enums.DataType.FLOAT8E4M3FNUZ),
- ("float8e5m2", np.uint8, _enums.DataType.FLOAT8E5M2),
- ("float8e5m2fnuz", np.uint8, _enums.DataType.FLOAT8E5M2FNUZ),
- ("int4", np.int8, _enums.DataType.INT4),
- ("int4_uint8", np.uint8, _enums.DataType.INT4),
- ("uint4", np.uint8, _enums.DataType.UINT4),
- ]
- )
- def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: _enums.DataType):
- array = np.array([0b1, 0b11], dtype=np_dtype)
- tensor = _core.Tensor(array, dtype=dtype)
- self.assertEqual(tensor.dtype, dtype)
- np.testing.assert_array_equal(tensor, array)
-
- def test_initialize_with_just_np_array(self):
- array = np.random.rand(1, 2)
- tensor = _core.Tensor(array)
- np.testing.assert_array_equal(tensor, array)
-
- def test_initialize_raises_when_numpy_dtype_doesnt_match(self):
- array = np.random.rand(1, 2).astype(np.float32)
- with self.assertRaises(TypeError):
- _core.Tensor(array, dtype=_enums.DataType.INT64)
-
- def test_initialize_raises_when_numpy_dtype_doesnt_match_custom_dtype(self):
- custom_dtype = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)}))
- array = np.random.rand(1, 2).astype(custom_dtype)
- with self.assertRaises(TypeError):
- _core.Tensor(array, dtype=_enums.DataType.BFLOAT16)
-
- def test_initialize_with_torch_tensor(self):
- array = np.random.rand(1, 2).astype(np.int64)
- np_tensor = _core.Tensor(array)
- torch_tensor = _core.Tensor(torch.tensor(array), dtype=_enums.DataType.INT64)
- np.testing.assert_array_equal(torch_tensor, array)
- np.testing.assert_array_equal(torch_tensor, np_tensor)
-
- def test_dlpack_np_to_torch(self):
- array = np.random.rand(1, 2).astype(np.float32)
- tensor = _core.Tensor(array)
- torch_tensor = torch.from_dlpack(tensor)
- np.testing.assert_array_equal(torch_tensor, array)
-
- def test_dlpack_torch_to_np(self):
- torch_tensor = torch.rand(1, 2)
- tensor = _core.Tensor(torch_tensor, dtype=_enums.DataType.FLOAT)
- array = np.from_dlpack(tensor)
- np.testing.assert_array_equal(array, torch_tensor)
-
- def test_repr(self):
- tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32))
- self.assertIsInstance(repr(tensor), str)
-
- def test_dtype_returns_data_type_enum(self):
- tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32))
- self.assertEqual(tensor.dtype, _enums.DataType.FLOAT)
-
- def test_shape(self):
- tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32))
- self.assertEqual(tensor.shape, _core.Shape((1, 2)))
-
- def test_numpy_returns_np_array(self):
- array = np.random.rand(1, 2).astype(np.float32)
- tensor = _core.Tensor(array)
- np.testing.assert_equal(tensor.numpy(), array)
-
- def test_numpy_returns_data_when_dtype_is_not_supported(self):
- array = np.array([1], dtype=np.uint8)
- tensor = _core.Tensor(array, dtype=_enums.DataType.INT4)
- np.testing.assert_equal(tensor.numpy(), array)
-
- def test_tobytes(self):
- array = np.random.rand(1, 2).astype(np.float32)
- torch_tensor = torch.tensor(array)
- tensor = _core.Tensor(torch_tensor, dtype=_enums.DataType.FLOAT)
- self.assertEqual(tensor.tobytes(), array.tobytes())
-
- def test_tobtyes_returns_packed_data_for_int4(self):
- array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8)
- # Test odd sized array
- assert len(array) % 2 == 1
- tensor = _core.Tensor(array, dtype=_enums.DataType.INT4)
- self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01")
-
- def test_tobtyes_returns_packed_data_for_uint4(self):
- array = np.array([0, 1, 2, 7, 15], dtype=np.uint8)
- # Test odd sized array
- assert len(array) % 2 == 1
- tensor = _core.Tensor(array, dtype=_enums.DataType.UINT4)
- self.assertEqual(tensor.tobytes(), b"\x10r\x0f")
-
- def test_metadata(self):
- array = np.random.rand(1, 2).astype(np.float32)
- tensor = _core.Tensor(array)
- tensor.meta["test"] = 1
- self.assertEqual(tensor.meta["test"], 1)
- tensor.metadata_props["test"] = "any string"
- self.assertEqual(tensor.metadata_props["test"], "any string")
-
-
-class ExternalTensorTest(unittest.TestCase):
- """Test the memory mapped external tensor class."""
-
- def setUp(self):
- self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
- self.external_data_name = "test_model.bin"
- self.base_path = self.temp_dir.name
- self.data = np.random.rand(2, 42).astype(np.float32)
- self.data_float16 = np.random.rand(2, 42).astype(np.float16)
- self.model = self._simple_model_with_external(
- self.base_path, self.external_data_name, self.data
- )
-
- def tearDown(self) -> None:
- self.temp_dir.cleanup()
-
- def _simple_model_with_external(
- self, base_path: str, external_data_name: str, data: np.ndarray
- ) -> onnx.ModelProto:
- input = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [None])
- output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [None])
- raw_data = data.tobytes()
- tensor = onnx.helper.make_tensor(
- "input", onnx.TensorProto.FLOAT, data.shape, raw_data, raw=True
- )
- raw_data2 = self.data_float16.tobytes()
- tensor2 = onnx.helper.make_tensor(
- "input2", onnx.TensorProto.FLOAT16, data.shape, raw_data2, raw=True
- )
- onnx.external_data_helper.set_external_data(
- tensor, external_data_name, offset=0, length=len(raw_data)
- )
- onnx.external_data_helper.set_external_data(
- tensor2, external_data_name, offset=len(raw_data), length=len(raw_data2)
- )
-
- node = onnx.helper.make_node("Identity", inputs=["input"], outputs=["output"])
- model = onnx.helper.make_model(
- onnx.helper.make_graph(
- [node], "test_graph", [input], [output], initializer=[tensor, tensor2]
- )
- )
- tensor.ClearField("raw_data")
- tensor2.ClearField("raw_data")
- # Save the data to disk
- with open(pathlib.Path(base_path) / external_data_name, "wb") as f:
- f.write(raw_data)
- f.write(raw_data2)
- return model
-
- def test_initialize(self):
- external_tensor = self.model.graph.initializer[0]
- external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
- tensor = _core.ExternalTensor(
- path=pathlib.Path(self.base_path) / external_info.location,
- offset=external_info.offset,
- length=external_info.length,
- dtype=_enums.DataType.FLOAT,
- name="input",
- shape=_core.Shape(external_tensor.dims),
- )
- self.assertEqual(tensor.dtype, _enums.DataType.FLOAT)
- np.testing.assert_equal(tensor, self.data)
- # Ensure repeated reads are consistent
- np.testing.assert_equal(tensor, self.data)
-
- def test_totypes_returns_correct_data_in(self):
- external_tensor = self.model.graph.initializer[0]
- external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
- tensor = _core.ExternalTensor(
- path=pathlib.Path(self.base_path) / external_info.location,
- offset=external_info.offset,
- length=external_info.length,
- dtype=_enums.DataType.FLOAT,
- name="input",
- shape=_core.Shape(external_tensor.dims),
- )
- external_tensor2 = self.model.graph.initializer[1]
- external_info2 = onnx.external_data_helper.ExternalDataInfo(external_tensor2)
- tensor2 = _core.ExternalTensor(
- path=pathlib.Path(self.base_path) / external_info2.location,
- offset=external_info2.offset,
- length=external_info2.length,
- dtype=_enums.DataType.FLOAT16,
- name="input",
- shape=_core.Shape(external_tensor2.dims),
- )
- self.assertEqual(tensor.tobytes(), self.data.tobytes())
- self.assertEqual(tensor2.tobytes(), self.data_float16.tobytes())
- # Ensure repeated reads are consistent
- self.assertEqual(tensor.tobytes(), self.data.tobytes())
- self.assertEqual(tensor2.tobytes(), self.data_float16.tobytes())
-
-
-class SymbolicDimTest(unittest.TestCase):
- def test_init_raises_when_value_is_int(self):
- # Static dimensions should be python integers
- with self.assertRaises(TypeError):
- _core.SymbolicDim(42)
-
- @parameterized.parameterized.expand([("str", "any string"), ("None", None)])
- def test_equality_with_other_dimensions(self, _: str, value: Any):
- dim1 = _core.SymbolicDim(value)
- dim2 = _core.SymbolicDim(value)
- self.assertEqual(dim1, dim2)
-
- @parameterized.parameterized.expand([("str", "any string"), ("None", None)])
- def test_equality_with_python_values(self, _: str, value: Any):
- dim = _core.SymbolicDim(value)
- self.assertEqual(dim, value)
- self.assertIn(value, [dim])
- self.assertIn(dim, [value])
-
- @parameterized.parameterized.expand([("str", "any string"), ("None", None)])
- def test_it_is_hashable(self, _: str, value: Any):
- dim = _core.SymbolicDim(value)
- self.assertEqual(hash(dim), hash(value))
- self.assertIn(dim, {dim})
- self.assertIn(dim, {value})
-
-
-class ShapeTest(unittest.TestCase):
- def test_init_raises_when_denotations_and_dims_have_different_lengths(self):
- with self.assertRaisesRegex(ValueError, "denotations"):
- _core.Shape([42], ["DATA_CHANNEL", "BATCH"])
-
- def test_int_dimensions_are_python_ints(self):
- shape = _core.Shape([42])
- self.assertIsInstance(shape[0], int)
-
- @parameterized.parameterized.expand(
- [
- ("empty", (), ()),
- ("1d", (42,), (42,)),
- ("int", (42, 42), (42, 42)),
- ("str", ("any string", "any string"), ("any string", "any string")),
- ("None", (None, None), (None, None)),
- ]
- )
- def test_eq_with_other_shapes(
- self, _: str, dims_1: tuple[Any, ...], dims_2: tuple[Any, ...]
- ):
- shape_1 = _core.Shape(dims_1)
- shape_2 = _core.Shape(dims_2)
- self.assertEqual(shape_1, shape_2)
-
- @parameterized.parameterized.expand(
- [
- ("empty", ()),
- ("1d", (42,)),
- ("int", (42, 42)),
- ("str", ("any string", "any string")),
- ("None", (None, None)),
- ]
- )
- def test_eq_with_tuple(self, _: str, dims: tuple[Any, ...]):
- shape = _core.Shape(dims)
- self.assertEqual(shape, dims)
-
- @parameterized.parameterized.expand(
- [
- ("empty", []),
- (
- "1d",
- [
- 42,
- ],
- ),
- ("int", [42, 42]),
- ("str", ["any string", "any string"]),
- ("None", [None, None]),
- ]
- )
- def test_eq_with_list(self, _: str, dims: list[Any]):
- shape = _core.Shape(dims)
- self.assertEqual(shape, dims)
-
- def test_eq_with_np_shape(self):
- dims = (42,)
- array = np.zeros(dims)
- shape = _core.Shape(dims)
- self.assertEqual(shape, array.shape)
-
- @parameterized.parameterized.expand(
- [
- ("empty", (), (1,)),
- ("d", (42,), (0,)),
- ("rank", (42, 42), (42, 42, 42)),
- ("str", ("any string",), (42,)),
- ("None", (None, None), (None, 42)),
- ]
- )
- def test_ne_with_other_shapes(
- self, _: str, dims_1: tuple[Any, ...], dims_2: tuple[Any, ...]
- ):
- shape_1 = _core.Shape(dims_1)
- shape_2 = _core.Shape(dims_2)
- self.assertNotEqual(shape_1, shape_2)
-
- def test_ne_with_random_object(self):
- shape = _core.Shape((42,))
- self.assertNotEqual(shape, 42)
-
- def test_setitem_raises_when_shape_is_frozen(self):
- shape = _core.Shape([42], denotations=("DATA_CHANNEL",), frozen=True)
- with self.assertRaisesRegex(TypeError, "frozen"):
- shape[0] = 1
-
- def test_getitem(self):
- shape = _core.Shape([42], denotations=("DATA_CHANNEL",))
- self.assertEqual(shape[0], 42)
-
- def test_getitem_accepts_a_slice(self):
- shape = _core.Shape([1, 2, 3, 4])
- self.assertEqual(shape[1:3], (2, 3))
-
- @parameterized.parameterized.expand(
- [
- ("int", 42),
- ("str", "any string"),
- ("None", None),
- ("SymbolicDim", _core.SymbolicDim("any string")),
- ]
- )
- def test_setitem(self, _: str, value):
- shape = _core.Shape([0])
- shape[0] = value
- dim = shape[0]
- if isinstance(dim, _core.SymbolicDim):
- self.assertEqual(dim.value, value)
- else:
- self.assertEqual(dim, value)
-
- def test_get_denotation(self):
- shape = _core.Shape([42], denotations=("DATA_CHANNEL",))
- self.assertEqual(shape.get_denotation(0), "DATA_CHANNEL")
-
- def test_set_denotation(self):
- shape = _core.Shape([42, 0], ["DATA_CHANNEL", "BATCH"])
- shape.set_denotation(1, "UPDATED")
- self.assertEqual(shape.get_denotation(1), "UPDATED")
-
- def test_set_denotation_is_still_possible_when_shape_is_frozen(self):
- shape = _core.Shape([42], denotations=("DATA_CHANNEL",), frozen=True)
- shape.set_denotation(0, "UPDATED")
- self.assertEqual(shape.get_denotation(0), "UPDATED")
-
-
-class ValueTest(unittest.TestCase):
- def test_initialize(self):
- _ = _core.Value(None, index=0)
-
- def test_meta(self):
- value = _core.Value(None, index=0)
- value.meta["test"] = 1
- self.assertEqual(value.meta["test"], 1)
- value.metadata_props["test"] = "any string"
- self.assertEqual(value.metadata_props["test"], "any string")
-
- # TODO(justinchuby): Test all methods
-
-
-class NodeTest(unittest.TestCase):
- def setUp(self) -> None:
- self.v0 = _core.Value(None, index=None)
- self.v1 = _core.Value(None, index=None)
- self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)
-
- def test_init_with_values(self):
- self.assertEqual(self.node.domain, "test")
- self.assertEqual(self.node.op_type, "TestOp")
- self.assertEqual(self.node.inputs, (self.v0, self.v1))
- self.assertEqual(len(self.node.outputs), 3)
- self.assertEqual(self.node.attributes, {})
-
- def test_init_with_preinitialized_outputs(self):
- out_1 = _core.Value(
- None,
- index=None,
- name="out_1",
- shape=_core.Shape([1]),
- type=_core.TensorType(_enums.DataType.BFLOAT16),
- )
- out_2 = _core.Value(
- None,
- index=None,
- name="out_2",
- shape=_core.Shape([2]),
- type=_core.TensorType(_enums.DataType.INT4),
- )
- node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[out_1, out_2])
- self.assertEqual(node.outputs[0].name, "out_1")
- self.assertEqual(node.outputs[0].shape, _core.Shape([1]))
- self.assertEqual(node.outputs[0].dtype, _enums.DataType.BFLOAT16)
- self.assertEqual(node.outputs[1].name, "out_2")
- self.assertEqual(node.outputs[1].shape, _core.Shape([2]))
- self.assertEqual(node.outputs[1].dtype, _enums.DataType.INT4)
- self.assertIs(node.outputs[0], out_1)
- self.assertIs(node.outputs[1], out_2)
- self.assertIs(node.outputs[0].producer(), node)
- self.assertIs(node.outputs[1].producer(), node)
- self.assertIs(node.outputs[0].index(), 0)
- self.assertIs(node.outputs[1].index(), 1)
-
- def test_init_raises_when_num_outputs_does_not_match_outputs(self):
- with self.assertRaisesRegex(ValueError, "outputs"):
- _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=2, outputs=[])
-
- def test_init_with_zero_num_outputs(self):
- node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=0)
- self.assertEqual(node.outputs, ())
-
- def test_init_with_empty_outputs(self):
- node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[])
- self.assertEqual(node.outputs, ())
-
- def test_init_produces_one_output_with_unspecified_output_argument(self):
- node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1))
- self.assertEqual(len(node.outputs), 1)
-
- def test_metadata(self):
- self.node.meta["test"] = 1
- self.assertEqual(self.node.meta["test"], 1)
- self.node.metadata_props["test"] = "any string"
- self.assertEqual(self.node.metadata_props["test"], "any string")
-
- def test_it_is_added_to_a_graph_if_specified(self):
- graph = _core.Graph(
- (self.v0, self.v1), # type: ignore
- self.node.outputs,
- nodes=(self.node,),
- opset_imports={"": 1},
- )
- self.assertIn(self.node, graph)
-
- # TODO(justinchuby): Test all methods
-
-
-class GraphTest(unittest.TestCase):
- def setUp(self) -> None:
- self.v0 = _core.Input(name="v0")
- self.v1 = _core.Input(name="v1")
- self.node = _core.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)
- self.graph = _core.Graph(
- (self.v0, self.v1),
- self.node.outputs,
- nodes=(self.node,),
- opset_imports={"": 1},
- )
-
- def test_initialize(self):
- self.assertEqual(self.graph.inputs, [self.v0, self.v1])
- self.assertEqual(self.graph.outputs, [*self.node.outputs])
- self.assertEqual(self.graph.opset_imports, {"": 1})
- self.assertEqual(self.graph.initializers, {})
- self.assertIsNone(self.graph.doc_string)
-
- def test_it_is_iterable_of_nodes(self):
- self.assertEqual(list(self.graph), [self.node])
-
- def test_metadata(self):
- self.graph.meta["test"] = 1
- self.assertEqual(self.graph.meta["test"], 1)
- self.graph.metadata_props["test"] = "any string"
- self.assertEqual(self.graph.metadata_props["test"], "any string")
-
- def test_remove_removes_node_from_graph(self):
- self.graph.remove(self.node)
- self.assertEqual(list(self.graph), [])
- self.assertIsNone(self.node.graph)
-
- def test_remove_does_not_change_input_users(self):
- self.graph.remove(self.node)
- self.assertEqual(tuple(self.v0.uses()), ((self.node, 0),))
- self.assertEqual(tuple(self.v1.uses()), ((self.node, 1),))
-
- def test_remove_does_not_change_graph_in_out(self):
- self.graph.remove(self.node)
- self.assertEqual(self.graph.inputs, [self.v0, self.v1])
- self.assertEqual(self.graph.outputs, list(self.node.outputs))
-
- def test_remove_raises_when_node_does_not_belong_to_graph(self):
- node = _core.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)
- with self.assertRaisesRegex(ValueError, "graph"):
- self.graph.remove(node)
-
- def test_remove_safe_raises_when_node_output_is_graph_output(self):
- with self.assertRaisesRegex(ValueError, "output"):
- self.graph.remove(self.node, safe=True)
-
- def test_remove_safe_raises_when_node_has_users(self):
- v0 = _core.Input(name="v0")
- v1 = _core.Input(name="v1")
- add_node = _core.Node("", "Add", inputs=(v0, v1), num_outputs=1)
- identity_node = _core.Node("", "Identity", inputs=add_node.outputs, num_outputs=1)
- graph = _core.Graph(
- (v0, v1),
- identity_node.outputs,
- nodes=(add_node, identity_node),
- opset_imports={"": 1},
- )
- with self.assertRaisesRegex(ValueError, "used by other nodes"):
- graph.remove(add_node, safe=True)
-
- def test_remove_safe_removes_uses_of_removed_nodes(self):
- v0 = _core.Input(name="v0")
- v1 = _core.Input(name="v1")
- add_node = _core.Node("", "Add", inputs=(v0, v1), num_outputs=1)
- identity_node = _core.Node("", "Identity", inputs=add_node.outputs, num_outputs=1)
- graph = _core.Graph(
- (v0, v1),
- identity_node.outputs,
- nodes=(add_node, identity_node),
- opset_imports={"": 1},
- )
- # Remove add_node and check that it is no longer a consumer of v0 and v1
- sub_node = _core.Node("", "Sub", inputs=(v0, v1), num_outputs=1)
- identity_node.replace_input_with(0, sub_node.outputs[0])
- graph.insert_before(identity_node, sub_node)
- graph.remove(add_node, safe=True)
- self.assertEqual(tuple(v0.uses()), ((sub_node, 0),))
- self.assertEqual(tuple(v1.uses()), ((sub_node, 1),))
- self.assertEqual(tuple(graph), (sub_node, identity_node))
- self.assertEqual(add_node.inputs, (None, None))
-
- # TODO(justinchuby): Test graph mutation methods
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/ir/_display.py b/onnxscript/ir/_display.py
deleted file mode 100644
index 937af92995..0000000000
--- a/onnxscript/ir/_display.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Internal utilities for displaying the intermediate representation of a model.
-
-NOTE: All third-party imports should be scoped and imported only when used to avoid
-importing unnecessary dependencies.
-"""
-# pylint: disable=import-outside-toplevel
-
-from __future__ import annotations
-
-from typing import Any
-
-_LONG_TEXT_LIMIT = 3000
-
-
-def require_rich() -> Any:
- """Raise an ImportError if rich is not installed."""
- try:
- import rich
- except ImportError:
- return None
- return rich
-
-
-class PrettyPrintable:
- def display(self, *, page: bool | None = None) -> None:
- """Pretty print the object.
-
- Args:
- page: Whether to page the output if it is too long.
- """
- rich = require_rich()
- text = str(self)
-
- if rich is None:
- print(text)
- # Color print this message
- print(
- f"\n\n\u001b[36mTip: Install the rich library with 'pip install rich' to pretty print this {self.__class__.__name__}.\u001b[0m"
- )
- return
-
- if page is None and len(text) > _LONG_TEXT_LIMIT:
- # By default, page the output if it is too long
- page = True
- if page:
- import rich.console
- import rich.syntax
-
- console = rich.console.Console()
- syntax = rich.syntax.Syntax(text, "cpp", theme="ansi_light")
- with console.pager(styles=True):
- console.print(syntax)
- else:
- rich.print(text)
diff --git a/onnxscript/ir/_display_test.py b/onnxscript/ir/_display_test.py
deleted file mode 100644
index 33e603a9b2..0000000000
--- a/onnxscript/ir/_display_test.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Test display() methods in various classes."""
-
-import contextlib
-import unittest
-
-import numpy as np
-
-import onnxscript.ir as ir
-
-
-class DisplayTest(unittest.TestCase):
- def test_tensor_display_does_not_raise_on_nan_values(self):
- array_with_nan = np.array([np.inf, -np.inf, np.nan, 5, -10], dtype=np.float32)
- tensor = ir.Tensor(array_with_nan, dtype=ir.DataType.FLOAT)
- with contextlib.redirect_stdout(None):
- tensor.display()
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py
deleted file mode 100644
index f2835fdad6..0000000000
--- a/onnxscript/ir/_enums.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""ONNX IR enums that matches the ONNX spec."""
-
-from __future__ import annotations
-
-import enum
-
-import numpy as np
-
-
-class AttributeType(enum.IntEnum):
- """Enum for the types of ONNX attributes."""
-
- UNDEFINED = 0
- FLOAT = 1
- INT = 2
- STRING = 3
- TENSOR = 4
- GRAPH = 5
- FLOATS = 6
- INTS = 7
- STRINGS = 8
- TENSORS = 9
- GRAPHS = 10
- SPARSE_TENSOR = 11
- SPARSE_TENSORS = 12
- TYPE_PROTO = 13
- TYPE_PROTOS = 14
-
- def __repr__(self) -> str:
- return self.name
-
- def __str__(self) -> str:
- return self.__repr__()
-
-
-class DataType(enum.IntEnum):
- """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
-
- # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
- # but we should stick to the names used in the ONNX spec for consistency.
- UNDEFINED = 0
- FLOAT = 1
- UINT8 = 2
- INT8 = 3
- UINT16 = 4
- INT16 = 5
- INT32 = 6
- INT64 = 7
- STRING = 8
- BOOL = 9
- FLOAT16 = 10
- DOUBLE = 11
- UINT32 = 12
- UINT64 = 13
- COMPLEX64 = 14
- COMPLEX128 = 15
- BFLOAT16 = 16
- FLOAT8E4M3FN = 17
- FLOAT8E4M3FNUZ = 18
- FLOAT8E5M2 = 19
- FLOAT8E5M2FNUZ = 20
- UINT4 = 21
- INT4 = 22
-
- @classmethod
- def from_numpy(cls, dtype: np.dtype) -> DataType:
- """Returns the ONNX data type for the numpy dtype.
-
- Raises:
- TypeError: If the data type is not supported by ONNX.
- """
- if dtype not in _NP_TYPE_TO_DATA_TYPE:
- raise TypeError(f"Unsupported numpy data type: {dtype}")
- return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
-
- @property
- def itemsize(self) -> float:
- """Returns the size of the data type in bytes."""
- return _ITEMSIZE_MAP[self]
-
- def numpy(self) -> np.dtype:
- """Returns the numpy dtype for the ONNX data type.
-
- Raises:
- TypeError: If the data type is not supported by numpy.
- """
- if self not in _DATA_TYPE_TO_NP_TYPE:
- raise TypeError(f"Numpy does not support ONNX data type: {self}")
- return _DATA_TYPE_TO_NP_TYPE[self]
-
- def __repr__(self) -> str:
- return self.name
-
- def __str__(self) -> str:
- return self.__repr__()
-
-
-_ITEMSIZE_MAP = {
- DataType.FLOAT: 4,
- DataType.UINT8: 1,
- DataType.INT8: 1,
- DataType.UINT16: 2,
- DataType.INT16: 2,
- DataType.INT32: 4,
- DataType.INT64: 8,
- DataType.STRING: 1,
- DataType.BOOL: 1,
- DataType.FLOAT16: 2,
- DataType.DOUBLE: 8,
- DataType.UINT32: 4,
- DataType.UINT64: 8,
- DataType.COMPLEX64: 8,
- DataType.COMPLEX128: 16,
- DataType.BFLOAT16: 2,
- DataType.FLOAT8E4M3FN: 1,
- DataType.FLOAT8E4M3FNUZ: 1,
- DataType.FLOAT8E5M2: 1,
- DataType.FLOAT8E5M2FNUZ: 1,
- DataType.UINT4: 0.5,
- DataType.INT4: 0.5,
-}
-
-
-_NP_TYPE_TO_DATA_TYPE = {
- np.dtype("bool"): DataType.BOOL,
- np.dtype("complex128"): DataType.COMPLEX128,
- np.dtype("complex64"): DataType.COMPLEX64,
- np.dtype("float16"): DataType.FLOAT16,
- np.dtype("float32"): DataType.FLOAT,
- np.dtype("float64"): DataType.DOUBLE,
- np.dtype("int16"): DataType.INT16,
- np.dtype("int32"): DataType.INT32,
- np.dtype("int64"): DataType.INT64,
- np.dtype("int8"): DataType.INT8,
- np.dtype("object"): DataType.STRING,
- np.dtype("uint16"): DataType.UINT16,
- np.dtype("uint32"): DataType.UINT32,
- np.dtype("uint64"): DataType.UINT64,
- np.dtype("uint8"): DataType.UINT8,
-}
-
-# ONNX DataType to Numpy dtype. This mapping does not capture ONNX data
-# types that are not supported by numpy.
-_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
-_DATA_TYPE_TO_NP_TYPE.update(
- {
- DataType.FLOAT8E4M3FN: np.dtype("uint8"),
- DataType.FLOAT8E4M3FNUZ: np.dtype("uint8"),
- DataType.FLOAT8E5M2: np.dtype("uint8"),
- DataType.FLOAT8E5M2FNUZ: np.dtype("uint8"),
- DataType.UINT4: np.dtype("uint8"),
- DataType.INT4: np.dtype("int8"),
- DataType.BFLOAT16: np.dtype("uint16"),
- }
-)
diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py
deleted file mode 100644
index a08debf0bf..0000000000
--- a/onnxscript/ir/_enums_test.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import unittest
-
-import numpy as np
-import onnx
-
-from onnxscript.ir import _enums
-
-
-class DataTypeTest(unittest.TestCase):
- def test_enums_are_the_same_as_spec(self):
- self.assertEqual(_enums.DataType.FLOAT, onnx.TensorProto.FLOAT)
- self.assertEqual(_enums.DataType.UINT8, onnx.TensorProto.UINT8)
- self.assertEqual(_enums.DataType.INT8, onnx.TensorProto.INT8)
- self.assertEqual(_enums.DataType.UINT16, onnx.TensorProto.UINT16)
- self.assertEqual(_enums.DataType.INT16, onnx.TensorProto.INT16)
- self.assertEqual(_enums.DataType.INT32, onnx.TensorProto.INT32)
- self.assertEqual(_enums.DataType.INT64, onnx.TensorProto.INT64)
- self.assertEqual(_enums.DataType.STRING, onnx.TensorProto.STRING)
- self.assertEqual(_enums.DataType.BOOL, onnx.TensorProto.BOOL)
- self.assertEqual(_enums.DataType.FLOAT16, onnx.TensorProto.FLOAT16)
- self.assertEqual(_enums.DataType.DOUBLE, onnx.TensorProto.DOUBLE)
- self.assertEqual(_enums.DataType.UINT32, onnx.TensorProto.UINT32)
- self.assertEqual(_enums.DataType.UINT64, onnx.TensorProto.UINT64)
- self.assertEqual(_enums.DataType.COMPLEX64, onnx.TensorProto.COMPLEX64)
- self.assertEqual(_enums.DataType.COMPLEX128, onnx.TensorProto.COMPLEX128)
- self.assertEqual(_enums.DataType.BFLOAT16, onnx.TensorProto.BFLOAT16)
- self.assertEqual(_enums.DataType.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FN)
- self.assertEqual(_enums.DataType.FLOAT8E4M3FNUZ, onnx.TensorProto.FLOAT8E4M3FNUZ)
- self.assertEqual(_enums.DataType.FLOAT8E5M2, onnx.TensorProto.FLOAT8E5M2)
- self.assertEqual(_enums.DataType.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E5M2FNUZ)
- self.assertEqual(_enums.DataType.UINT4, onnx.TensorProto.UINT4)
- self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4)
- self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED)
-
- def test_from_numpy_takes_np_dtype_and_returns_data_type(self):
- array = np.array([], dtype=np.float64)
- self.assertEqual(_enums.DataType.from_numpy(array.dtype), _enums.DataType.DOUBLE)
-
- def test_numpy_returns_np_dtype(self):
- self.assertEqual(_enums.DataType.DOUBLE.numpy(), np.dtype(np.float64))
-
- def test_itemsize_returns_size_of_data_type_in_bytes(self):
- self.assertEqual(_enums.DataType.DOUBLE.itemsize, 8)
- self.assertEqual(_enums.DataType.INT4.itemsize, 0.5)
-
- def test_repr_and_str_return_name(self):
- self.assertEqual(str(_enums.DataType.DOUBLE), "DOUBLE")
- self.assertEqual(repr(_enums.DataType.DOUBLE), "DOUBLE")
-
-
-class AttributeTypeTest(unittest.TestCase):
- def test_enums_are_the_same_as_spec(self):
- self.assertEqual(_enums.AttributeType.FLOAT, onnx.AttributeProto.FLOAT)
- self.assertEqual(_enums.AttributeType.INT, onnx.AttributeProto.INT)
- self.assertEqual(_enums.AttributeType.STRING, onnx.AttributeProto.STRING)
- self.assertEqual(_enums.AttributeType.TENSOR, onnx.AttributeProto.TENSOR)
- self.assertEqual(_enums.AttributeType.GRAPH, onnx.AttributeProto.GRAPH)
- self.assertEqual(_enums.AttributeType.FLOATS, onnx.AttributeProto.FLOATS)
- self.assertEqual(_enums.AttributeType.INTS, onnx.AttributeProto.INTS)
- self.assertEqual(_enums.AttributeType.STRINGS, onnx.AttributeProto.STRINGS)
- self.assertEqual(_enums.AttributeType.TENSORS, onnx.AttributeProto.TENSORS)
- self.assertEqual(_enums.AttributeType.GRAPHS, onnx.AttributeProto.GRAPHS)
- self.assertEqual(_enums.AttributeType.SPARSE_TENSOR, onnx.AttributeProto.SPARSE_TENSOR)
- self.assertEqual(
- _enums.AttributeType.SPARSE_TENSORS, onnx.AttributeProto.SPARSE_TENSORS
- )
- self.assertEqual(_enums.AttributeType.TYPE_PROTO, onnx.AttributeProto.TYPE_PROTO)
- self.assertEqual(_enums.AttributeType.TYPE_PROTOS, onnx.AttributeProto.TYPE_PROTOS)
- self.assertEqual(_enums.AttributeType.UNDEFINED, onnx.AttributeProto.UNDEFINED)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/ir/_graph_comparison.py b/onnxscript/ir/_graph_comparison.py
deleted file mode 100644
index 788b4b4d54..0000000000
--- a/onnxscript/ir/_graph_comparison.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Utilities for comparing IR graphs."""
-
-from __future__ import annotations
-
-from onnxscript.ir import _core
-
-# NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs
-# NOTE(justinchuby): A graph may be specified with a set of inputs and outputs
-
-
-def topologically_equal(graph1: _core.Graph, graph2: _core.Graph) -> bool:
- """Return true if the two graphs are topologically equivalent, without considering initializers.
-
- Args:
- graph1: The first graph to compare.
- graph2: The second graph to compare.
-
- Returns:
- True if the graphs are equal, False otherwise.
- """
- raise NotImplementedError()
diff --git a/onnxscript/ir/_invariants.py b/onnxscript/ir/_invariants.py
deleted file mode 100644
index 8d009c3cc9..0000000000
--- a/onnxscript/ir/_invariants.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Utilities to enforce invariants on the IR."""
-
-from __future__ import annotations
-
-import functools
-from typing import Any, Callable
-
-
-class InvariantError(Exception):
- """Raised when an invariant is violated."""
-
-
-class PreconditionError(InvariantError):
- """Raised when a precondition is violated."""
-
-
-class PostconditionError(InvariantError):
- """Raised when a postcondition is violated."""
-
-
-def requires(
- preconditions: Callable[..., str | None],
-) -> Callable[..., Callable[..., Any]]:
- """Decorator to enforce preconditions on a function."""
- # TODO(justinchuby): Preserve python function signature with this decorator
-
- def decorator(func: Callable[..., None]) -> Callable[..., None]:
- @functools.wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> None:
- message = preconditions(*args, **kwargs)
- if message is not None:
- raise PreconditionError(message)
- return func(*args, **kwargs)
-
- return wrapper
-
- return decorator
-
-
-def ensures(
- postconditions: Callable[..., str | None],
-) -> Callable[..., Callable[..., Any]]:
- """Decorator to enforce postconditions on a function."""
-
- def decorator(func: Callable[..., None]) -> Callable[..., None]:
- @functools.wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> None:
- result = func(*args, **kwargs)
- message = postconditions(*args, **kwargs)
- if message is not None:
- raise PostconditionError(message)
- return result
-
- return wrapper
-
- return decorator
diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py
deleted file mode 100644
index 059a88f2b9..0000000000
--- a/onnxscript/ir/_linked_list.py
+++ /dev/null
@@ -1,278 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Mutable list for nodes in a graph with safe mutation properties."""
-
-from __future__ import annotations
-
-from typing import Generic, Iterable, Iterator, Sequence, TypeVar
-
-T = TypeVar("T")
-
-
-class _LinkBox(Generic[T]):
- """A link in a doubly linked list that has a reference to the actual object in the link.
-
- The :class:`_LinkBox` is a container for the actual object in the list. It is used to
- maintain the links between the elements in the linked list. The actual object is stored in the
- :attr:`value` attribute.
-
- By using a separate container for the actual object, we can safely remove the object from the
- list without losing the links. This allows us to remove the object from the list during
- iteration and place the object into a different list without breaking any chains.
-
- This is an internal class and should only be initialized by the :class:`DoublyLinkedSet`.
-
- Attributes:
- prev: The previous box in the list.
- next: The next box in the list.
- erased: A flag to indicate if the box has been removed from the list.
- owning_list: The :class:`DoublyLinkedSet` to which the box belongs.
- value: The actual object in the list.
- """
-
- __slots__ = ("prev", "next", "value", "owning_list")
-
- def __init__(self, owner: DoublyLinkedSet[T], value: T | None) -> None:
- """Create a new link box.
-
- Args:
- owner: The linked list to which this box belongs.
- value: The value to be stored in the link box. When the value is None,
- the link box is considered erased (default). The root box of the list
- should be created with a None value.
- """
- self.prev: _LinkBox[T] = self
- self.next: _LinkBox[T] = self
- self.value: T | None = value
- self.owning_list: DoublyLinkedSet[T] = owner
-
- @property
- def erased(self) -> bool:
- return self.value is None
-
- def erase(self) -> None:
- """Remove the link from the list and detach the value from the box."""
- if self.value is None:
- raise ValueError("_LinkBox is already erased")
- # Update the links
- prev, next_ = self.prev, self.next
- prev.next, next_.prev = next_, prev
- # Detach the value
- self.value = None
-
- def __repr__(self) -> str:
- return f"_LinkBox({self.value!r}, erased={self.erased}, prev={self.prev.value!r}, next={self.next.value!r})"
-
-
-class DoublyLinkedSet(Generic[T], Sequence[T]):
- """A doubly linked ordered set of nodes.
-
- The container can be viewed as a set as it does not allow duplicate values. The order of the
- elements is maintained. One can typically treat it as a doubly linked list with list-like
- methods implemented.
-
- Adding and removing elements from the set during iteration is safe. Moving elements
- from one set to another is also safe.
-
- During the iteration:
- - If new elements are inserted after the current node, the iterator will
- iterate over them as well.
- - If new elements are inserted before the current node, they will
- not be iterated over in this iteration.
- - If the current node is lifted and inserted in a different location,
- iteration will start from the "next" node at the _original_ location.
-
- Time complexity:
- Inserting and removing nodes from the set is O(1). Accessing nodes by index is O(n),
- although accessing nodes at either end of the set is O(1). I.e.
- ``linked_set[0]`` and ``linked_set[-1]`` are O(1).
-
- Values need to be hashable. ``None`` is not a valid value in the set.
- """
-
- __slots__ = ("_root", "_length", "_value_ids_to_boxes")
-
- def __init__(self, values: Iterable[T] | None = None) -> None:
- # Using the root node simplifies the mutation implementation a lot
- # The list is circular. The root node is the only node that is not a part of the list values
- root_ = _LinkBox(self, None)
- self._root: _LinkBox = root_
- self._length = 0
- self._value_ids_to_boxes: dict[int, _LinkBox] = {}
- if values is not None:
- self.extend(values)
-
- def __iter__(self) -> Iterator[T]:
- """Iterate over the elements in the list.
-
- - If new elements are inserted after the current node, the iterator will
- iterate over them as well.
- - If new elements are inserted before the current node, they will
- not be iterated over in this iteration.
- - If the current node is lifted and inserted in a different location,
- iteration will start from the "next" node at the _original_ location.
- """
- box = self._root.next
- while box is not self._root:
- if box.owning_list is not self:
- raise RuntimeError(f"Element {box!r} is not in the list")
- if not box.erased:
- assert box.value is not None
- yield box.value
- box = box.next
-
- def __reversed__(self) -> Iterator[T]:
- """Iterate over the elements in the list in reverse order."""
- box = self._root.prev
- while box is not self._root:
- if not box.erased:
- assert box.value is not None
- yield box.value
- box = box.prev
-
- def __len__(self) -> int:
- assert self._length == len(
- self._value_ids_to_boxes
- ), "Bug in the implementation: length mismatch"
- return self._length
-
- def __getitem__(self, index: int) -> T:
- """Get the node at the given index.
-
- Complexity is O(n).
- """
- if index >= self._length or index < -self._length:
- raise IndexError(
- f"Index out of range: {index} not in range [-{self._length}, {self._length})"
- )
- if index < 0:
- # Look up from the end of the list
- iterator = reversed(self)
- item = next(iterator)
- for _ in range(-index - 1):
- item = next(iterator)
- else:
- iterator = iter(self) # type: ignore[assignment]
- item = next(iterator)
- for _ in range(index):
- item = next(iterator)
- return item
-
- def _insert_one_after(
- self,
- box: _LinkBox[T],
- new_value: T,
- ) -> _LinkBox[T]:
- """Insert a new value after the given box.
-
- All insertion methods should call this method to ensure that the list is updated correctly.
-
- Example::
- Before: A <-> B <-> C
- ^v0 ^v1 ^v2
- Call: _insert_one_after(B, v3)
- After: A <-> B <-> new_box <-> C
- ^v0 ^v1 ^v3 ^v2
-
- Args:
- box: The box which the new value is to be inserted.
- new_value: The new value to be inserted.
- """
- if new_value is None:
- raise TypeError(f"{self.__class__.__name__} does not support None values")
- if box.value is new_value:
- # Do nothing if the new value is the same as the old value
- return box
- if box.owning_list is not self:
- raise ValueError(f"Value {box.value!r} is not in the list")
-
- if (new_value_id := id(new_value)) in self._value_ids_to_boxes:
- # If the value is already in the list, remove it first
- self.remove(new_value)
-
- # Create a new _LinkBox for the new value
- new_box = _LinkBox(self, new_value)
- # original_box <=> original_next
- # becomes
- # original_box <=> new_box <=> original_next
- original_next = box.next
- box.next = new_box
- new_box.prev = box
- new_box.next = original_next
- original_next.prev = new_box
-
- # Be sure to update the length and mapping
- self._length += 1
- self._value_ids_to_boxes[new_value_id] = new_box
-
- return new_box
-
- def _insert_many_after(
- self,
- box: _LinkBox[T],
- new_values: Iterable[T],
- ):
- """Insert multiple new values after the given box."""
- insertion_point = box
- for new_value in new_values:
- insertion_point = self._insert_one_after(insertion_point, new_value)
-
- def remove(self, value: T) -> None:
- """Remove a node from the list."""
- if (value_id := id(value)) not in self._value_ids_to_boxes:
- raise ValueError(f"Value {value!r} is not in the list")
- box = self._value_ids_to_boxes[value_id]
- # Remove the link box and detach the value from the box
- box.erase()
-
- # Be sure to update the length and mapping
- self._length -= 1
- del self._value_ids_to_boxes[value_id]
-
- def append(self, value: T) -> None:
- """Append a node to the list."""
- _ = self._insert_one_after(self._root.prev, value)
-
- def extend(
- self,
- values: Iterable[T],
- ) -> None:
- for value in values:
- self.append(value)
-
- def insert_after(
- self,
- value: T,
- new_values: Iterable[T],
- ) -> None:
- """Insert new nodes after the given node.
-
- Args:
- value: The value after which the new values are to be inserted.
- new_values: The new values to be inserted.
- """
- if (value_id := id(value)) not in self._value_ids_to_boxes:
- raise ValueError(f"Value {value!r} is not in the list")
- insertion_point = self._value_ids_to_boxes[value_id]
- return self._insert_many_after(insertion_point, new_values)
-
- def insert_before(
- self,
- value: T,
- new_values: Iterable[T],
- ) -> None:
- """Insert new nodes before the given node.
-
- Args:
- value: The value before which the new values are to be inserted.
- new_values: The new values to be inserted.
- """
- if (value_id := id(value)) not in self._value_ids_to_boxes:
- raise ValueError(f"Value {value!r} is not in the list")
- insertion_point = self._value_ids_to_boxes[value_id].prev
- return self._insert_many_after(insertion_point, new_values)
-
- def __repr__(self) -> str:
- return f"DoublyLinkedSet({list(self)})"
diff --git a/onnxscript/ir/_linked_list_test.py b/onnxscript/ir/_linked_list_test.py
deleted file mode 100644
index a82b0e172b..0000000000
--- a/onnxscript/ir/_linked_list_test.py
+++ /dev/null
@@ -1,380 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Unit tests for the _linked_list module."""
-
-from __future__ import annotations
-
-import unittest
-
-import parameterized
-
-from onnxscript.ir import _linked_list
-
-
-class _TestElement:
- def __init__(self, value):
- self.value = value
-
- def __repr__(self) -> str:
- return f"_TestElement({self.value})"
-
-
-class DoublyLinkedSetTest(unittest.TestCase):
- def test_empty_list(self):
- linked_list = _linked_list.DoublyLinkedSet()
- self.assertEqual(len(linked_list), 0)
- self.assertEqual(list(linked_list), [])
- self.assertEqual(list(reversed(linked_list)), [])
- with self.assertRaises(IndexError):
- _ = linked_list[0]
- with self.assertRaises(IndexError):
- _ = linked_list[-1]
-
- def test_append_single_element(self):
- linked_list = _linked_list.DoublyLinkedSet()
- elem = _TestElement(0)
- linked_list.append(elem)
-
- self.assertEqual(len(linked_list), 1)
- self.assertEqual(linked_list[0], elem)
- self.assertEqual(linked_list[-1], elem)
- self.assertEqual(list(linked_list), [elem])
- self.assertEqual(list(reversed(linked_list)), [elem])
- with self.assertRaises(IndexError):
- _ = linked_list[1]
- with self.assertRaises(IndexError):
- _ = linked_list[-2]
-
- def test_append_multiple_elements(self):
- linked_list = _linked_list.DoublyLinkedSet()
- elems = [_TestElement(i) for i in range(3)]
- for elem in elems:
- linked_list.append(elem)
-
- self.assertEqual(len(linked_list), 3)
- self.assertEqual(linked_list[0], elems[0])
- self.assertEqual(linked_list[1], elems[1])
- self.assertEqual(linked_list[2], elems[2])
- self.assertEqual(linked_list[-1], elems[2])
- self.assertEqual(linked_list[-2], elems[1])
- self.assertEqual(linked_list[-3], elems[0])
- self.assertEqual(list(linked_list), elems)
- self.assertEqual(list(reversed(linked_list)), list(reversed(elems)))
-
- def test_extend(self):
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- self.assertEqual(len(linked_list), 3)
- self.assertEqual(linked_list[0], elems[0])
- self.assertEqual(linked_list[1], elems[1])
- self.assertEqual(linked_list[2], elems[2])
- self.assertEqual(linked_list[-1], elems[2])
- self.assertEqual(linked_list[-2], elems[1])
- self.assertEqual(linked_list[-3], elems[0])
- self.assertEqual(list(linked_list), elems)
- self.assertEqual(list(reversed(linked_list)), list(reversed(elems)))
-
- @parameterized.parameterized.expand(
- [
- ("single_element", [0], 0, [1], [0, 1]),
- ("single_element_negative_index", [0], -1, [1], [0, 1]),
- ("multiple_elements", [0], 0, [1, 2], [0, 1, 2]),
- ("multiple_elements_negative_index", [0], -1, [1, 2], [0, 1, 2]),
- (
- "multiple_original_elements_insert_at_start",
- [0, 1, 2],
- 0,
- [42, 43],
- [0, 42, 43, 1, 2],
- ),
- (
- "multiple_original_elements_insert_at_middle",
- [0, 1, 2],
- 1,
- [42, 43],
- [0, 1, 42, 43, 2],
- ),
- (
- "multiple_original_elements_insert_at_end",
- [0, 1, 2],
- 2,
- [42, 43],
- [0, 1, 2, 42, 43],
- ),
- ]
- )
- def test_insert_after(
- self, _: str, original: list[int], location: int, insertion: list[int], expected: list
- ) -> None:
- # Construct the original list
- elems = [_TestElement(i) for i in original]
- linked_list = _linked_list.DoublyLinkedSet(elems)
-
- # Create the new elements
- new_elements = [_TestElement(i) for i in insertion]
- linked_list.insert_after(elems[location], new_elements)
-
- # Check the list
- self.assertEqual(len(linked_list), len(expected))
- self.assertEqual([elem.value for elem in linked_list], expected)
-
- @parameterized.parameterized.expand(
- [
- ("single_element", [0], 0, [1], [1, 0]),
- ("single_element_negative_index", [0], -1, [1], [1, 0]),
- ("multiple_elements", [0], 0, [1, 3], [1, 3, 0]),
- ("multiple_elements_negative_index", [0], -1, [1, 3], [1, 3, 0]),
- (
- "multiple_original_elements_insert_at_start",
- [0, 1, 2],
- 0,
- [42, 43],
- [42, 43, 0, 1, 2],
- ),
- (
- "multiple_original_elements_insert_at_middle",
- [0, 1, 2],
- 1,
- [42, 43],
- [0, 42, 43, 1, 2],
- ),
- (
- "multiple_original_elements_insert_at_end",
- [0, 1, 2],
- 2,
- [42, 43],
- [0, 1, 42, 43, 2],
- ),
- ]
- )
- def test_insert_before(
- self, _: str, original: list[int], location: int, insertion: list[int], expected: list
- ) -> None:
- # Construct the original list
- elems = [_TestElement(i) for i in original]
- linked_list = _linked_list.DoublyLinkedSet(elems)
-
- # Create the new elements
- new_elements = [_TestElement(i) for i in insertion]
- linked_list.insert_before(elems[location], new_elements)
-
- # Check the list
- self.assertEqual(len(linked_list), len(expected))
- self.assertEqual([elem.value for elem in linked_list], expected)
- self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1])
-
- @parameterized.parameterized.expand(
- [
- ("start", 0, [1, 2]),
- ("middle", 1, [0, 2]),
- ("end", 2, [0, 1]),
- ("start_negative", -1, [0, 1]),
- ("middle_negative", -2, [0, 2]),
- ("end_negative", -3, [1, 2]),
- ]
- )
- def test_remove(self, _: str, index: int, expected: list[int]) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
-
- linked_list.remove(elems[index])
-
- self.assertEqual(len(linked_list), 2)
- self.assertEqual([elem.value for elem in linked_list], expected)
- self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1])
-
- def test_remove_raises_when_element_not_found(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
-
- with self.assertRaises(ValueError):
- linked_list.remove(_TestElement(3))
-
- def test_remove_raises_when_element_is_already_removed(self) -> None:
- linked_list = _linked_list.DoublyLinkedSet()
- elem = _TestElement(0)
- linked_list.append(elem)
- linked_list.remove(elem)
-
- with self.assertRaises(ValueError):
- linked_list.remove(elem)
-
- def test_append_self_does_nothing(self) -> None:
- linked_list = _linked_list.DoublyLinkedSet()
- elem = _TestElement(0)
- linked_list.append(elem)
-
- linked_list.append(elem)
-
- self.assertEqual(len(linked_list), 1)
- self.assertEqual(linked_list[0], elem)
- self.assertEqual(list(linked_list), [elem])
- self.assertEqual(list(reversed(linked_list)), [elem])
-
- def test_append_supports_appending_element_from_the_same_list(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
-
- linked_list.append(elems[1])
-
- self.assertEqual(len(linked_list), 3)
- self.assertEqual([elem.value for elem in linked_list], [0, 2, 1])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [1, 2, 0])
-
- def test_extend_supports_extending_elements_from_the_same_list(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- linked_list.extend(elems[::-1])
-
- self.assertEqual(len(linked_list), 3)
- self.assertEqual([elem.value for elem in linked_list], [2, 1, 0])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [0, 1, 2])
-
- def test_insert_after_supports_inserting_element_from_the_same_list(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- linked_list.insert_after(elems[0], [elems[2]])
-
- self.assertEqual(len(linked_list), 3)
- self.assertEqual([elem.value for elem in linked_list], [0, 2, 1])
-
- def test_insert_before_supports_inserting_element_from_the_same_list(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- linked_list.insert_before(elems[0], [elems[2]])
-
- self.assertEqual(len(linked_list), 3)
- self.assertEqual([elem.value for elem in linked_list], [2, 0, 1])
-
- def test_iterator_supports_mutation_during_iteration_current_element(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- for elem in linked_list:
- if elem.value == 1:
- linked_list.remove(elem)
-
- self.assertEqual(len(linked_list), 2)
- self.assertEqual([elem.value for elem in linked_list], [0, 2])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0])
-
- def test_iterator_supports_mutation_during_iteration_previous_element(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- for elem in linked_list:
- if elem.value == 1:
- linked_list.remove(elem)
- linked_list.remove(elems[0])
-
- self.assertEqual(len(linked_list), 1)
- self.assertEqual([elem.value for elem in linked_list], [2])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [2])
-
- def test_iterator_supports_mutation_during_iteration_next_element(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- for elem in linked_list:
- if elem.value == 1:
- linked_list.remove(elems[2])
- linked_list.remove(elem)
-
- self.assertEqual(len(linked_list), 1)
- self.assertEqual([elem.value for elem in linked_list], [0])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [0])
-
- def test_iterator_supports_mutation_in_nested_iteration_right_of_iterator(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- iter1_visited = []
- iter2_visited = []
- for elem in linked_list:
- iter1_visited.append(elem.value)
- for elem2 in linked_list:
- iter2_visited.append(elem2.value)
- if elem2.value == 1:
- linked_list.remove(elem2)
-
- self.assertEqual(len(linked_list), 2)
- self.assertEqual(iter1_visited, [0, 2])
- self.assertEqual(iter2_visited, [0, 1, 2, 0, 2])
- self.assertEqual([elem.value for elem in linked_list], [0, 2])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0])
-
- def test_iterator_supports_mutation_in_nested_iteration_when_iter_is_self(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- iter1_visited = []
- iter2_visited = []
- for elem in linked_list:
- iter1_visited.append(elem.value)
- for elem2 in linked_list:
- iter2_visited.append(elem2.value)
- if elem2.value == 0: # Remove the element the current iterator points to
- linked_list.remove(elem2)
-
- self.assertEqual(len(linked_list), 2)
- self.assertEqual(iter1_visited, [0, 1, 2])
- self.assertEqual(iter2_visited, [0, 1, 2, 1, 2, 1, 2])
- self.assertEqual([elem.value for elem in linked_list], [1, 2])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1])
-
- def test_iterator_supports_mutation_in_nested_iteration_left_of_iterator(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- iter1_visited = []
- iter2_visited = []
- for elem in linked_list:
- iter1_visited.append(elem.value)
- for elem2 in linked_list:
- iter2_visited.append(elem2.value)
- if (
- elem.value == 1 and elem2.value == 0
- ): # Remove the element before the current iterator points to
- linked_list.remove(elems[0])
-
- self.assertEqual(len(linked_list), 2)
- self.assertEqual(iter1_visited, [0, 1, 2])
- self.assertEqual(iter2_visited, [0, 1, 2, 0, 1, 2, 1, 2])
- self.assertEqual([elem.value for elem in linked_list], [1, 2])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1])
-
- def test_insert_after_supports_element_from_different_list_during_iteration(self) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- other_linked_list = _linked_list.DoublyLinkedSet()
- other_elem = _TestElement(42)
- other_linked_list.append(other_elem)
-
- for elem in linked_list:
- if elem.value == 1:
- linked_list.insert_after(elem, [other_elem])
-
- self.assertEqual(len(linked_list), 4)
- self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0])
- # Other list remains unchanged
- self.assertEqual(len(other_linked_list), 1)
- self.assertEqual([elem.value for elem in other_linked_list], [42])
-
- def test_insert_after_supports_taking_elements_from_another_doubly_linked_list(
- self,
- ) -> None:
- elems = [_TestElement(i) for i in range(3)]
- linked_list = _linked_list.DoublyLinkedSet(elems)
- other_linked_list = _linked_list.DoublyLinkedSet()
- other_elem = _TestElement(42)
- other_linked_list.append(other_elem)
-
- linked_list.insert_after(elems[1], other_linked_list)
-
- self.assertEqual(len(linked_list), 4)
- self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2])
- self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0])
- # Other list remains unchanged
- self.assertEqual(len(other_linked_list), 1)
- self.assertEqual([elem.value for elem in other_linked_list], [42])
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/ir/_metadata.py b/onnxscript/ir/_metadata.py
deleted file mode 100644
index bbb01a9596..0000000000
--- a/onnxscript/ir/_metadata.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Class for storing metadata about the IR objects."""
-
-from __future__ import annotations
-
-import collections
-from typing import Any, Mapping
-
-
-class MetadataStore(collections.UserDict):
- """Class for storing metadata about the IR objects.
-
- Metadata is stored as key-value pairs. The keys are strings and the values
- can be any Python object.
-
- The metadata store also supports marking keys as invalid. This is useful
- when a pass wants to mark a key that needs to be recomputed.
- """
-
- def __init__(self, data: Mapping[str, Any] | None = None, /) -> None:
- super().__init__(data)
- self._invalid_keys: set[str] = set()
-
- def __setitem__(self, key: str, item: Any) -> None:
- self.data[key] = item
- self._invalid_keys.discard(key)
-
- def invalidate(self, key: str) -> None:
- self._invalid_keys.add(key)
-
- def is_valid(self, key: str) -> bool:
- """Returns whether the value is valid.
-
- Note that default values (None) are not necessarily invalid. For example,
- a shape that is unknown (None) may be still valid if shape inference has
- determined that the shape is unknown.
-
- Whether a value is valid is solely determined by the user that sets the value.
- """
- return key not in self._invalid_keys
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.data!r}, invalid_keys={self._invalid_keys!r})"
diff --git a/onnxscript/ir/_name_authority.py b/onnxscript/ir/_name_authority.py
deleted file mode 100644
index 856c86247e..0000000000
--- a/onnxscript/ir/_name_authority.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Auxiliary class for managing names in the IR."""
-
-from __future__ import annotations
-
-from onnxscript.ir import _core
-
-
-class NameAuthority:
- """Class for giving names to values and nodes in the IR.
-
- The names are generated in the format ``val_{value_counter}`` for values and
- ``node_{op_type}_{node_counter}`` for nodes. The counter is incremented each time
- a new value or node is named.
-
- The class does not keep track of the names it has given, so it is possible to
- generate names that conflicts with existing names. It is the responsibility of the
- user to ensure that the names are unique (typically by running a name-fixing pass
- on the graph).
- """
-
- def __init__(self):
- self._value_counter = 0
- self._node_counter = 0
-
- def name_value(self, value: _core.Value) -> None:
- value.name = f"val_{self._value_counter}"
- self._value_counter += 1
-
- def name_node(self, node: _core.Node) -> None:
- node.name = f"node_{node.op_type}_{self._node_counter}"
- self._node_counter += 1
diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py
deleted file mode 100644
index 7e5b791208..0000000000
--- a/onnxscript/ir/_protocols.py
+++ /dev/null
@@ -1,590 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Protocols for the ONNX IR.
-
-This file defines the interfaces for tools to interact with the IR. The interfaces
-are designed such that tools leveraging the IR can be decoupled from the IR
-implementation. This allows for the implementation to evolve independently of the
-tools.
-"""
-
-# 👀
-# NOTE: Why are we using protocols, instead of abstract base classes?
-#
-# Protocols are more flexible than abstract base classes. Users can define their
-# own classes that implement the protocols without having to inherit from a
-# specific base class. For example, a user can define a custom tensor class that
-# implements the TensorProtocol without explicitly inheriting, and the IR can
-# work with that class without any changes.
-#
-# `isinstance` checks can be slower with protocols. Avoid using `isinstance`
-# checks when you can. Always check for concrete classes first.
-#
-# NOTE: Why are we using protocols, instead of using concrete classes directly?
-#
-# Protocols define the interface that is typically more stable. If you find yourself
-# updating the protocols, pause 🛑, and carefully make sure it is absolutely needed
-# and will improve the design. If you are adding new methods, consider if the method
-# should be part of the protocol or if it should be a higher level convenience function
-# defined outside the protocol.
-
-from __future__ import annotations
-
-import typing
-from typing import (
- Any,
- Collection,
- Iterable,
- Iterator,
- Mapping,
- MutableMapping,
- MutableSequence,
- OrderedDict,
- Protocol,
- Sequence,
- Tuple,
-)
-
-from onnxscript.ir import _enums
-
-if typing.TYPE_CHECKING:
- import numpy as np
- from typing_extensions import TypeAlias
-
-# An identifier that will uniquely identify an operator. E.g (domain, op_type, overload)
-OperatorIdentifier: TypeAlias = Tuple[str, str, str]
-
-
-@typing.runtime_checkable
-class ArrayCompatible(Protocol):
- """Protocol for array-like objects.
-
- An example of an array-like object is a numpy ndarray or a PyTorch Tensor.
- Read more at https://numpy.org/devdocs/user/basics.interoperability.html
- """
-
- def __array__(self, dtype: Any) -> np.ndarray: ...
-
-
-@typing.runtime_checkable
-class DLPackCompatible(Protocol):
- """Protocol for objects that can support dlpack.
-
- Computation backends can call __dlpack__ to obtain the underlying data in a
- tensor without copying the data. This allows use to use tensorflow tensors etc.
- without copying the data.
- """
-
- def __dlpack__(self, *, stream: Any = ...) -> Any:
- """Return PyCapsule."""
- ...
-
- def __dlpack_device__(self) -> Any:
- """Return the device."""
- ...
-
-
-@typing.runtime_checkable
-class TensorProtocol(ArrayCompatible, Protocol):
- """Concrete tensor backed by data.
-
- The protocol does not specify how the data is stored. That data is exposed
- through the :attr:`raw` attribute for examination, but accessing :attr:`raw`
- is typically not needed.
-
- To use the tensor as a numpy array, call :meth:`numpy`. To convert the tensor
- to a byte string for serialization, call :meth:`tobytes`.
-
- It is recommended to check the size of the tensor first before accessing the
- underlying data, because accessing the data may be expensive and incur IO
- overhead.
-
- Attributes:
- name: The name of the tensor.
- shape: The shape of the tensor.
- dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
- doc_string: Documentation string.
- raw: The raw data behind this tensor. It can be anything.
- size: The number of elements in the tensor.
- nbytes: The number of bytes in the tensor.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- name: str
- shape: ShapeProtocol
- dtype: _enums.DataType
- doc_string: str | None
- raw: Any
- metadata_props: MutableMapping[str, str]
- meta: MutableMapping[str, Any]
-
- @property
- def size(self) -> int: ...
-
- @property
- def nbytes(self) -> int: ...
-
- def numpy(self) -> np.ndarray:
- """Return the tensor as a numpy array."""
- ...
-
- def __array__(self, dtype: Any = None) -> np.ndarray:
- """Return the tensor as a numpy array, compatible with np.array."""
- ...
-
- def tobytes(self) -> bytes:
- """Return the tensor as a byte string conformed to the ONNX specification, in little endian."""
- ...
-
-
-@typing.runtime_checkable
-class ValueProtocol(Protocol):
- """Protocol for values.
-
- A value is a named entity that can be used to represent an input or output of a graph,
- a function, or a node. The information it stores generalizes over ``ValueInfoProto``
- in the ONNX specification.
-
- A :class:`Value` is always not owned or owned by exactly one node. When the value is not
- owned, it must be an input of a graph or a function. ``producer`` and ``index``
- are ``None``.
-
- When the value is owned by a node, it is an output of the node.
- The node that produces the value can be accessed with :meth:`producer`.
- The index of the output of the node that produces the value can be accessed with
- :meth:`index`.
-
- To find all the nodes that use this value as an input, call :meth:`uses`.
-
- To check if the value is an output of a graph, call :meth:`is_graph_output`.
-
- Attributes:
- name: The name of the value. A value is always named when it is part of a graph.
- shape: The shape of the value.
- type: The type of the value.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- doc_string: Documentation string.
- """
-
- name: str
- shape: ShapeProtocol | None
- type: TypeProtocol | None
- metadata_props: MutableMapping[str, str]
- meta: MutableMapping[str, Any]
- doc_string: str | None
-
- def producer(self) -> NodeProtocol | None:
- """The node that produces this value."""
- ...
-
- def index(self) -> int | None:
- """The index of the output of the node that produces this value."""
- ...
-
- def uses(self) -> Collection[tuple[NodeProtocol, int]]:
- """The set of (node, input_index) with node being those that use this value as an input."""
- ...
-
- def is_graph_output(self) -> bool:
- """Whether this value is an output of a graph."""
- ...
-
-
-@typing.runtime_checkable
-class NodeProtocol(Protocol):
- """Protocol for nodes.
-
- A node represents an invocation of an operation on the :class:`Value` s in
- the computational graph.
-
- A node can be optionally named. A name should typically be assigned when the
- node is added to a graph.
-
- :attr:`domain`, :attr:`op_type`, and :attr:`overload` together uniquely identify
- the operator, and are always strings. For ONNX operators, :attr:`domain` and :attr:`overload`
- are both empty strings.
-
- :attr:`inputs` and :attr:`outputs` are the input and output values of the node.
-
- :attr:`attributes` are the attributes of the node. The attributes are stored in an
- ordered dictionary to preserve the order of the attributes. This is a deviation from
- the current ONNX spec where attributes are unordered, but it is helpful for tools
- that rely on the order of the attributes, e.g. those converting to and from Python
- function keyword arguments.
-
- :attr:`version` is unique to the IR and is not specified in the ONNX spec. This
- allows the IR to represent a graph with mixed opset versions. Deserializers
- should decide how to reconcile the different versions within the graph. A typical
- graph will have a single version, declared in the :class:`Graph` object and
- the nodes will have ``None`` as the version.
-
- Attributes:
- domain: The domain of the operator. E.g. ``""`` for ONNX operators.
- op_type: The operator name.
- overload: The overload name when the node is invoking a function.
- inputs: Input values.
- outputs: Output values.
- attributes: The attributes of the operator.
- version: The version of the operator.
- doc_string: Documentation string.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- name: str | None
- domain: str
- op_type: str
- overload: str
- inputs: Sequence[ValueProtocol]
- outputs: Sequence[ValueProtocol]
- attributes: OrderedDict[str, AttributeProtocol | ReferenceAttributeProtocol]
- version: int | None
- doc_string: str | None
- metadata_props: MutableMapping[str, str]
- meta: MutableMapping[str, Any]
-
- def replace_input_with(self, index: int, value: ValueProtocol | None) -> None:
- """Set the input at the given index to the given value, replacing the original value."""
- ...
-
-
-@typing.runtime_checkable
-class GraphProtocol(Protocol):
- """Protocol for graphs.
-
- Graph represents a computation graph. In addition to the ONNX specification
- specified fields, it also contains a mapping of :attr:`opset_imports`. This
- allows different subgraphs to import different opsets. It is the responsibility
- of the deserializer to reconcile the different opsets.
-
- The nodes are not guaranteed to be topologically sorted. But the
- iteration order should be deterministic across different runs. It is the
- responsibility of the user to maintain a topological order of the nodes.
-
- Note that there is not a ``node`` attribute in the Graph. The Graph can be
- seen as a Sequence of nodes and should be used as such. For example, to obtain
- all nodes as a list, call ``list(graph)``.
-
- Attributes:
- name: The name of the graph.
- inputs: The input values of the graph.
- outputs: The output values of the graph.
- initializers: The initializers in the graph.
- doc_string: Documentation string.
- opset_imports: Opsets imported by the graph.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- # TODO(justinchuby): Support quantization_annotation
- name: str | None
- inputs: MutableSequence[ValueProtocol]
- outputs: MutableSequence[ValueProtocol]
- initializers: MutableMapping[str, TensorProtocol]
- doc_string: str
- opset_imports: MutableMapping[str, int]
- metadata_props: MutableMapping[str, str]
- meta: MutableMapping[str, Any]
-
- def __getitem__(self, index: int) -> NodeProtocol: ...
- def __len__(self) -> int: ...
- def __iter__(self) -> Iterator[NodeProtocol]: ...
- def __reversed__(self) -> Iterator[NodeProtocol]: ...
-
- # Mutation methods
- def append(self, node: NodeProtocol, /) -> None:
- """Append a node to the graph."""
- ...
-
- def extend(self, nodes: Iterable[NodeProtocol], /) -> None:
- """Extend the graph with the given nodes."""
- ...
-
- def remove(self, node: NodeProtocol, /) -> None:
- """Remove a node from the graph."""
- ...
-
- def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
- """Insert new nodes after the given node."""
- ...
-
- def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
- """Insert new nodes before the given node."""
- ...
-
- def sort(self) -> None:
- """Topologically sort the nodes in the graph."""
- ...
-
-
-@typing.runtime_checkable
-class GraphViewProtocol(Protocol):
- """Protocol for a read-only view on a graph.
-
- The GraphView is useful for analysis of a subgraph. It can be initialized
- with a subset of nodes from a :class:`Graph`. Creating GraphView does not
- change the ownership of the nodes, and so it is possible to create multiple
- GraphViews that contain the same nodes.
-
- Attributes:
- name: The name of the graph.
- inputs: The input values of the graph.
- outputs: The output values of the graph.
- initializers: The initializers in the graph.
- doc_string: Documentation string.
- opset_imports: Opsets imported by the graph.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- name: str | None
- inputs: Sequence[ValueProtocol]
- outputs: Sequence[ValueProtocol]
- initializers: Mapping[str, TensorProtocol]
- doc_string: str
- opset_imports: Mapping[str, int]
- metadata_props: MutableMapping[str, str]
- meta: MutableMapping[str, Any]
-
- def __getitem__(self, index: int) -> NodeProtocol: ...
- def __len__(self) -> int: ...
- def __iter__(self) -> Iterator[NodeProtocol]: ...
- def __reversed__(self) -> Iterator[NodeProtocol]: ...
-
-
-@typing.runtime_checkable
-class ModelProtocol(Protocol):
- """Protocol for models.
-
- A model is a container for a graph and metadata. It is the top-level object
- that represents an ONNX model.
-
- Attributes:
- graph: The graph of the model.
- ir_version: The version of the IR.
- producer_name: The name of the producer.
- producer_version: The version of the producer.
- domain: The domain of the model.
- model_version: The version of the model.
- doc_string: Documentation string.
- functions: The functions defined in the model.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- graph: GraphProtocol
- ir_version: int
- producer_name: str | None
- producer_version: str | None
- domain: str | None
- model_version: int | None
- doc_string: str | None
- functions: MutableMapping[str, FunctionProtocol]
- # TODO(justinchuby): Add training_info
- opset_imports: MutableMapping[str, int]
- metadata_props: MutableMapping[str, str]
- meta: MutableMapping[str, Any]
-
-
-@typing.runtime_checkable
-class AttributeProtocol(Protocol):
- """Protocol for ONNX attributes.
-
- Attributes:
- name: The name of the attribute.
- type: The type of the attribute.
- value: The value of the attribute.
- doc_string: Documentation string.
- """
-
- name: str
- type: _enums.AttributeType
- value: Any
- doc_string: str | None
-
-
-@typing.runtime_checkable
-class ReferenceAttributeProtocol(Protocol):
- """Protocol for a reference attribute.
-
- A reference attribute can only appear inside the definition body of a function.
-
- Attributes:
- name: The name of the attribute.
- ref_attr_name: The name of the attribute definition this attribute refers to.
- type: The type of the attribute.
- doc_string: Documentation string.
- """
-
- name: str
- ref_attr_name: str
- type: _enums.AttributeType
- doc_string: str | None
-
-
-@typing.runtime_checkable
-class SparseTensorProtocol(Protocol):
- values: TensorProtocol
- indices: TensorProtocol
- dims: Sequence[int]
-
-
-@typing.runtime_checkable
-class SymbolicDimProtocol(Protocol):
- """Value of a single symbolic/dynamic dimension in a shape.
-
- Attributes:
- value: The value of the dimension.
- """
-
- value: str | None # TODO(justinchuby): Maybe support sympy
-
-
-@typing.runtime_checkable
-class ShapeProtocol(Protocol):
- """Protocol for ONNX shapes.
-
- A shape is a sequence of dimensions.
-
- Attributes:
- dims: The dimensions of the shape.
- """
-
- dims: Sequence[int | SymbolicDimProtocol]
-
- def __len__(self) -> int: ...
- def __iter__(self) -> Iterator[int | SymbolicDimProtocol]: ...
- @typing.overload
- def __getitem__(self, index: int) -> int | SymbolicDimProtocol: ...
- @typing.overload
- def __getitem__(self, index: slice) -> tuple[int | SymbolicDimProtocol, ...]: ...
- def __setitem__(
- self, index: int, value: int | SymbolicDimProtocol | str | None
- ) -> None: ...
- def __eq__(self, other: object) -> bool: ...
- def __ne__(self, value: object) -> bool: ...
- def get_denotation(self, index: int) -> str | None: ...
- def set_denotation(self, index: int, denotation: str | None) -> None: ...
- def numpy(self) -> Sequence[int]: ...
- def rank(self) -> int: ...
-
-
-@typing.runtime_checkable
-class TypeProtocol(Protocol):
- """Protocol for ONNX tensors, Sequence tensors, Optional tensors and Sparse tensors.
-
- These three types of tensors share the same attribute "elem_type" so they are
- merged in the same interface. Unlike the ONNX TensorProto, shapes are not included
- in the type and should be stored in the :class:`Value`.
-
- Attributes:
- denotation: An optional denotation can be used to denote the whole
- type with a standard semantic description as to what is
- stored inside.
- Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
- for pre-defined type denotations.
- elem_type: The type of its elements for nested types like Sequence[Optional] tensors.
- Or the DataType if the type is not nested.
- dtype: The data type of the tensor or the nested tensor.
- """
-
- denotation: str | None
- elem_type: TypeProtocol | _enums.DataType
- dtype: _enums.DataType
-
- def __eq__(self, __value: object) -> bool: ...
-
-
-@typing.runtime_checkable
-class MapTypeProtocol(Protocol):
- """Protocol for ONNX map types.
-
- TODO: This protocol is not yet implemented in the ONNX IR.
- """
-
- key_type: typing.Literal[
- _enums.DataType.STRING,
- _enums.DataType.INT64,
- _enums.DataType.INT32,
- _enums.DataType.INT16,
- _enums.DataType.INT8,
- _enums.DataType.UINT64,
- _enums.DataType.UINT32,
- _enums.DataType.UINT16,
- _enums.DataType.UINT8,
- ]
- value_type: _enums.DataType
-
-
-@typing.runtime_checkable
-class FunctionProtocol(Protocol):
- """Protocol for ONNX functions.
-
- Like a graph, a function can have nodes that are not topologically sorted. It is
- the responsibility of the user to maintain a topological order of the nodes.
-
- Note that there is not a ``node`` attribute in the Function. The Function can be
- seen as a Sequence of nodes and should be used as such. For example, to obtain
- all nodes as a list, call ``list(function)``.
-
- Attributes:
- name: The function name.
- domain: The domain this function is defined in.
- overload: The overload name when the function is overloaded.
- inputs: The input values of the function.
- attributes: The attributes this function defines.
- outputs: The output values of the function.
- opset_imports: Opsets imported by the function.
- doc_string: Documentation string.
- metadata_props: Metadata that will be serialized to the ONNX file.
- meta: Metadata store for graph transform passes.
- """
-
- name: str
- domain: str
- overload: str
- inputs: Sequence[ValueProtocol]
- attributes: OrderedDict[str, AttributeProtocol]
- outputs: Sequence[ValueProtocol]
- doc_string: str
- opset_imports: MutableMapping[str, int]
- metadata_props: MutableMapping[str, str]
- meta: MutableMapping[str, Any]
-
- def __getitem__(self, index: int) -> NodeProtocol: ...
- def __len__(self) -> int: ...
- def __iter__(self) -> Iterator[NodeProtocol]: ...
- def __reversed__(self) -> Iterator[NodeProtocol]: ...
- def identifier(self) -> OperatorIdentifier:
- """Return the unique identifier of the function."""
- ...
-
- # Mutation methods
- # End Block
- def append(self, node: NodeProtocol, /) -> None:
- """Append a node to the function."""
- ...
-
- def extend(self, nodes: Iterable[NodeProtocol], /) -> None:
- """Extend the function with the given nodes."""
- ...
-
- def remove(self, node: NodeProtocol, /) -> None:
- """Remove a node from the function."""
- ...
-
- def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
- """Insert new nodes after the given node."""
- ...
-
- def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
- """Insert new nodes before the given node."""
- ...
-
- def sort(self) -> None:
- """Topologically sort the nodes in the function."""
- ...
diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py
new file mode 100644
index 0000000000..d4d88ab5bb
--- /dev/null
+++ b/onnxscript/ir/_schemas.py
@@ -0,0 +1,548 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+
+import collections.abc
+import dataclasses
+import inspect
+import logging
+import types
+import typing
+from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union
+
+import onnx
+
+import onnxscript
+from onnxscript import ir
+
+logger = logging.getLogger(__name__)
+
+
+# A special value to indicate that the default value is not specified
+class _Empty:
+ def __repr__(self):
+ return "_EMPTY_DEFAULT"
+
+
+_EMPTY_DEFAULT = _Empty()
+
+# Map from python type to corresponding ONNX AttributeProto type
+_PY_TYPE_TO_ATTR_TYPE = {
+ float: ir.AttributeType.FLOAT,
+ int: ir.AttributeType.INT,
+ str: ir.AttributeType.STRING,
+ bool: ir.AttributeType.INT,
+ ir.Tensor: ir.AttributeType.TENSOR,
+ ir.TensorProtocol: ir.AttributeType.TENSOR,
+ ir.Graph: ir.AttributeType.GRAPH,
+ ir.GraphProtocol: ir.AttributeType.GRAPH,
+}
+
+# Map from python type to corresponding ONNX AttributeProto type,
+# for repeated (i.e., list of) values
+_LIST_TYPE_TO_ATTR_TYPE = {
+ float: ir.AttributeType.FLOATS,
+ int: ir.AttributeType.INTS,
+ str: ir.AttributeType.STRINGS,
+ bool: ir.AttributeType.INTS,
+ ir.Tensor: ir.AttributeType.TENSORS,
+ ir.TensorProtocol: ir.AttributeType.TENSORS,
+ ir.Graph: ir.AttributeType.GRAPHS,
+ ir.GraphProtocol: ir.AttributeType.GRAPHS,
+}
+
+_ALL_VALUE_TYPES = (
+ {ir.TensorType(dtype) for dtype in ir.DataType}
+ | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}
+ | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType}
+)
+
+# TypeAnnotationValue represents the (value of) valid type-annotations recognized
+# by ONNX Script. Currently, it supports
+# - float, int, str (primitive attribute types)
+# - Sequence[float], Sequence[int], Sequence[str] (attribute types)
+# - Tensor types
+# - Sequence[Tensor] types
+# - Union of above 2
+# - TypeVars with above bounds
+# - Above types with annotation attached
+TypeAnnotationValue = Any
+
+
+@dataclasses.dataclass(frozen=True)
+class TypeConstraintParam:
+ """Type constraint for a parameter.
+
+ Attributes:
+ name: Name of the parameter. E.g. "TFloat"
+ allowed_types: Allowed types for the parameter.
+ """
+
+ name: str
+ allowed_types: set[ir.TypeProtocol]
+ description: str = ""
+
+ def __hash__(self) -> int:
+ return hash((self.name, tuple(self.allowed_types)))
+
+ def __str__(self) -> str:
+ allowed_types_str = " | ".join(str(t) for t in self.allowed_types)
+ return f"{self.name}={allowed_types_str}"
+
+ @classmethod
+ def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
+ return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description)
+
+ @classmethod
+ def any_value(cls, name: str, description: str = "") -> TypeConstraintParam:
+ return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type]
+
+
+@dataclasses.dataclass(frozen=True)
+class Parameter:
+ """A formal parameter of an operator."""
+
+ name: str
+ type_constraint: TypeConstraintParam
+ required: bool
+ variadic: bool
+ default: Any = _EMPTY_DEFAULT
+ # TODO: Add other properties too
+
+ def __str__(self) -> str:
+ type_str = self.type_constraint.name
+ if self.has_default():
+ return f"{self.name}: {type_str} = {self.default}"
+ return f"{self.name}: {type_str}"
+
+ def has_default(self) -> bool:
+ return self.default is not _EMPTY_DEFAULT
+
+
+@dataclasses.dataclass(frozen=True)
+class AttributeParameter:
+ """A parameter in the function signature that represents an ONNX attribute."""
+
+ name: str
+ type: ir.AttributeType
+ required: bool
+ default: ir.Attr | None = None
+
+ def __str__(self) -> str:
+ type_str = self.type.name
+ if self.has_default():
+ return f"{self.name}: {type_str} = {self.default}"
+ return f"{self.name}: {type_str}"
+
+ def has_default(self) -> bool:
+ return self.default is not None
+
+
+def _get_type_from_str(
+ type_str: str,
+) -> ir.TensorType | ir.SequenceType | ir.OptionalType:
+ """Converter a type_str from ONNX OpSchema to ir.TypeProtocol.
+
+ A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
+ """
+ # Split the type_str a sequence types and dtypes
+ # 1. Remove the ending ")"
+ striped = type_str.rstrip(")")
+ # 2. Split the type_str by "("
+ type_parts = striped.split("(")
+
+ # Convert the dtype to ir.DataType
+ dtype = ir.DataType[type_parts[-1].upper()]
+
+ # Create a place holder type first
+ type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED)
+
+ # Construct the type
+ for type_part in reversed(type_parts[:-1]):
+ if type_part == "tensor":
+ type_ = ir.TensorType(dtype)
+ elif type_part == "seq":
+ type_ = ir.SequenceType(type_)
+ elif type_part == "optional":
+ type_ = ir.OptionalType(type_)
+ else:
+ raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'")
+ return type_ # type: ignore[return-value]
+
+
+def _convert_formal_parameter(
+ param: onnx.defs.OpSchema.FormalParameter,
+ type_constraints: Mapping[str, TypeConstraintParam],
+) -> Parameter:
+ """Convert a formal parameter from ONNX OpSchema to Parameter."""
+ if param.type_str in type_constraints:
+ type_constraint = type_constraints[param.type_str]
+ else:
+ # param.type_str can be a plain type like 'int64'.
+ type_constraint = TypeConstraintParam(
+ name=param.name,
+ allowed_types={_get_type_from_str(param.type_str)},
+ )
+ return Parameter(
+ name=param.name,
+ type_constraint=type_constraint,
+ required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional,
+ variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic,
+ )
+
+
+def _is_optional(type_: type) -> bool:
+ """Returns whether a type_ is an Optional."""
+ origin_type = typing.get_origin(type_)
+ if origin_type is Union and type(None) in typing.get_args(type_):
+ # Python < 3.10
+ return True
+ if origin_type is Optional:
+ # Python >= 3.10
+ return True
+ if (
+ hasattr(types, "UnionType")
+ and origin_type is types.UnionType
+ and type(None) in typing.get_args(type_)
+ ):
+ # Python >= 3.10
+ return True
+ return False
+
+
+def _get_attr_type(type_: type) -> ir.AttributeType:
+ """Obtain the type of the attribute from a Python class."""
+ try:
+ if type_ in _PY_TYPE_TO_ATTR_TYPE:
+ return _PY_TYPE_TO_ATTR_TYPE[type_]
+ origin_type = typing.get_origin(type_)
+ if origin_type is None:
+ return ir.AttributeType.UNDEFINED
+ if origin_type in (
+ collections.abc.Sequence,
+ Sequence,
+ typing.List,
+ list,
+ typing.Tuple,
+ tuple,
+ ):
+ inner_type = typing.get_args(type_)[0]
+ if inner_type in _LIST_TYPE_TO_ATTR_TYPE:
+ return _LIST_TYPE_TO_ATTR_TYPE[inner_type]
+ except TypeError:
+ logger.warning("TypeError when checking %s.", type_, exc_info=True)
+ return ir.AttributeType.UNDEFINED
+
+
+def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None:
+ """Returns the name of the type constraint for a given type annotation.
+
+ Args:
+ type_: A Python type.
+
+ Returns:
+ The name of the type constraint if it is a TypeVar.
+ - Prefixes the name with "Sequence_" if the type annotation is a Sequence[].
+ """
+ if isinstance(type_, TypeVar):
+ return type_.__name__
+ if _is_optional(type_):
+ subtypes = typing.get_args(type_)
+ for subtype in subtypes:
+ if subtype is type(None):
+ continue
+ type_param_name = _get_type_constraint_name(subtype)
+ return type_param_name if type_param_name else None
+ origin_type = typing.get_origin(type_)
+ if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
+ subtypes = typing.get_args(type_)
+ type_param_name = _get_type_constraint_name(subtypes[0])
+ return f"Sequence_{type_param_name}" if type_param_name else None
+ return None
+
+
+def _get_allowed_types_from_type_annotation(
+ type_: TypeAnnotationValue,
+) -> set[ir.TypeProtocol]:
+ """Obtain the allowed types from a type annotation."""
+ if type_ is onnxscript.onnx_types.TensorType:
+ # Any tensor type
+ return {ir.TensorType(dtype) for dtype in ir.DataType}
+
+ allowed_types: set[ir.TypeProtocol]
+
+ if isinstance(type_, TypeVar):
+ allowed_types = set()
+ if constraints := type_.__constraints__:
+ for constraint in constraints:
+ allowed_types.update(_get_allowed_types_from_type_annotation(constraint))
+ else:
+ bound = type_.__bound__
+ if bound is None:
+ allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment]
+ else:
+ allowed_types.update(_get_allowed_types_from_type_annotation(bound))
+ return allowed_types
+ if hasattr(type_, "dtype"):
+ # A single tensor type like INT64, FLOAT, etc.
+ return {ir.TensorType(ir.DataType(type_.dtype))}
+ if _is_optional(type_):
+ allowed_types = set()
+ subtypes = typing.get_args(type_)
+ for subtype in subtypes:
+ if subtype is type(None):
+ continue
+ allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
+ # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful.
+ return allowed_types
+
+ origin_type = typing.get_origin(type_)
+ if origin_type is Union:
+ allowed_types = set()
+ subtypes = typing.get_args(type_)
+ for subtype in subtypes:
+ assert subtype is not type(None), (
+ "Union should not contain None type because it is handled by _is_optional."
+ )
+ allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
+ return allowed_types
+
+ if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
+ subtypes = typing.get_args(type_)
+ return {
+ ir.SequenceType(t) for t in _get_allowed_types_from_type_annotation(subtypes[0])
+ }
+
+ # Allow everything by default
+ return _ALL_VALUE_TYPES # type: ignore[return-value]
+
+
+@dataclasses.dataclass
+class OpSignature:
+ """Schema for an operator.
+
+ Attributes:
+ domain: Domain of the operator. E.g. "".
+ name: Name of the operator. E.g. "Add".
+ overload: Overload name of the operator.
+ params: Input parameters. When the op is an ONNX function definition,
+ the order is according to the function signature. This mean we can
+ interleave ONNX inputs and ONNX attributes in the list.
+ outputs: Output parameters.
+ """
+
+ domain: str
+ name: str
+ overload: str
+ params: Sequence[Parameter | AttributeParameter]
+ outputs: Sequence[Parameter]
+ params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field(
+ init=False, repr=False
+ )
+
+ def __post_init__(self):
+ self.params_map = {param.name: param for param in self.params}
+
+ def get(self, name: str) -> Parameter | AttributeParameter:
+ return self.params_map[name]
+
+ def __contains__(self, name: str) -> bool:
+ return name in self.params_map
+
+ def __iter__(self) -> Iterator[Parameter | AttributeParameter]:
+ return iter(self.params)
+
+ def __str__(self) -> str:
+ domain = self.domain or "''"
+ # TODO: Double check the separator for overload
+ overload = f"::{self.overload}" if self.overload else ""
+ params = ", ".join(str(param) for param in self.params)
+ outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs)
+ type_constraints = {}
+ for param in self.params:
+ if isinstance(param, Parameter):
+ type_constraints[param.type_constraint.name] = param.type_constraint
+ for param in self.outputs:
+ type_constraints[param.type_constraint.name] = param.type_constraint
+ type_constraints_str = ", ".join(
+ str(type_constraint) for type_constraint in type_constraints.values()
+ )
+ return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}"
+
+ @classmethod
+ def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature:
+ """Produce an OpSignature from an ONNX OpSchema."""
+ type_constraints = {
+ constraint.type_param_str: TypeConstraintParam(
+ name=constraint.type_param_str,
+ allowed_types={
+ _get_type_from_str(type_str) for type_str in constraint.allowed_type_strs
+ },
+ description=constraint.description,
+ )
+ for constraint in op_schema.type_constraints
+ }
+
+ params = [
+ _convert_formal_parameter(param, type_constraints) for param in op_schema.inputs
+ ]
+
+ for param in op_schema.attributes.values():
+ default_attr = (
+ ir.serde.deserialize_attribute(param.default_value)
+ if param.default_value is not None
+ else None
+ )
+ if default_attr is not None:
+ # Set the name of the default attribute because it may have a different name from the parameter
+ default_attr.name = param.name
+ params.append(
+ AttributeParameter(
+ name=param.name,
+ type=ir.AttributeType(param.type), # type: ignore[arg-type]
+ required=param.required,
+ default=default_attr, # type: ignore[arg-type]
+ )
+ )
+
+ outputs = [
+ _convert_formal_parameter(param, type_constraints) for param in op_schema.outputs
+ ]
+
+ return cls(
+ domain=op_schema.domain,
+ name=op_schema.name,
+ overload="",
+ params=params,
+ outputs=outputs,
+ )
+
+ @classmethod
+ def from_function(
+ cls, func, domain: str, name: str | None = None, overload: str = ""
+ ) -> OpSignature:
+ """Produce an OpSignature from a function using type annotation."""
+
+ py_signature = inspect.signature(func)
+ # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases
+ # https://github.com/python/cpython/issues/102405
+ type_hints = typing.get_type_hints(func)
+
+ params: list[Parameter | AttributeParameter] = []
+ # Create a mapping from type to a unique name
+ type_constraints: dict[str, TypeConstraintParam] = {}
+
+ for param in py_signature.parameters.values():
+ if param.name not in type_hints:
+ logger.warning(
+ "Missing annotation for parameter '%s' from %s. Treating as an Input.",
+ param.name,
+ py_signature,
+ )
+ type_constraint = TypeConstraintParam.any_value(f"T_{param.name}")
+ type_constraints[param.name] = type_constraint
+ params.append(
+ Parameter(
+ name=param.name,
+ type_constraint=type_constraint,
+ required=param.default is inspect.Parameter.empty,
+ # TODO: Handle variadic
+ variadic=False,
+ default=param.default
+ if param.default is not inspect.Parameter.empty
+ else _EMPTY_DEFAULT,
+ )
+ )
+ else:
+ type_ = type_hints[param.name]
+ if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
+ # Construct the default attribute
+ if param.default is not inspect.Parameter.empty:
+ # TODO: Use ir_convenience instead to handle int as float
+ default = ir.Attr(param.name, attr_type, param.default)
+ else:
+ default = None
+ params.append(
+ AttributeParameter(
+ name=param.name,
+ type=attr_type,
+ required=param.default is inspect.Parameter.empty,
+ default=default,
+ )
+ )
+ else:
+ # Obtain the type constraint from the type annotation
+
+ # 1. Get a type constraint name from the type annotation
+ # If the type annotation is a TypeVar or Optional[TypeVar], get its name
+ # Otherwise, name it T_{param.name}
+ type_constraint_name = _get_type_constraint_name(type_)
+ if type_constraint_name is None:
+ type_constraint_name = f"T_{param.name}"
+
+ # 2. If the type constraint param is already initialized, use it
+ if type_constraint_name in type_constraints:
+ type_constraint = type_constraints[type_constraint_name]
+ else:
+ # 3. Otherwise, create a new TypeConstraintParam
+ type_constraint = TypeConstraintParam(
+ name=type_constraint_name,
+ allowed_types=_get_allowed_types_from_type_annotation(type_),
+ )
+ type_constraints[type_constraint_name] = type_constraint
+ # 4. Create Parameter
+ params.append(
+ Parameter(
+ name=param.name,
+ type_constraint=type_constraint,
+ required=param.default is inspect.Parameter.empty,
+ # TODO: Handle variadic
+ variadic=False,
+ default=param.default
+ if param.default is not inspect.Parameter.empty
+ else _EMPTY_DEFAULT,
+ )
+ )
+
+ return_type = type_hints.get("return")
+
+ outputs = []
+ if return_type is None:
+ # No returns
+ pass
+ else:
+ if typing.get_origin(return_type) is tuple:
+ # Multiple returns
+ return_types = typing.get_args(return_type)
+ else:
+ return_types = [return_type] # type: ignore[assignment]
+
+ for i, return_type_i in enumerate(return_types):
+ if (
+ return_param_name := _get_type_constraint_name(return_type_i)
+ ) in type_constraints:
+ type_constraint = type_constraints[return_param_name]
+ else:
+ return_param_name = f"TReturn{i}"
+ type_constraint = TypeConstraintParam(
+ name=return_param_name,
+ allowed_types=_get_allowed_types_from_type_annotation(return_type_i),
+ )
+ type_constraints[return_param_name] = type_constraint
+ outputs.append(
+ Parameter(
+ name=return_param_name,
+ type_constraint=type_constraint,
+ required=True,
+ variadic=False,
+ default=_EMPTY_DEFAULT,
+ )
+ )
+
+ return cls(
+ domain=domain,
+ name=name or func.__name__,
+ overload=overload,
+ params=params,
+ outputs=outputs,
+ )
diff --git a/onnxscript/ir/_schemas_test.py b/onnxscript/ir/_schemas_test.py
new file mode 100644
index 0000000000..82082d031f
--- /dev/null
+++ b/onnxscript/ir/_schemas_test.py
@@ -0,0 +1,175 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+
+import unittest
+from typing import Any, Optional, Sequence, TypeVar, Union
+
+import parameterized
+
+import onnxscript
+from onnxscript import FLOAT, INT64, ir
+from onnxscript.ir import _schemas
+
+_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT)
+_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64)
+_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT])
+
+
+class TypeConversionFunctionsTest(unittest.TestCase):
+ @parameterized.parameterized.expand(
+ [
+ (
+ "tensor_type_all",
+ onnxscript.onnx_types.TensorType,
+ {ir.TensorType(dtype) for dtype in ir.DataType},
+ ),
+ ("tensor_type", INT64, {ir.TensorType(ir.DataType.INT64)}),
+ (
+ "tensor_type_union",
+ Union[INT64, FLOAT],
+ {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)},
+ ),
+ (
+ "tensor_type_variadic_shape",
+ INT64[...],
+ {ir.TensorType(ir.DataType.INT64)},
+ ),
+ ("tensor_type_shape", INT64[10], {ir.TensorType(ir.DataType.INT64)}),
+ (
+ "type_var_constraints",
+ _TestTypeVarConstraints,
+ {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)},
+ ),
+ (
+ "type_bound_one",
+ _TestTypeVarOneBound,
+ {ir.TensorType(ir.DataType.INT64)},
+ ),
+ (
+ "type_bound_two",
+ _TestTypeVarTwoBound,
+ {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)},
+ ),
+ (
+ "optional_tensor_type_all",
+ Optional[onnxscript.onnx_types.TensorType],
+ {ir.TensorType(dtype) for dtype in ir.DataType},
+ ),
+ (
+ "optional_tensor_type",
+ Optional[INT64],
+ {ir.TensorType(ir.DataType.INT64)},
+ ),
+ (
+ "optional_tensor_type_union",
+ Optional[Union[INT64, FLOAT]],
+ {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)},
+ ),
+ (
+ "optional_tensor_type_variadic_shape",
+ Optional[INT64[...]],
+ {ir.TensorType(ir.DataType.INT64)},
+ ),
+ (
+ "optional_tensor_type_shape",
+ Optional[INT64[10]],
+ {ir.TensorType(ir.DataType.INT64)},
+ ),
+ (
+ "optional_type_var_constraints",
+ Optional[_TestTypeVarConstraints],
+ {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)},
+ ),
+ (
+ "optional_type_bound_one",
+ Optional[_TestTypeVarOneBound],
+ {ir.TensorType(ir.DataType.INT64)},
+ ),
+ (
+ "optional_type_bound_two",
+ Optional[_TestTypeVarTwoBound],
+ {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)},
+ ),
+ (
+ "sequence_type_all",
+ Sequence[onnxscript.onnx_types.TensorType],
+ {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType},
+ ),
+ (
+ "sequence_type",
+ Sequence[INT64],
+ {ir.SequenceType(ir.TensorType(ir.DataType.INT64))},
+ ),
+ (
+ "union_sequence_type",
+ Union[Sequence[INT64], Sequence[FLOAT]],
+ {
+ ir.SequenceType(ir.TensorType(ir.DataType.INT64)),
+ ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)),
+ },
+ ),
+ (
+ "sequence_type_variadic_shape",
+ Sequence[INT64[...]],
+ {ir.SequenceType(ir.TensorType(ir.DataType.INT64))},
+ ),
+ (
+ "sequence_type_shape",
+ Sequence[INT64[10]],
+ {ir.SequenceType(ir.TensorType(ir.DataType.INT64))},
+ ),
+ (
+ "sequence_type_var_constraints",
+ Sequence[_TestTypeVarConstraints],
+ {
+ ir.SequenceType(ir.TensorType(ir.DataType.INT64)),
+ ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)),
+ },
+ ),
+ (
+ "sequence_type_bound_one",
+ Sequence[_TestTypeVarOneBound],
+ {ir.SequenceType(ir.TensorType(ir.DataType.INT64))},
+ ),
+ (
+ "sequence_type_bound_two",
+ Sequence[_TestTypeVarTwoBound],
+ {
+ ir.SequenceType(ir.TensorType(ir.DataType.INT64)),
+ ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)),
+ },
+ ),
+ ]
+ )
+ def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]):
+ self.assertEqual(_schemas._get_allowed_types_from_type_annotation(pytype), expected) # pylint: disable=protected-access
+
+ @parameterized.parameterized.expand(
+ [
+ ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"),
+ ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"),
+ (
+ "optional_type_var",
+ Optional[_TestTypeVarOneBound],
+ "_TestTypeVarOneBound",
+ ),
+ (
+ "sequence_type_var",
+ Sequence[_TestTypeVarOneBound],
+ "Sequence__TestTypeVarOneBound",
+ ),
+ ("normal_type", INT64, None),
+ ("union_type", Union[INT64, FLOAT], None),
+ ("optional_type", Optional[INT64], None),
+ ("sequence_type", Sequence[INT64], None),
+ ("optional_sequence_type", Optional[Sequence[INT64]], None),
+ ("optional_union_type", Optional[Union[INT64, FLOAT]], None),
+ ]
+ )
+ def test_get_type_constraint_name(self, _: str, pytype: Any, expected: str | None):
+ self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) # pylint: disable=protected-access
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py
new file mode 100644
index 0000000000..78dce2739e
--- /dev/null
+++ b/onnxscript/ir/_tape.py
@@ -0,0 +1,71 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""Convenience methods for constructing the IR."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Optional, Sequence
+
+from onnx_ir import tape
+
+if TYPE_CHECKING:
+ import onnx_ir as ir
+
+
+# A type representing the domains/versions used in creating nodes in IR.
+UsedOpsets = set[tuple[str, Optional[int]]]
+
+
+class Builder(tape.Tape):
+ """An extension of the tape that provides a more convenient API for constructing the IR.
+
+ Example:
+ >>> from onnxscript import ir
+ >>> from onnxscript.ir import _tape
+ >>> op = _tape.Builder()
+ >>> input = ir.Value(name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)))
+ >>> relu_val = op.Relu(input, _name="relu_node", _domain="", _version=18, _outputs=["relu_out"])
+
+ Note: When passing `_name`, ensure it is unique to avoid duplicate node names.
+ """
+
+ def __getattr__(self, op_type: str) -> Any:
+ return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)
+
+ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
+ domain = kwargs.pop("_domain", "")
+ version = kwargs.pop("_version", None)
+ outputs = kwargs.pop("_outputs", 1)
+ name = kwargs.pop("_name", None)
+
+ if isinstance(outputs, Sequence):
+ num_outputs = len(outputs)
+ else:
+ assert isinstance(outputs, int)
+ num_outputs = outputs
+
+ if num_outputs == 1:
+ value = super().op(
+ op_type,
+ inputs=inputs,
+ attributes=kwargs,
+ domain=domain,
+ version=version,
+ name=name,
+ )
+ if isinstance(outputs, Sequence):
+ value.name = outputs[0]
+ return value
+ values = super().op_multi_out(
+ op_type,
+ inputs=inputs,
+ attributes=kwargs,
+ domain=domain,
+ version=version,
+ name=name,
+ num_outputs=num_outputs,
+ )
+ if isinstance(outputs, Sequence):
+ for value, name in zip(values, outputs):
+ value.name = name
+ return values
diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py
new file mode 100644
index 0000000000..f8210e7a0b
--- /dev/null
+++ b/onnxscript/ir/_tape_test.py
@@ -0,0 +1,104 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+
+import unittest
+
+from onnxscript import ir
+from onnxscript.ir import _tape
+
+
+class TestTape(unittest.TestCase):
+ def test_op(self):
+ # Create a simple ONNX model with shape inference
+ # Define the model
+ inputs = [
+ ir.Value(
+ name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
+ ),
+ ir.Value(
+ name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
+ ),
+ ]
+
+ tape = ir.tape.Tape()
+
+ _ = tape.op("Add", inputs=inputs)
+
+ self.assertEqual([n.op_type for n in tape.nodes], ["Add"])
+
+ def test_initializers(self):
+ inputs = [
+ ir.Value(
+ name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
+ ),
+ ir.Value(
+ name="input_b",
+ type=ir.TensorType(ir.DataType.FLOAT),
+ shape=ir.Shape((2, 1)),
+ const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT),
+ ),
+ ]
+
+ tape = ir.tape.Tape()
+
+ # Shape and type are not explicitly set for the initializer but it should still work
+ initializer = tape.initializer(
+ ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT), name="initializer"
+ )
+ val_add = tape.op("Add", inputs=inputs)
+ _ = tape.op("Mul", inputs=[val_add, initializer])
+
+ self.assertEqual([n.op_type for n in tape.nodes], ["Add", "Mul"])
+ self.assertEqual(tape.initializers, (initializer,))
+
+ def test_op_multi_out(self):
+ inputs = [
+ ir.Value(
+ name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
+ ),
+ ir.Value(
+ name="input_b",
+ type=ir.TensorType(ir.DataType.FLOAT),
+ shape=ir.Shape((2, 1)),
+ const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT),
+ ),
+ ]
+
+ tape = ir.tape.Tape()
+
+ out1, out2, out3 = tape.op_multi_out("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking
+ _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3])
+
+ self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"])
+
+
+class TestBuilder(unittest.TestCase):
+ def test_op_name(self):
+ op = _tape.Builder()
+
+ input_a = ir.Value(
+ name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
+ )
+ input_b = ir.Value(
+ name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
+ )
+
+ add = op.Add(input_a, input_b, _name="add_node")
+ _ = op.Relu(add, _name="relu_node")
+ self.assertEqual(op.nodes[0].name, "add_node")
+ self.assertEqual(op.nodes[1].name, "relu_node")
+
+ def test_op_name_multi_out(self):
+ op = _tape.Builder()
+
+ input_a = ir.Value(
+ name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
+ )
+
+ _ = op.CustomOp(input_a, _name="custom_node", _outputs=3)
+ self.assertEqual(op.nodes[0].name, "custom_node")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py
deleted file mode 100644
index abe825f84b..0000000000
--- a/onnxscript/ir/_type_casting.py
+++ /dev/null
@@ -1,84 +0,0 @@
-"""Numpy utilities for non-native type operation."""
-# TODO(justinchuby): Upstream the logic to onnx
-
-from __future__ import annotations
-
-import typing
-from typing import Sequence
-
-import numpy as np
-
-if typing.TYPE_CHECKING:
- import numpy.typing as npt
-
-
-def pack_int4(array: np.ndarray) -> npt.NDArray[np.uint8]:
- """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
- # Create a 1D copy
- array_flat = array.ravel().astype(np.uint8)
- size = array.size
- odd_sized = size % 2 == 1
- if odd_sized:
- array_flat.resize([size + 1], refcheck=False)
- array_flat &= 0x0F
- array_flat[1::2] <<= 4
- return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
-
-
-def unpack_uint4(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
- """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
-
- Args:
- data: A numpy array.
- dims: The dimensions are used to reshape the unpacked buffer.
-
- Returns:
- A numpy array of int8/uint8 reshaped to dims.
- """
- result = np.empty([data.size * 2], dtype=data.dtype)
- array_low = data & np.uint8(0x0F)
- array_high = data & np.uint8(0xF0)
- array_high >>= np.uint8(4)
- result[0::2] = array_low
- result[1::2] = array_high
- if result.size == np.prod(dims) + 1:
- # handle single-element padding due to odd number of elements
- result = result[:-1]
- result.resize(dims, refcheck=False)
- return result
-
-
-def _int4_to_int8(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]:
- """Extend 4-bit signed integer to 8-bit signed integer."""
- return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8)
-
-
-def unpack_int4(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.int8]:
- """Convert a packed (signed) int4 array to unpacked int4 array represented as int8.
-
- The sign bit is extended to the most significant bit of the int8.
-
- Args:
- data: A numpy array.
- dims: The dimensions are used to reshape the unpacked buffer.
-
- Returns:
- A numpy array of int8 reshaped to dims.
- """
- unpacked = unpack_uint4(data, dims)
- return _int4_to_int8(unpacked)
-
-
-def float32_to_bfloat16(array: npt.NDArray[np.float32]) -> npt.NDArray[np.uint16]:
- """Convert a numpy array to uint16 representation of bfloat16."""
- bfloat16_array = array.astype(np.float32).view(np.uint32)
- # Drop bottom 16-bits
- # Round remaining bits using round-to-nearest-even
- rounded = bfloat16_array >> 16
- rounded &= 1
- rounded += 0x7FFF
- bfloat16_array += rounded # type: ignore[arg-type]
- bfloat16_array >>= 16
- # NaN requires at least 1 significant bit set
- bfloat16_array[np.isnan(array)] = 0x7FC0 # sign=0, exp=all-ones, sig=0b1000000
- return bfloat16_array.astype(np.uint16)
diff --git a/onnxscript/ir/_type_casting_test.py b/onnxscript/ir/_type_casting_test.py
deleted file mode 100644
index 544146e6b1..0000000000
--- a/onnxscript/ir/_type_casting_test.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import unittest
-
-import numpy as np
-import parameterized
-
-from onnxscript.ir import _type_casting
-
-
-class TypeCastingTest(unittest.TestCase):
- @parameterized.parameterized.expand(
- [
- ("signed", np.float32),
- ("unsigned", np.uint32),
- ]
- )
- def test_pack_int4_even_sized_array(self, _: str, dtype):
- array = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype)
- expected = np.array([0x21, 0x43, 0x65, 0x87], dtype=np.uint8)
- actual = _type_casting.pack_int4(array)
- np.testing.assert_array_equal(actual, expected)
-
- @parameterized.parameterized.expand(
- [
- ("signed", np.float32),
- ("unsigned", np.uint32),
- ]
- )
- def test_pack_int4_odd_sized_array(self, _: str, dtype):
- array = np.array([1, 2, 3, 4, 5], dtype=dtype)
- expected = np.array([0x21, 0x43, 0x5], dtype=np.uint8)
- actual = _type_casting.pack_int4(array)
- np.testing.assert_array_equal(actual, expected)
-
- @parameterized.parameterized.expand(
- [
- ("signed", np.float32),
- ("unsigned", np.uint32),
- ]
- )
- def test_pack_int4_returns_flatten_array(self, _: str, dtype):
- array = np.array([[[1, 2, 3, 4, 5]]], dtype=dtype)
- expected = np.array([0x21, 0x43, 0x5], dtype=np.uint8)
- actual = _type_casting.pack_int4(array)
- np.testing.assert_array_equal(actual, expected)
-
- @parameterized.parameterized.expand(
- [
- ("negative_infinity", np.uint16(0b1_11111111_0000000)),
- ("negative_min_normal", np.uint16(0b1_11111110_1111111)),
- ("negative_max_normal", np.uint16(0b1_00000001_0000000)),
- ("negative_min_subnormal", np.uint16(0b1_00000000_1111111)),
- ("negative_max_subnormal", np.uint16(0b1_00000000_0000001)),
- ("negative_zero", np.uint16(0b1_00000000_0000000)),
- ("positive_zero", np.uint16(0b0_00000000_0000000)),
- ("positive_min_subnormal", np.uint16(0b0_00000000_0000001)),
- ("positive_max_subnormal", np.uint16(0b0_00000000_1111111)),
- ("positive_min_normal", np.uint16(0b0_00000001_0000000)),
- ("positive_max_normal", np.uint16(0b0_11111110_1111111)),
- ("positive_infinity", np.uint16(0b0_11111111_0000000)),
- ("positive_nan", np.uint16(0b0_11111111_1000000)),
- ("positive_one", np.uint16(0b0_00111111_0000000)),
- ("negative_one", np.uint16(0b1_00111111_0000000)),
- ]
- )
- def test_float32_to_bfloat16(self, _: str, binary: np.uint16):
- value = np.array([binary << 16]).astype(np.uint32).view(np.float32)
- expected = np.array([binary])
- actual = _type_casting.float32_to_bfloat16(value)
- np.testing.assert_array_equal(actual, expected)
-
-
-if __name__ == "__main__":
- unittest.main(verbosity=2)
diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py
new file mode 100644
index 0000000000..e248a5fa84
--- /dev/null
+++ b/onnxscript/ir/convenience.py
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# pylint: disable=wildcard-import,unused-wildcard-import
+from onnx_ir.convenience import * # type: ignore # noqa: F403
diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py
new file mode 100644
index 0000000000..5310f1740a
--- /dev/null
+++ b/onnxscript/ir/passes/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+__all__ = [
+ "PassBase",
+ "PassResult",
+ "PassManager",
+ "Sequential",
+ "InPlacePass",
+ "FunctionalPass",
+ # Errors
+ "InvariantError",
+ "PreconditionError",
+ "PostconditionError",
+ "PassError",
+]
+
+from onnx_ir.passes import (
+ FunctionalPass,
+ InPlacePass,
+ InvariantError,
+ PassBase,
+ PassError,
+ PassManager,
+ PassResult,
+ PostconditionError,
+ PreconditionError,
+ Sequential,
+)
diff --git a/onnxscript/ir/passes/common/__init__.py b/onnxscript/ir/passes/common/__init__.py
new file mode 100644
index 0000000000..5a5ddbe52f
--- /dev/null
+++ b/onnxscript/ir/passes/common/__init__.py
@@ -0,0 +1,34 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+__all__ = [
+ "AddInitializersToInputsPass",
+ "CheckerPass",
+ "ClearMetadataAndDocStringPass",
+ "CommonSubexpressionEliminationPass",
+ "InlinePass",
+ "LiftConstantsToInitializersPass",
+ "LiftSubgraphInitializersToMainGraphPass",
+ "RemoveInitializersFromInputsPass",
+ "RemoveUnusedFunctionsPass",
+ "RemoveUnusedNodesPass",
+ "RemoveUnusedOpsetsPass",
+ "ShapeInferencePass",
+ "TopologicalSortPass",
+]
+
+from onnx_ir.passes.common import (
+ AddInitializersToInputsPass,
+ CheckerPass,
+ ClearMetadataAndDocStringPass,
+ CommonSubexpressionEliminationPass,
+ InlinePass,
+ LiftConstantsToInitializersPass,
+ LiftSubgraphInitializersToMainGraphPass,
+ RemoveInitializersFromInputsPass,
+ RemoveUnusedFunctionsPass,
+ RemoveUnusedNodesPass,
+ RemoveUnusedOpsetsPass,
+ ShapeInferencePass,
+ TopologicalSortPass,
+)
diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py
deleted file mode 100644
index 6060c881bc..0000000000
--- a/onnxscript/ir/serde.py
+++ /dev/null
@@ -1,1350 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-"""Serialize and deserialize the intermediate representation to/from ONNX protos."""
-
-# NOTES for developers:
-# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead.
-#
-# NOTE: Protobuf serialization
-# Initializing a protobuf message with initialized protobuf messages incurs
-# a copy and is slow. Instead, use proto.add() to add to a repeated field.
-# or initialize the message first and then set the fields if the fields are
-# plain Python objects.
-
-from __future__ import annotations
-
-__all__ = [
- # Tensors
- "TensorProtoTensor",
- # Deserialization
- "deserialize_attribute",
- "deserialize_function",
- "deserialize_graph",
- "deserialize_model",
- "deserialize_node",
- "deserialize_opset_import",
- "deserialize_tensor",
- "deserialize_type_proto_for_shape",
- "deserialize_type_proto_for_type",
- "deserialize_value_info_proto",
- # Serialization
- "serialize_attribute_into",
- "serialize_attribute",
- "serialize_dimension_into",
- "serialize_function_into",
- "serialize_function",
- "serialize_graph_into",
- "serialize_graph",
- "serialize_model_into",
- "serialize_model",
- "serialize_node_into",
- "serialize_node",
- "serialize_shape_into",
- "serialize_reference_attribute_into",
- "serialize_tensor_into",
- "serialize_tensor",
- "serialize_type_into",
- "serialize_value_into",
- "serialize_value",
-]
-
-import collections
-import logging
-import os
-import typing
-from typing import Any, List, Mapping, Sequence
-
-import numpy as np
-import onnx
-import onnx.external_data_helper
-
-from onnxscript.ir import _core, _enums, _metadata, _protocols, _type_casting
-
-if typing.TYPE_CHECKING:
- import google.protobuf.internal.containers as proto_containers
- import numpy.typing as npt
-
-logger = logging.getLogger(__name__)
-
-_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
- 10 # ONNX IR version where value info in functions was introduced
-)
-
-
-def _little_endian_dtype(dtype) -> np.dtype:
- """Create a small endian dtype on all platforms.
-
- This is useful because ONNX always stores raw_data in small endian. On big
- endian platforms, we still need to interpret the raw_data in small endian.
- """
- return np.dtype(dtype).newbyteorder("<")
-
-
-def _unflatten_complex(
- array: npt.NDArray[np.float32 | np.float64],
-) -> npt.NDArray[np.complex64 | np.complex128]:
- """Convert the real representation of a complex dtype to the complex dtype."""
- return array[::2] + 1j * array[1::2]
-
-
-class TensorProtoTensor(_core.TensorBase):
- """A tensor initialized from a tensor proto."""
-
- def __init__(self, proto: onnx.TensorProto) -> None:
- self._proto = proto
- self._metadata_props: dict[str, str] | None = deserialize_metadata_props(
- proto.metadata_props
- )
- self._metadata: _metadata.MetadataStore | None = None
-
- @property
- def name(self) -> str:
- return self._proto.name
-
- @property
- def shape(self) -> _core.Shape:
- return _core.Shape(self._proto.dims, frozen=True)
-
- @property
- def dtype(self) -> _enums.DataType:
- return _enums.DataType(self._proto.data_type)
-
- @property
- def doc_string(self) -> str:
- return self._proto.doc_string
-
- @property
- def raw(self) -> onnx.TensorProto:
- return self._proto
-
- def __repr__(self) -> str:
- # It is a little hard to display the content when there can be types
- # unsupported by numpy
- # Preferably we should display some content when the tensor is small
- return f"{self._repr_base()}(name={self.name!r})"
-
- def __array__(self, dtype: Any = None) -> np.ndarray:
- """Return the tensor as a numpy array, compatible with np.array."""
- return self.numpy().__array__(dtype)
-
- def numpy(self) -> np.ndarray:
- """Return the tensor as a numpy array.
-
- This is an improved version of onnx.numpy_helper.to_array.
- It first reads the data using the dtype corresponding to the tensor
- proto data field, then converts it to the correct dtype and shape.
- Special cases are bfloat16, complex and int4 where we need to
- reinterpret the data. Other types can simply be casted.
-
- When the data type is not supported by numpy, the value is the bit representation
- of the dtype:
-
- - ``int8`` for int4, with the sign bit extended to 8 bits.
- - ``uint8`` for uint4.
- - ``uint8`` for 8-bit data types like float8.
- - ``uint16`` for bfloat16.
-
- When the data type is a string, this method returns a numpy array
- of bytes instead of a numpy array of strings, to follow the ONNX
- specification.
-
- External tensors are not supported by this class. Use
- :class:`onnxscript.ir.ExternalTensor` instead.
-
- Raises:
- ValueError: If the data type is UNDEFINED.
- """
- dtype = self.dtype
- if dtype == _enums.DataType.UNDEFINED:
- raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")
- if self._proto.data_location == onnx.TensorProto.EXTERNAL:
- raise ValueError(
- "Cannot convert external tensor to numpy array. "
- "Use ir.ExternalTensor instead."
- )
-
- if self._proto.HasField("raw_data"):
- array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<"))
- # Cannot return now, because we may need to unpack 4bit tensors
- elif dtype == _enums.DataType.STRING:
- return np.array(self._proto.string_data).reshape(self._proto.dims)
- elif self._proto.int32_data:
- array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
- if dtype == _enums.DataType.FLOAT16:
- # Reinterpret the int32 as float16; bfloat16 is handled on the last line
- array = array.astype(np.uint16).view(np.float16)
- elif self._proto.int64_data:
- array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64))
- elif self._proto.uint64_data:
- array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
- elif self._proto.float_data:
- array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
- if dtype == _enums.DataType.COMPLEX64:
- array = _unflatten_complex(array)
- elif self._proto.double_data:
- array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
- if dtype == _enums.DataType.COMPLEX128:
- array = _unflatten_complex(array)
- else:
- # Empty tensor
- if not self._proto.dims:
- # When dims not precent and there is no data, we return an empty array
- return np.array([], dtype=dtype.numpy())
- else:
- # Otherwise we return a size 0 array with the correct shape
- return np.zeros(self._proto.dims, dtype=dtype.numpy())
-
- if dtype == _enums.DataType.INT4:
- return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
- elif dtype == _enums.DataType.UINT4:
- return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
- else:
- # Otherwise convert to the correct dtype and reshape
- # Note we cannot use view() here because the storage dtype may not be the same size as the target
- return array.astype(dtype.numpy()).reshape(self._proto.dims)
-
- def tobytes(self) -> bytes:
- """Return the tensor as a byte string conformed to the ONNX specification, in little endian.
-
- Raises:
- ValueError: If the tensor is a string tensor or an external tensor.
- ValueError: If the tensor is of UNDEFINED data type.
- """
- if self._proto.data_location == onnx.TensorProto.EXTERNAL:
- raise ValueError(
- "Cannot convert external tensor to bytes. Use ir.ExternalTensor instead."
- )
- if self.dtype == _enums.DataType.STRING:
- raise ValueError("Cannot convert string tensor to bytes.")
- if self.dtype == _enums.DataType.UNDEFINED:
- raise ValueError("Cannot convert UNDEFINED tensor to bytes.")
-
- if self._proto.HasField("raw_data"):
- return self._proto.raw_data
- if self._proto.float_data:
- return np.array(
- self._proto.float_data, dtype=_little_endian_dtype(np.float32)
- ).tobytes()
- if self._proto.int32_data:
- array = np.array(self._proto.int32_data, dtype=np.int32)
- if self.dtype in {
- _enums.DataType.INT16,
- _enums.DataType.UINT16,
- _enums.DataType.FLOAT16,
- _enums.DataType.BFLOAT16,
- }:
- return array.astype(_little_endian_dtype(np.uint16)).tobytes()
- if self.dtype in {
- _enums.DataType.INT8,
- _enums.DataType.UINT8,
- _enums.DataType.BOOL,
- _enums.DataType.FLOAT8E4M3FN,
- _enums.DataType.FLOAT8E4M3FNUZ,
- _enums.DataType.FLOAT8E5M2,
- _enums.DataType.FLOAT8E5M2FNUZ,
- _enums.DataType.INT4,
- _enums.DataType.UINT4,
- }:
- # uint4 and int4 values are already packed, even when stored as int32
- # so we don't need to pack them again
- return array.astype(_little_endian_dtype(np.uint8)).tobytes()
- assert self.dtype == _enums.DataType.INT32
- return array.tobytes()
- if self._proto.int64_data:
- return np.array(
- self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
- ).tobytes()
- if self._proto.double_data:
- return np.array(
- self._proto.double_data, dtype=_little_endian_dtype(np.float64)
- ).tobytes()
- if self._proto.uint64_data:
- array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
- if self.dtype == _enums.DataType.UINT32:
- return array.astype(_little_endian_dtype(np.uint32)).tobytes()
- assert self.dtype == _enums.DataType.UINT64
- return array.tobytes()
- # The repeating fields can be empty and still valid.
- # For example, int32_data can be empty and still be a valid tensor.
- return b""
-
- @property
- def meta(self) -> _metadata.MetadataStore:
- """The metadata store for intermediate analysis.
-
- Write to the :attribute:`metadata_props` if you would like the metadata to be serialized
- to the ONNX proto.
- """
- if self._metadata is None:
- self._metadata = _metadata.MetadataStore()
- return self._metadata
-
- @property
- def metadata_props(self) -> dict[str, str]:
- if self._metadata_props is None:
- self._metadata_props = {}
- return self._metadata_props
-
-
-def _get_field(proto: Any, field: str) -> Any:
- if proto.HasField(field):
- return getattr(proto, field)
- return None
-
-
-# Deserialization
-
-
-def deserialize_opset_import(
- protos: Sequence[onnx.OperatorSetIdProto],
-) -> dict[str, int]:
- return {opset.domain: opset.version for opset in protos}
-
-
-def _parse_experimental_function_value_info_name(
- name: str,
-) -> tuple[str, str, str] | None:
- """Get the function domain, name and value name if the value info is for a function.
-
- The experimental format is:
- {function_domain}::{function_name}/{value_name}
-
- Args:
- name: The name stored in the value info.
-
- Returns:
- A tuple of the function domain, function name and value name if the value info is for a function.
- None otherwise.
- """
- parts = name.split("/")
- expected_parts = 2
- if len(parts) != expected_parts:
- return None
- function, value_name = parts
- parts = function.split("::")
- if len(parts) != expected_parts:
- return None
- # NOTE: There will not be overload because overloads are introduced in ONNX IR v10, which also
- # introduces the ValueInfoProto for functions
- function_domain, function_name = parts
- return function_domain, function_name, value_name
-
-
-def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
- graph = _deserialize_graph(proto.graph, [])
- graph.opset_imports.update(deserialize_opset_import(proto.opset_import))
-
- functions = []
- for func in proto.functions:
- functions.append(deserialize_function(func))
-
- model = _core.Model(
- graph,
- ir_version=proto.ir_version,
- producer_name=_get_field(proto, "producer_name"),
- producer_version=_get_field(proto, "producer_version"),
- domain=_get_field(proto, "domain"),
- model_version=_get_field(proto, "model_version"),
- doc_string=_get_field(proto, "doc_string"),
- functions=functions,
- meta_data_props=deserialize_metadata_props(proto.metadata_props),
- )
-
- # Handle experimental value info for functions created by the dynamo exporter in IR version 9
- if model.ir_version < _FUNCTION_VALUE_INFO_SUPPORTED_VERSION:
- _deserialized_experimental_value_info_for_function_ir9(
- model.functions, proto.graph.value_info
- )
-
- return model
-
-
-def _deserialized_experimental_value_info_for_function_ir9(
- functions: Mapping[_protocols.OperatorIdentifier, _core.Function],
- value_info_protos: Sequence[onnx.ValueInfoProto],
-) -> None:
- """Deserialize value info for functions when they are stored in an experimental format.
-
- The experimental format is:
- {function_domain}::{function_name}/{value_name}
- """
- # Parse value info for functions from the main graph
- function_value_value_info_mapping: collections.defaultdict[
- _protocols.OperatorIdentifier,
- dict[str, onnx.ValueInfoProto],
- ] = collections.defaultdict(dict)
- for value_info_proto in value_info_protos:
- if (
- parsed := _parse_experimental_function_value_info_name(value_info_proto.name)
- ) is None:
- continue
- function_domain, function_name, value_name = parsed
- function_overload = ""
- # TODO(justinchuby): Create a constructor for OperatorIdentifier so we don't create tuples manually
- function_id = (function_domain, function_name, function_overload)
- function = functions.get(function_id)
- if function is None:
- # Function not found
- logger.debug(
- "Function with ID '%s' not found in model functions. Value info '%s' will be ignored.",
- function_id,
- value_info_proto.name,
- )
- continue
- function_value_value_info_mapping[function_id][value_name] = value_info_proto
- for function_id, function in functions.items():
- for input in function.inputs:
- if input.name in function_value_value_info_mapping[function_id]:
- deserialize_value_info_proto(
- function_value_value_info_mapping[function_id][input.name], input
- )
- for node in function:
- for output in node.outputs:
- if output.name in function_value_value_info_mapping[function_id]:
- deserialize_value_info_proto(
- function_value_value_info_mapping[function_id][output.name],
- output,
- )
- # The function outputs are handled as well because they are also node outputs
-
-
-def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
- return _deserialize_graph(proto, [])
-
-
-def _deserialize_graph(
- proto: onnx.GraphProto, scoped_values: list[dict[str, _core.Value]]
-) -> _core.Graph:
- """Deserialize a graph proto, recursively if needed.
-
- Args:
- proto: The graph proto to deserialize.
- scoped_values: A list of dictionaries mapping value names to their corresponding Value objects.
- Every time we enter a new graph, a new scope is created and appended to this list to include
- all values defined in the scope.
- scoped_value_info: A list of dictionaries mapping value names to their corresponding ValueInfoProto.
- """
- # Create values for initializers and inputs
- initializers = [deserialize_tensor(tensor) for tensor in proto.initializer]
- inputs = [_core.Input(info.name) for info in proto.input]
- for info, value in zip(proto.input, inputs):
- deserialize_value_info_proto(info, value)
-
- # Initialize the values dictionary for this graph scope with the inputs and initializers
- values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
- scoped_values.append(values)
- for initializer in initializers:
- if initializer.name in values:
- # The initializer is for an input
- values[initializer.name].const_value = initializer
- else:
- # The initializer is for some other value. Create this value first
- initializer_value = _core.Value(
- None,
- index=None,
- name=initializer.name,
- # TODO(justinchuby): Fix type hinting for shape and dtype
- shape=initializer.shape, # type: ignore
- type=_core.TensorType(initializer.dtype),
- const_value=initializer,
- )
- values[initializer.name] = initializer_value
-
- # Add ValueInfos for this graph scope
- value_info = {info.name: info for info in proto.value_info}
-
- # Deserialize nodes with all known values
- nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node]
-
- # Fill in values for graph outputs
- outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
- scoped_values.pop()
- return _core.Graph(
- inputs,
- outputs,
- nodes=nodes,
- # TODO(justinchuby): Attach the values associated with the initializers
- initializers=initializers,
- doc_string=_get_field(proto, "doc_string"),
- name=_get_field(proto, "name"),
- metadata_props=deserialize_metadata_props(proto.metadata_props),
- )
-
-
-def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
- inputs = [_core.Input(name) for name in proto.input]
- values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
- value_info = {info.name: info for info in getattr(proto, "value_info", [])}
-
- # TODO(justinchuby): Handle unsorted nodes
- nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node]
- outputs = [values[name] for name in proto.output]
- graph = _core.Graph(
- inputs,
- outputs,
- nodes=nodes,
- initializers=(),
- doc_string=_get_field(proto, "doc_string"),
- opset_imports=deserialize_opset_import(proto.opset_import),
- name=(
- f"{proto.name}_{proto.domain}" + f"__{proto.overload}"
- if hasattr(proto, "overload") and proto.overload
- else ""
- ),
- )
- attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto]
- # Attributes without defaults
- attributes += [
- _core.Attr(name, _enums.AttributeType.UNDEFINED, None) for name in proto.attribute
- ]
- return _core.Function(
- domain=proto.domain,
- name=proto.name,
- overload=getattr(proto, "overload", ""),
- graph=graph,
- attributes=typing.cast(List[_core.Attr], attributes),
- metadata_props=deserialize_metadata_props(proto.metadata_props),
- )
-
-
-def deserialize_value_info_proto(
- proto: onnx.ValueInfoProto, value: _core.Value | None
-) -> _core.Value:
- if value is None:
- value = _core.Value(None, index=None, name=proto.name)
- value.shape = deserialize_type_proto_for_shape(proto.type)
- value.type = deserialize_type_proto_for_type(proto.type)
- metadata_props = deserialize_metadata_props(proto.metadata_props)
- if metadata_props is not None:
- value.metadata_props.update(metadata_props)
- value.doc_string = _get_field(proto, "doc_string")
- return value
-
-
-def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
- if proto.HasField("tensor_type"):
- if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
- return None
- # This logic handles when the shape is [] as well
- dim_protos = shape_proto.dim
- deserialized_dim_denotations = [
- deserialize_dimension(dim_proto) for dim_proto in dim_protos
- ]
- dims = [dim for dim, _ in deserialized_dim_denotations]
- denotations = [denotation for _, denotation in deserialized_dim_denotations]
- return _core.Shape(dims, denotations=denotations, frozen=True)
- if proto.HasField("sparse_tensor_type"):
- if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None:
- return None
- dim_protos = shape_proto.dim
- deserialized_dim_denotations = [
- deserialize_dimension(dim_proto) for dim_proto in dim_protos
- ]
- dims = [dim for dim, _ in deserialized_dim_denotations]
- denotations = [denotation for _, denotation in deserialized_dim_denotations]
- return _core.Shape(dims, denotations=denotations, frozen=True)
- if proto.HasField("sequence_type"):
- if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
- return None
- return deserialize_type_proto_for_shape(elem_type)
- if proto.HasField("optional_type"):
- if (elem_type := _get_field(proto.optional_type, "elem_type")) is None:
- return None
- return deserialize_type_proto_for_shape(elem_type)
- if proto.HasField("map_type"):
- # TODO(justinchuby): Do we need to support map types?
- raise NotImplementedError("Map types are not supported yet")
-
- return None
-
-
-def deserialize_type_proto_for_type(
- proto: onnx.TypeProto,
-) -> _protocols.TypeProtocol | None:
- denotation = _get_field(proto, "denotation")
- if proto.HasField("tensor_type"):
- if (elem_type := _get_field(proto.tensor_type, "elem_type")) is None:
- return None
- return _core.TensorType(_enums.DataType(elem_type), denotation=denotation)
- if proto.HasField("sparse_tensor_type"):
- if (elem_type := _get_field(proto.sparse_tensor_type, "elem_type")) is None:
- return None
- return _core.SparseTensorType(_enums.DataType(elem_type), denotation=denotation)
- if proto.HasField("sequence_type"):
- # FIXME(justinchuby): Allow nested types being None
- if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
- raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}")
- nested_type = deserialize_type_proto_for_type(elem_type)
- if nested_type is None:
- raise ValueError(f"SequenceType must have elem_type set: {proto}")
- return _core.SequenceType(nested_type, denotation=denotation)
- if proto.HasField("optional_type"):
- # FIXME(justinchuby): Allow nested types being None
- if (elem_type := _get_field(proto.optional_type, "elem_type")) is None:
- raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}")
- nested_type = deserialize_type_proto_for_type(elem_type)
- if nested_type is None:
- raise ValueError(f"SequenceType must have elem_type set: {proto}")
- return _core.OptionalType(nested_type, denotation=denotation)
- if proto.HasField("map_type"):
- # TODO(justinchuby): Do we need to support map types?
- raise NotImplementedError("Map types are not supported yet")
-
- return None
-
-
-def deserialize_dimension(
- proto: onnx.TensorShapeProto.Dimension,
-) -> tuple[int | _core.SymbolicDim, str | None]:
- """Deserialize a dimension proto into (dimension, denotation).
-
- Args:
- proto: The dimension proto to deserialize.
-
- Returns:
- A tuple of the dimension and its denotation.
- """
- value_field = proto.WhichOneof("value")
- denotation = _get_field(proto, "denotation")
- if value_field is not None:
- value = getattr(proto, value_field)
- if value_field == "dim_value":
- return value, denotation
- if value_field == "dim_param":
- return _core.SymbolicDim(value), denotation
- return _core.SymbolicDim(None), denotation
-
-
-def deserialize_tensor(
- proto: onnx.TensorProto, base_path: str | os.PathLike = ""
-) -> _protocols.TensorProtocol:
- # TODO: Sanitize base_path
- if proto.data_location == onnx.TensorProto.EXTERNAL:
- external_info = onnx.external_data_helper.ExternalDataInfo(proto)
- return _core.ExternalTensor(
- path=os.path.join(base_path, external_info.location),
- offset=external_info.offset,
- length=external_info.length,
- dtype=_enums.DataType(proto.data_type),
- name=proto.name,
- shape=_core.Shape(proto.dims),
- doc_string=proto.doc_string,
- metadata_props=deserialize_metadata_props(proto.metadata_props),
- )
- if proto.data_type == _enums.DataType.STRING:
- name = _get_field(proto, "name")
- doc_string = _get_field(proto, "doc_string")
- metadata_props = deserialize_metadata_props(proto.metadata_props)
- return _core.StringTensor(
- proto.string_data,
- shape=_core.Shape(proto.dims),
- name=name,
- doc_string=doc_string,
- metadata_props=metadata_props,
- )
- return TensorProtoTensor(proto)
-
-
-def deserialize_metadata_props(
- proto: Sequence[onnx.StringStringEntryProto],
-) -> dict[str, str] | None:
- if len(proto) == 0:
- # Avoid creating an empty dictionary to save memory
- return None
- return {entry.key: entry.value for entry in proto}
-
-
-def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr:
- return _deserialize_attribute(proto, [])
-
-
-def _deserialize_attribute(
- proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]]
-) -> _core.Attr | _core.RefAttr:
- name = proto.name
- doc_string = _get_field(proto, "doc_string")
- type_ = _enums.AttributeType(proto.type)
- ref_attr_name = _get_field(proto, "ref_attr_name")
- if ref_attr_name:
- return _core.RefAttr(name, ref_attr_name, type_, doc_string=doc_string)
-
- if type_ == _enums.AttributeType.INT:
- return _core.AttrInt64(name, proto.i, doc_string=doc_string)
- if type_ == _enums.AttributeType.FLOAT:
- return _core.AttrFloat32(name, proto.f, doc_string=doc_string)
- if type_ == _enums.AttributeType.STRING:
- return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string)
- if type_ == _enums.AttributeType.INTS:
- return _core.AttrInt64s(name, proto.ints, doc_string=doc_string)
- if type_ == _enums.AttributeType.FLOATS:
- return _core.AttrFloat32s(name, proto.floats, doc_string=doc_string)
- if type_ == _enums.AttributeType.STRINGS:
- return _core.AttrStrings(
- name, [s.decode("utf-8") for s in proto.strings], doc_string=doc_string
- )
- if type_ == _enums.AttributeType.TENSOR:
- return _core.AttrTensor(name, deserialize_tensor(proto.t), doc_string=doc_string)
- if type_ == _enums.AttributeType.GRAPH:
- return _core.AttrGraph(
- name, _deserialize_graph(proto.g, scoped_values), doc_string=doc_string
- )
- if type_ == _enums.AttributeType.TENSORS:
- return _core.AttrTensors(
- name,
- [deserialize_tensor(t) for t in proto.tensors],
- doc_string=doc_string,
- )
- if type_ == _enums.AttributeType.GRAPHS:
- return _core.AttrGraphs(
- name,
- [_deserialize_graph(g, scoped_values) for g in proto.graphs],
- doc_string=doc_string,
- )
- if type_ == _enums.AttributeType.SPARSE_TENSOR:
- raise NotImplementedError("Sparse tensors are not supported yet")
- if type_ == _enums.AttributeType.SPARSE_TENSORS:
- raise NotImplementedError("Sparse tensors are not supported yet")
- if type_ == _enums.AttributeType.TYPE_PROTO:
- ir_type = deserialize_type_proto_for_type(proto.tp)
- shape = deserialize_type_proto_for_shape(proto.tp)
- return _core.AttrTypeProto(
- name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string
- )
- if type_ == _enums.AttributeType.TYPE_PROTOS:
- type_and_shapes = []
- for type_proto in proto.type_protos:
- ir_type = deserialize_type_proto_for_type(type_proto)
- shape = deserialize_type_proto_for_shape(type_proto)
- type_and_shapes.append(_core.TypeAndShape(ir_type, shape))
- return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string)
- if type_ == _enums.AttributeType.UNDEFINED:
- return _core.Attr(name, type_, None, doc_string=doc_string)
- raise ValueError(f"Unsupported attribute type: '{type_}'")
-
-
-def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
- return _deserialize_node(proto, scoped_values=[], value_info={})
-
-
-def _deserialize_node(
- proto: onnx.NodeProto,
- scoped_values: list[dict[str, _core.Value]],
- value_info: dict[str, onnx.ValueInfoProto],
-) -> _core.Node:
- node_inputs: list[_core.Value | None] = []
- for input_name in proto.input:
- if input_name == "":
- # Empty input
- node_inputs.append(None)
- continue
-
- # Find the input in all value scopes
- found = False
- for values in reversed(scoped_values):
- if input_name not in values:
- continue
- node_inputs.append(values[input_name])
- found = True
- del values # Remove the reference so it is not used by mistake
- break
- if not found:
- # If the input is not found, we know the graph may be unsorted and
- # the input may be a supposed-to-be initializer or an output of a node that comes later.
- # Here we create the value with the name and add it to the current scope.
- # Nodes need to check the value pool for potentially initialized outputs
- logger.warning(
- "Input '%s' of node '%s(%s::%s:%s)' not found in any scope. "
- "The graph may be unsorted. Creating a new input (current depth: %s) .",
- input_name,
- proto.name,
- proto.domain,
- proto.op_type,
- getattr(proto, "overload", ""),
- len(scoped_values),
- )
- if len(scoped_values) > 1:
- logger.warning(
- "Caveat: The value is created in the subgraph. If "
- "the node is referencing a value that is not in the current graph, "
- "it is impossible to create it in the correct scope.",
- )
- value = _core.Value(None, index=None, name=input_name)
- # Fill in shape/type information if they exist
- if input_name in value_info:
- deserialize_value_info_proto(value_info[input_name], value)
- node_inputs.append(value)
- # We can only create the value in the current scope. If the subgraph is
- # referencing a value that is not in the current scope, it is impossible
- # to create it in the correct scope.
- scoped_values[-1][input_name] = value
-
- # Build the output values for the node.
- node_outputs: list[_core.Value] = []
- for output_name in proto.output:
- if output_name == "":
- # Empty output
- node_outputs.append(_core.Value(None, index=None, name=""))
- continue
-
- # 1. When the graph is unsorted, we may be able to find the output already created
- # as an input to some other nodes in the current scope.
- # Note that a value is always owned by the producing node. Even though a value
- # can be created when parsing inputs of other nodes, the new node created here
- # that produces the value will assume ownership. It is then impossible to transfer
- # the ownership to any other node.
-
- # The output can only be found in the current scope. It is impossible for
- # a node to produce an output that is not in its own scope.
- current_scope = scoped_values[-1]
- if output_name in current_scope:
- value = current_scope[output_name]
- else:
- # 2. Common scenario: the graph is sorted and this is the first time we see the output.
- # Create the value and add it to the current scope.
- value = _core.Value(None, index=None, name=output_name)
- current_scope[output_name] = value
- # Fill in shape/type information if they exist
- if output_name in value_info:
- deserialize_value_info_proto(value_info[output_name], value)
- else:
- logger.debug(
- "ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
- output_name,
- proto.name,
- proto.op_type,
- )
- node_outputs.append(value)
- return _core.Node(
- proto.domain,
- proto.op_type,
- node_inputs,
- [_deserialize_attribute(a, scoped_values) for a in proto.attribute],
- overload=getattr(proto, "overload", ""),
- outputs=node_outputs,
- name=proto.name,
- doc_string=_get_field(proto, "doc_string"),
- metadata_props=deserialize_metadata_props(proto.metadata_props),
- )
-
-
-# Serialization
-
-
-def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto:
- return serialize_model_into(onnx.ModelProto(), from_=model)
-
-
-def serialize_model_into(
- model_proto: onnx.ModelProto, from_: _protocols.ModelProtocol
-) -> onnx.ModelProto:
- """Serialize an IR model to an ONNX model proto."""
- model_proto.ir_version = from_.ir_version
- if from_.producer_name:
- model_proto.producer_name = from_.producer_name
- if from_.producer_version:
- model_proto.producer_version = from_.producer_version
- if from_.domain:
- model_proto.domain = from_.domain
- if from_.model_version:
- model_proto.model_version = from_.model_version
- if from_.doc_string:
- model_proto.doc_string = from_.doc_string
- # Sort names for deterministic serialization
- _serialize_opset_imports_into(model_proto.opset_import, from_.opset_imports)
- if from_.metadata_props:
- _serialize_metadata_props_into(model_proto.metadata_props, from_.metadata_props)
- serialize_graph_into(model_proto.graph, from_.graph)
-
- create_value_info_in_functions = from_.ir_version >= _FUNCTION_VALUE_INFO_SUPPORTED_VERSION
- for func in from_.functions.values():
- serialize_function_into(
- model_proto.functions.add(),
- from_=func,
- create_value_info=create_value_info_in_functions,
- )
- if not create_value_info_in_functions:
- # Create them in the main graph instead
- _serialize_experimental_value_info_for_function_ir9_into(model_proto.graph, func)
- return model_proto
-
-
-def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool:
- """Check if value info should be created for a value.
-
- Args:
- value: The value to check.
-
- Returns:
- True if value info should be created for the value.
- """
- # No need to serialize value info if it is not set
- return not (value.shape is None and value.type is None)
-
-
-def _serialize_experimental_value_info_for_function_ir9_into(
- graph_proto: onnx.GraphProto, function: _protocols.FunctionProtocol
-) -> None:
- """Serialize value info for functions in an experimental format for IR version 9.
-
- Because IRv9 and older does not have ValueInfoProto for functions, we give the value info
- special names and store them in the main graph instead.
-
- The experimental format is:
- {function_domain}::{function_name}/{value_name}
-
- Args:
- graph_proto: The graph proto to create ValueInfoProto in.
- function: The function to serialize.
- """
- # TODO(justinchuby): In the future, we can decide if it is a good idea to simply iterate over
- # all values in the function and call serialize_value_into instead.
- function_qualified_name = f"{function.domain}::{function.name}"
-
- def format_name(value_name: str) -> str:
- return f"{function_qualified_name}/{value_name}"
-
- for input in function.inputs:
- if not input.name:
- logging.warning(
- "Function '%s': Value name not set for function input: %s",
- function_qualified_name,
- input,
- )
- continue
- if not _should_create_value_info_for_value(input):
- # No need to serialize value info if it is not set
- continue
- serialize_value_into(graph_proto.value_info.add(), input, name=format_name(input.name))
- for node in function:
- for node_output in node.outputs:
- if not node_output.name:
- logging.warning(
- "Function '%s': Value name not set for node output: %s",
- function_qualified_name,
- node_output,
- )
- continue
- if not _should_create_value_info_for_value(node_output):
- # No need to serialize value info if it is not set
- continue
- serialize_value_into(
- graph_proto.value_info.add(),
- node_output,
- name=format_name(node_output.name),
- )
-
-
-def _serialize_opset_imports_into(
- opset_ids: proto_containers.RepeatedCompositeFieldContainer[onnx.OperatorSetIdProto],
- from_: Mapping[str, int],
-) -> None:
- """Serialize opset imports into a repeated field of OperatorSetId protos.
-
- Args:
- opset_ids: The repeated field to serialize into.
- from_: The mapping of opset domains to versions to serialize.
- """
- # Sort names for deterministic serialization
- for domain, version in from_.items():
- opset_ids.add(domain=domain, version=version)
-
-
-def _serialize_metadata_props_into(
- string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
- onnx.StringStringEntryProto
- ],
- from_: Mapping[str, str],
-) -> None:
- """Serialize metadata properties into a repeated field of string-string entries.
-
- Args:
- string_string_entries: The repeated field to serialize into.
- from_: The mapping of metadata properties to serialize.
- """
- # Sort names for deterministic serialization
- for key in sorted(from_):
- string_string_entries.add(key=key, value=from_[key])
-
-
-def serialize_graph(
- graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
-) -> onnx.GraphProto:
- graph_proto = onnx.GraphProto()
- serialize_graph_into(graph_proto, from_=graph)
- return graph_proto
-
-
-def serialize_graph_into(
- graph_proto: onnx.GraphProto,
- from_: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
-) -> None:
- if from_.name:
- graph_proto.name = from_.name
- if from_.doc_string:
- graph_proto.doc_string = from_.doc_string
- for input_ in from_.inputs:
- serialize_value_into(graph_proto.input.add(), input_)
- # TODO(justinchuby): Support sparse_initializer
- for initializer in from_.initializers.values():
- serialize_tensor_into(graph_proto.initializer.add(), from_=initializer)
- for node in from_:
- serialize_node_into(graph_proto.node.add(), from_=node)
- for node_output in node.outputs:
- if not _should_create_value_info_for_value(node_output):
- # No need to serialize value info if it is not set
- continue
- if node_output.is_graph_output():
- # No need to serialize value info for these outputs because they are also graph outputs
- continue
- serialize_value_into(graph_proto.value_info.add(), node_output)
- for output in from_.outputs:
- serialize_value_into(graph_proto.output.add(), from_=output)
- if from_.metadata_props:
- _serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props)
-
-
-def serialize_function(
- function: _protocols.FunctionProtocol, *, create_value_info: bool = True
-) -> onnx.FunctionProto:
- """Serialize an IR function as a FunctionProto.
-
- Args:
- function: The function to serialize.
- create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported
- starting from ONNX IR version 10.
- """
- function_proto = onnx.FunctionProto()
- serialize_function_into(
- function_proto, from_=function, create_value_info=create_value_info
- )
- return function_proto
-
-
-def serialize_function_into(
- function_proto: onnx.FunctionProto,
- from_: _protocols.FunctionProtocol,
- *,
- create_value_info: bool = True,
-) -> None:
- """Serialize an IR function into a FunctionProto.
-
- Args:
- function_proto: The proto to serialize into.
- from_: The function to serialize.
- create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported
- starting from ONNX IR version 10.
- """
- if from_.domain:
- function_proto.domain = from_.domain
- if from_.name:
- function_proto.name = from_.name
- if from_.overload:
- function_proto.overload = from_.overload
- if from_.doc_string:
- function_proto.doc_string = from_.doc_string
- if from_.opset_imports:
- # A valid ONNX graph should have at least one opset import, that is
- # the default ONNX opset.
- # Here we check for emptiness before serializing to keep the logic consistent
- _serialize_opset_imports_into(function_proto.opset_import, from_.opset_imports)
- if from_.metadata_props:
- _serialize_metadata_props_into(function_proto.metadata_props, from_.metadata_props)
- for input_ in from_.inputs:
- function_proto.input.append(input_.name)
- if not _should_create_value_info_for_value(input_):
- # No need to serialize value info if it is not set
- continue
- if not create_value_info:
- continue
- serialize_value_into(function_proto.value_info.add(), input_)
- for attr in from_.attributes.values():
- if attr.value is not None:
- serialize_attribute_into(function_proto.attribute_proto.add(), from_=attr)
- else:
- # ONNX does not record type information if the attribute does not have a default
- function_proto.attribute.append(attr.name)
- for func_output in from_.outputs:
- function_proto.output.append(func_output.name)
- # No need to serialize value info for function outputs because they are
- # also node outputs
- for node in from_:
- serialize_node_into(function_proto.node.add(), from_=node)
- # Record value info for outputs
- for node_output in node.outputs:
- if not _should_create_value_info_for_value(node_output):
- # No need to serialize value info if it is not set
- continue
- if not create_value_info:
- continue
- serialize_value_into(function_proto.value_info.add(), node_output)
-
-
-def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
- node_proto = onnx.NodeProto()
- serialize_node_into(node_proto, from_=node)
- return node_proto
-
-
-def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None:
- node_proto.op_type = from_.op_type
- if from_.domain:
- # If the domain is "", we can assume the default domain and not set it
- node_proto.domain = from_.domain
- if from_.name:
- node_proto.name = from_.name
- if from_.overload:
- node_proto.overload = from_.overload
- if from_.doc_string:
- node_proto.doc_string = from_.doc_string
- if from_.metadata_props:
- _serialize_metadata_props_into(node_proto.metadata_props, from_.metadata_props)
- for input_ in from_.inputs:
- if input_ is None:
- node_proto.input.append("")
- else:
- node_proto.input.append(input_.name)
- for output in from_.outputs:
- node_proto.output.append(output.name)
- for attr in from_.attributes.values():
- if isinstance(attr, _core.Attr):
- serialize_attribute_into(node_proto.attribute.add(), from_=attr)
- elif isinstance(attr, _core.RefAttr):
- serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
- # Handle protocol attributes for completeness. We do not check them first because
- # calling isinstance on a protocol can be slow.
- # Most of the time, we will have Attr or RefAttr so the two branches below
- # will not be taken.
- elif isinstance(attr, _protocols.AttributeProtocol):
- serialize_attribute_into(node_proto.attribute.add(), from_=attr)
- elif isinstance(attr, _protocols.ReferenceAttributeProtocol):
- serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
- else:
- raise TypeError(f"Unsupported attribute type: {type(attr)}")
-
-
-def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto:
- tensor_proto = onnx.TensorProto()
- serialize_tensor_into(tensor_proto, from_=tensor)
- return tensor_proto
-
-
-def serialize_tensor_into(
- tensor_proto: onnx.TensorProto, from_: _protocols.TensorProtocol
-) -> None:
- if isinstance(from_, TensorProtoTensor):
- # Directly copy from the tensor proto if it is available
- tensor_proto.CopyFrom(from_.raw)
- if from_.metadata_props:
- _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
- return
-
- tensor_proto.name = from_.name
- if from_.doc_string:
- tensor_proto.doc_string = from_.doc_string
- tensor_proto.data_type = from_.dtype.value
- tensor_proto.dims.extend(from_.shape.numpy())
- if isinstance(from_, _core.ExternalTensor):
- # Store external tensors as is
- tensor_proto.data_location = onnx.TensorProto.EXTERNAL
- for k, v in {
- "location": os.fspath(from_.path),
- "offset": from_.offset,
- "length": from_.length,
- }.items():
- if v is not None:
- entry = tensor_proto.external_data.add()
- entry.key = k
- entry.value = str(v)
- elif isinstance(from_, _core.StringTensor):
- tensor_proto.string_data.extend(from_.string_data())
- else:
- tensor_proto.raw_data = from_.tobytes()
- _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
-
-
-def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.AttributeProto:
- attribute_proto = onnx.AttributeProto()
- serialize_attribute_into(attribute_proto, from_=attribute)
- return attribute_proto
-
-
-def serialize_attribute_into(
- attribute_proto: onnx.AttributeProto, from_: _protocols.AttributeProtocol
-) -> None:
- attribute_proto.name = from_.name
- if from_.doc_string:
- attribute_proto.doc_string = from_.doc_string
- _fill_in_value_for_attribute(attribute_proto, from_.type, from_.value)
-
-
-def _fill_in_value_for_attribute(
- attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any
-) -> None:
- if type_ == _enums.AttributeType.INT:
- # value: int
- attribute_proto.i = value
- attribute_proto.type = onnx.AttributeProto.INT
- elif type_ == _enums.AttributeType.FLOAT:
- # value: float
- attribute_proto.f = value
- attribute_proto.type = onnx.AttributeProto.FLOAT
- elif type_ == _enums.AttributeType.STRING:
- # value: str
- attribute_proto.s = value.encode("utf-8")
- attribute_proto.type = onnx.AttributeProto.STRING
- elif type_ == _enums.AttributeType.INTS:
- # value: Sequence[int]
- attribute_proto.ints.extend(value)
- attribute_proto.type = onnx.AttributeProto.INTS
- elif type_ == _enums.AttributeType.FLOATS:
- # value: Sequence[float]
- attribute_proto.floats.extend(value)
- attribute_proto.type = onnx.AttributeProto.FLOATS
- elif type_ == _enums.AttributeType.STRINGS:
- # value: Sequence[str]
- attribute_proto.strings.extend([s.encode("utf-8") for s in value])
- attribute_proto.type = onnx.AttributeProto.STRINGS
- elif type_ == _enums.AttributeType.TENSOR:
- # value: _protocols.TensorProtocol
- serialize_tensor_into(attribute_proto.t, value)
- attribute_proto.type = onnx.AttributeProto.TENSOR
- elif type_ == _enums.AttributeType.GRAPH:
- # value: _protocols.GraphProtocol
- serialize_graph_into(attribute_proto.g, value)
- attribute_proto.type = onnx.AttributeProto.GRAPH
- elif type_ == _enums.AttributeType.TENSORS:
- # value: Sequence[_protocols.TensorProtocol]
- for tensor in value:
- serialize_tensor_into(attribute_proto.tensors.add(), tensor)
- attribute_proto.type = onnx.AttributeProto.TENSORS
- elif type_ == _enums.AttributeType.GRAPHS:
- # value: Sequence[_protocols.GraphProtocol]
- for graph in value:
- serialize_graph_into(attribute_proto.graphs.add(), graph)
- attribute_proto.type = onnx.AttributeProto.GRAPHS
- elif type_ == _enums.AttributeType.SPARSE_TENSOR:
- raise NotImplementedError("Sparse tensors are not supported yet")
- elif type_ == _enums.AttributeType.SPARSE_TENSORS:
- raise NotImplementedError("Sparse tensors are not supported yet")
- elif type_ == _enums.AttributeType.TYPE_PROTO:
- # value: _core.TypeAndShape
- if value.type is not None:
- serialize_type_into(attribute_proto.tp, value.type)
- # Need to create the type _before_ writing the shape
- if value.shape is not None:
- serialize_shape_into(attribute_proto.tp, value.shape)
- attribute_proto.type = onnx.AttributeProto.TYPE_PROTO
- elif type_ == _enums.AttributeType.TYPE_PROTOS:
- for ir_type in value:
- # ir_type: _core.TypeAndShape
- type_proto = attribute_proto.type_protos.add()
- if ir_type.type is not None:
- serialize_type_into(type_proto, ir_type.type)
- # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
- if ir_type.shape is not None:
- serialize_shape_into(type_proto, ir_type.shape)
- attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS
- else:
- raise TypeError(f"Unsupported attribute type: {type_}")
-
-
-def serialize_reference_attribute_into(
- attribute_proto: onnx.AttributeProto, from_: _protocols.ReferenceAttributeProtocol
-) -> None:
- attribute_proto.name = from_.name
- attribute_proto.ref_attr_name = from_.ref_attr_name
- if from_.doc_string:
- attribute_proto.doc_string = from_.doc_string
- attribute_proto.type = typing.cast(onnx.AttributeProto.AttributeType, from_.type.value)
-
-
-def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.ValueInfoProto:
- """Serialize a value into a ValueInfoProto.
-
- Args:
- value: The proto to serialize into.
- from_: The value to serialize.
- name: A custom name to set for the value info. If not provided, the name from the value will be used.
- """
- value_info_proto = onnx.ValueInfoProto()
- serialize_value_into(value_info_proto, value, name=name)
- return value_info_proto
-
-
-def serialize_value_into(
- value_info_proto: onnx.ValueInfoProto,
- from_: _protocols.ValueProtocol,
- *,
- name: str = "",
-) -> None:
- """Serialize a value into a ValueInfoProto.
-
- Args:
- value_info_proto: The proto to serialize into.
- from_: The value to serialize.
- name: A custom name to set for the value info. If not provided, the name from the value will be used.
- """
- if name:
- value_info_proto.name = name
- else:
- value_info_proto.name = from_.name
- if from_.metadata_props:
- _serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props)
- if from_.type is not None:
- serialize_type_into(value_info_proto.type, from_.type)
- # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
- if from_.shape is not None:
- serialize_shape_into(value_info_proto.type, from_.shape)
- if from_.doc_string:
- value_info_proto.doc_string = from_.doc_string
-
-
-def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None:
- if from_.denotation:
- type_proto.denotation = from_.denotation
- if isinstance(from_, _core.TensorType):
- tensor_type_proto = type_proto.tensor_type
- tensor_type_proto.elem_type = from_.dtype.value
- elif isinstance(from_, _core.SparseTensorType):
- sparse_tensor_type_proto = type_proto.sparse_tensor_type
- sparse_tensor_type_proto.elem_type = from_.dtype.value
- elif isinstance(from_, _core.SequenceType):
- sequence_type_proto = type_proto.sequence_type
- serialize_type_into(sequence_type_proto.elem_type, from_.elem_type)
- elif isinstance(from_, _core.OptionalType):
- optional_type_proto = type_proto.optional_type
- serialize_type_into(optional_type_proto.elem_type, from_.elem_type)
- else:
- raise TypeError(f"Unsupported type: {from_}")
-
-
-def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
- value_field = type_proto.WhichOneof("value")
- tensor_type = getattr(type_proto, value_field)
- while not isinstance(tensor_type.elem_type, int):
- # Find the leaf type that has the shape field
- type_proto = tensor_type.elem_type
- value_field = type_proto.WhichOneof("value")
- tensor_type = getattr(type_proto, value_field)
- # When from is empty, we still need to set the shape field to an empty list by touching it
- tensor_type.shape.ClearField("dim")
- for i, dim in enumerate(from_):
- denotation = from_.get_denotation(i)
- serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation)
-
-
-def serialize_dimension_into(
- dim_proto: onnx.TensorShapeProto.Dimension,
- dim: int | _protocols.SymbolicDimProtocol,
- denotation: str | None = None,
-) -> None:
- if denotation:
- dim_proto.denotation = denotation
- if isinstance(dim, int):
- dim_proto.dim_value = dim
- elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)):
- if dim.value is not None:
- # TODO(justinchuby): None is probably not a valid value for dim_param
- dim_proto.dim_param = str(dim.value)
diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py
deleted file mode 100644
index 64512d9066..0000000000
--- a/onnxscript/ir/serde_test.py
+++ /dev/null
@@ -1,211 +0,0 @@
-import unittest
-from typing import Callable
-
-import numpy as np
-import onnx
-import parameterized
-
-from onnxscript import ir
-from onnxscript.ir import serde
-
-
-class TensorProtoTensorTest(unittest.TestCase):
- @parameterized.parameterized.expand(
- [
- ("FLOAT", onnx.TensorProto.FLOAT),
- ("BOOL", onnx.TensorProto.BOOL),
- ("FLOAT16", onnx.TensorProto.FLOAT16),
- ("DOUBLE", onnx.TensorProto.DOUBLE),
- ]
- )
- def test_tensor_proto_tensor(self, _: str, dtype: int):
- tensor_proto = onnx.helper.make_tensor(
- "test_tensor", dtype, [1, 9], [-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]
- )
- tensor = serde.TensorProtoTensor(tensor_proto)
- expected_array = onnx.numpy_helper.to_array(tensor_proto)
- np.testing.assert_array_equal(tensor.numpy(), expected_array)
- raw_data = tensor.tobytes()
- tensor_proto_from_raw_data = onnx.TensorProto(
- dims=tensor_proto.dims,
- data_type=tensor_proto.data_type,
- raw_data=raw_data,
- )
- array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
- np.testing.assert_array_equal(array_from_raw_data, expected_array)
-
- def test_tensor_proto_tensor_bfloat16(self):
- expected_array = np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]])
- tensor_proto = onnx.helper.make_tensor(
- "test_tensor", onnx.TensorProto.BFLOAT16, [1, 9], expected_array
- )
- tensor = serde.TensorProtoTensor(tensor_proto)
- np.testing.assert_array_equal(
- onnx.numpy_helper.bfloat16_to_float32(tensor.numpy()), expected_array
- )
- raw_data = tensor.tobytes()
- tensor_proto_from_raw_data = onnx.TensorProto(
- dims=tensor_proto.dims,
- data_type=tensor_proto.data_type,
- raw_data=raw_data,
- )
- array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
- np.testing.assert_array_equal(array_from_raw_data, expected_array)
-
- @parameterized.parameterized.expand(
- [
- (
- "FLOAT8E4M3FN",
- onnx.TensorProto.FLOAT8E4M3FN,
- lambda x: onnx.numpy_helper.float8e4m3_to_float32(x, fn=True),
- ),
- (
- "FLOAT8E4M3FNUZ",
- onnx.TensorProto.FLOAT8E4M3FNUZ,
- lambda x: onnx.numpy_helper.float8e4m3_to_float32(x, fn=True, uz=True),
- ),
- (
- "FLOAT8E5M2",
- onnx.TensorProto.FLOAT8E5M2,
- onnx.numpy_helper.float8e5m2_to_float32,
- ),
- (
- "FLOAT8E5M2FNUZ",
- onnx.TensorProto.FLOAT8E5M2FNUZ,
- lambda x: onnx.numpy_helper.float8e5m2_to_float32(x, fn=True, uz=True),
- ),
- ]
- )
- def test_tensor_proto_tensor_float8(self, _: str, dtype: int, to_float32_func: Callable):
- expected_array = np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]])
- tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 9], expected_array)
- tensor = serde.TensorProtoTensor(tensor_proto)
- np.testing.assert_array_equal(to_float32_func(tensor.numpy()), expected_array)
- raw_data = tensor.tobytes()
- tensor_proto_from_raw_data = onnx.TensorProto(
- dims=tensor_proto.dims,
- data_type=tensor_proto.data_type,
- raw_data=raw_data,
- )
- if dtype in (onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FNUZ):
- # TODO: Remove the fix when ONNX 1.17 releases
- self.skipTest("ONNX to_array fails: https://github.com/onnx/onnx/pull/6124")
- array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
- np.testing.assert_array_equal(array_from_raw_data, expected_array)
-
- @parameterized.parameterized.expand(
- [
- ("INT8", onnx.TensorProto.INT8),
- ("INT16", onnx.TensorProto.INT16),
- ("INT32", onnx.TensorProto.INT32),
- ("INT64", onnx.TensorProto.INT64),
- ("INT4", onnx.TensorProto.INT4),
- ]
- )
- def test_tensor_proto_tensor_int(self, _: str, dtype: int):
- tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 4], [-1, 0, 1, 8])
- tensor = serde.TensorProtoTensor(tensor_proto)
- expected_array = onnx.numpy_helper.to_array(
- tensor_proto
- ) # [-1, 0, 1, 7], 8 is clamped to 7
- np.testing.assert_array_equal(tensor.numpy(), expected_array)
- raw_data = tensor.tobytes()
- tensor_proto_from_raw_data = onnx.TensorProto(
- dims=tensor_proto.dims,
- data_type=tensor_proto.data_type,
- raw_data=raw_data,
- )
- array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
- np.testing.assert_array_equal(array_from_raw_data, expected_array)
-
- @parameterized.parameterized.expand(
- [
- ("UINT8", onnx.TensorProto.UINT8),
- ("UINT16", onnx.TensorProto.UINT16),
- ("UINT32", onnx.TensorProto.UINT32),
- ("UINT64", onnx.TensorProto.UINT64),
- ("UINT4", onnx.TensorProto.UINT4),
- ]
- )
- def test_tensor_proto_tensor_uint(self, _: str, dtype: int):
- tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 3], [0, 1, 8])
- tensor = serde.TensorProtoTensor(tensor_proto)
- expected_array = onnx.numpy_helper.to_array(tensor_proto)
- np.testing.assert_array_equal(tensor.numpy(), expected_array)
- raw_data = tensor.tobytes()
- tensor_proto_from_raw_data = onnx.TensorProto(
- dims=tensor_proto.dims,
- data_type=tensor_proto.data_type,
- raw_data=raw_data,
- )
- array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
- np.testing.assert_array_equal(array_from_raw_data, expected_array)
-
- @parameterized.parameterized.expand(
- [
- ("COMPLEX64", onnx.TensorProto.COMPLEX64, np.complex64),
- ("COMPLEX128", onnx.TensorProto.COMPLEX128, np.complex128),
- ]
- )
- def test_tensor_proto_tensor_complex(self, _: str, dtype: int, np_dtype: np.dtype):
- expected_array = np.array([[0.0 + 1j, 0.2 - 1j, 0.3]], dtype=np_dtype)
- tensor_proto = onnx.helper.make_tensor(
- "test_tensor", dtype, [1, 3], [0.0 + 1j, 0.2 - 1j, 0.3]
- )
- tensor = serde.TensorProtoTensor(tensor_proto)
- np.testing.assert_array_equal(tensor.numpy(), expected_array)
- raw_data = tensor.tobytes()
- tensor_proto_from_raw_data = onnx.TensorProto(
- dims=tensor_proto.dims,
- data_type=tensor_proto.data_type,
- raw_data=raw_data,
- )
- array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
- np.testing.assert_array_equal(array_from_raw_data, expected_array)
-
- def test_tensor_proto_tensor_empty_tensor(self):
- tensor_proto = onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [0], [])
- tensor = serde.TensorProtoTensor(tensor_proto)
- expected_array = onnx.numpy_helper.to_array(tensor_proto)
- np.testing.assert_array_equal(tensor.numpy(), expected_array)
- raw_data = tensor.tobytes()
- tensor_proto_from_raw_data = onnx.TensorProto(
- dims=tensor_proto.dims,
- data_type=tensor_proto.data_type,
- raw_data=raw_data,
- )
- array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
- np.testing.assert_array_equal(array_from_raw_data, expected_array)
-
-
-class DeserializeGraphTest(unittest.TestCase):
- def test_deserialize_graph_handles_unsorted_graph(self):
- node_0 = ir.Node(
- "",
- "Op_0",
- inputs=[ir.Input("input_0"), ir.Input("input_1")],
- num_outputs=2,
- name="node_0",
- )
- node_1 = ir.Node(
- "",
- "Op_1",
- inputs=[node_0.outputs[0]],
- num_outputs=1,
- name="node_1",
- )
- graph = ir.Graph(
- inputs=node_0.inputs, # type: ignore
- outputs=[node_1.outputs[0]],
- # Unsorted nodes
- nodes=[node_1, node_0],
- name="test_graph",
- )
- graph_proto = serde.serialize_graph(graph)
- deserialized_graph = serde.deserialize_graph(graph_proto)
- self.assertEqual(deserialized_graph[0].op_type, "Op_1")
- self.assertEqual(deserialized_graph[1].op_type, "Op_0")
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py
index 3940ba9297..76023ea002 100644
--- a/onnxscript/irbuilder.py
+++ b/onnxscript/irbuilder.py
@@ -1,7 +1,6 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
+# ruff: noqa: TID251
from __future__ import annotations
import dataclasses
@@ -215,7 +214,7 @@ def __str__(self):
def debug_print(self):
if logger.isEnabledFor(logging.DEBUG):
- logger.debug("%s: %s", type(self), str(self))
+ logger.debug("%s: %s", type(self), self)
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
n = helper.make_node(
@@ -321,6 +320,7 @@ def to_model_proto(
io_types: Optional[ONNXType] = None,
input_types: Optional[Sequence[ONNXType]] = None,
output_types: Optional[Sequence[ONNXType]] = None,
+ value_infos: dict[str, ONNXType] | None = None,
**kwargs,
) -> onnx.ModelProto:
"""Converts this instance into a `onnx.ModelProto`.
@@ -334,12 +334,24 @@ def to_model_proto(
are set to be of the corresponding type in this list.
output_types: When specified, all the outputs of the model
are set to be of the corresponding type in this list.
+ value_infos: A dictionary mapping intermediate variable names to ONNX types.
+ Used to set value_info for intermediate variables.
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
Returns:
An instance of :class:`onnx.ModelProto`.
"""
- graph, sub_functions = self.to_graph_and_functions(use_default_type=False)
+ value_infos = (
+ [
+ onnx.helper.make_value_info(name, type.to_type_proto())
+ for name, type in value_infos.items()
+ ]
+ if value_infos
+ else None
+ )
+ graph, sub_functions = self.to_graph_and_functions(
+ use_default_type=False, value_infos=value_infos
+ )
if io_types is not None:
for input in graph.input:
if not input.HasField("type"):
@@ -370,13 +382,19 @@ def to_proto(f):
for n in self.stmts:
if n.callee.opset.domain not in opsets:
opsets[n.callee.opset.domain] = n.callee.opset.version
+
+ for proto in functions:
+ if proto.domain not in opsets:
+ opsets[proto.domain] = 1
+ # TODO(rama): Handle conflicts with appropriate error/warning message.
+ for opset in proto.opset_import:
+ if opset.domain not in opsets:
+ opsets[opset.domain] = opset.version
+
if "" not in opsets:
# No operator is using the standard opset.
# A default value is given.
opsets[""] = onnx_opset_version()
- for proto in functions:
- if proto.domain not in opsets:
- opsets[proto.domain] = 1
if "ir_version" not in kwargs:
kwargs["ir_version"] = select_ir_version(opsets[""])
@@ -389,7 +407,9 @@ def to_proto(f):
)
def to_graph_and_functions(
- self, use_default_type: bool = True
+ self,
+ use_default_type: bool = True,
+ value_infos: Sequence[ValueInfoProto] | None = None,
) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
"""Converts this instance into a `onnx.GraphProto` and a map from
function-name to `onnx.FunctionProto`.
@@ -397,6 +417,8 @@ def to_graph_and_functions(
Args:
use_default_type: if True, the function uses a default type
for inputs and outputs that do not have a type
+ value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added
+ to the graph.
Returns:
a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto`
@@ -410,6 +432,7 @@ def to_graph_and_functions(
self.name,
[x.to_value_info(use_default_type) for x in self.inputs],
[y.to_value_info(use_default_type) for y in self.outputs],
+ value_info=value_infos,
)
return graph, called_functions
diff --git a/onnxscript/main.py b/onnxscript/main.py
index 51c180e275..3ea3e50f90 100644
--- a/onnxscript/main.py
+++ b/onnxscript/main.py
@@ -1,22 +1,22 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
# pylint disable: protected-access
from __future__ import annotations
import ast
import inspect
import sys
-import types
-from typing import Any, Callable, Optional, Sequence
+from typing import Any, Callable, Optional, Sequence, TypeVar
-import onnx.helper
+from typing_extensions import ParamSpec
import onnxscript
-from onnxscript import converter, irbuilder, values
+from onnxscript import converter, ir, irbuilder, values
from onnxscript._internal import ast_utils
+_R = TypeVar("_R")
+_P = ParamSpec("_P")
+
def script_check(
f: ast.FunctionDef,
@@ -42,7 +42,7 @@ def script(
opset: Optional[values.Opset] = None,
default_opset: Optional[values.Opset] = None,
**kwargs: Any,
-) -> Callable[[types.FunctionType], onnxscript.OnnxFunction]:
+) -> Callable[[Callable[_P, _R]], onnxscript.OnnxFunction[_P, _R]]:
"""Main decorator. Declares a function as an onnx function.
Args:
@@ -78,7 +78,7 @@ def log2(x):
"Script parameter must be an opset. Did you use @script instead of @script()?"
)
- def transform(f: types.FunctionType) -> onnxscript.OnnxFunction:
+ def transform(f: Callable[_P, _R]) -> onnxscript.OnnxFunction[_P, _R]:
if not inspect.isfunction(f):
raise TypeError("The ONNXScript decorator should be applied to functions only.")
@@ -98,7 +98,7 @@ def transform(f: types.FunctionType) -> onnxscript.OnnxFunction:
return transform
-def graph() -> Callable[[types.FunctionType], values.OnnxClosure]:
+def graph() -> Callable[[Callable], values.OnnxClosure]:
"""A parametric decorator used to annotate nested-functions that are used
as graph-attributes.
@@ -145,7 +145,7 @@ def Sum(sum_in, next):
onnx_function = wrapper_frame.f_locals["self"]
nested_functions = onnx_function.function_ir.nested_functions
- def transform(f: types.FunctionType) -> values.OnnxClosure:
+ def transform(f: Callable) -> values.OnnxClosure:
return values.OnnxClosure(nested_functions[f.__name__], function_frame, f)
return transform
@@ -160,11 +160,17 @@ def export_onnx_lib(functions: Sequence[values.OnnxFunction], filename: str) ->
# Since we don't yet have LibProto defined, we use a ModelProto as a temporary
# container for the list of functions exported as a library, with an empty graph
# and dummy opset_imports.
- model = onnx.helper.make_model(
- onnx.GraphProto(),
- functions=[f.to_function_proto() for f in functions],
+
+ # TODO(justinchuby): This function is not well supported. We should consider removing it
+ model = ir.Model(
+ ir.Graph(
+ inputs=[],
+ outputs=[],
+ nodes=[],
+ opset_imports={"": 15},
+ ),
+ functions=[ir.serde.deserialize_function(f.to_function_proto()) for f in functions],
+ ir_version=10,
producer_name="p2o",
- opset_imports=[onnx.helper.make_opsetid("", 15)],
)
-
- onnx.save(model, filename)
+ ir.save(model, filename)
diff --git a/onnxscript/onnx_opset/__init__.py b/onnxscript/onnx_opset/__init__.py
index c84d95c0cd..9b6ed0915c 100644
--- a/onnxscript/onnx_opset/__init__.py
+++ b/onnxscript/onnx_opset/__init__.py
@@ -2,13 +2,11 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -37,13 +35,15 @@
from onnxscript.onnx_opset._impl.opset18 import Opset18
from onnxscript.onnx_opset._impl.opset19 import Opset19
from onnxscript.onnx_opset._impl.opset20 import Opset20
+from onnxscript.onnx_opset._impl.opset21 import Opset21
+from onnxscript.onnx_opset._impl.opset22 import Opset22
+from onnxscript.onnx_opset._impl.opset23 import Opset23
+from onnxscript.onnx_opset._impl.opset24 import Opset24
from onnxscript.onnx_opset._impl.opset_ai_onnx_ml1 import Opset_ai_onnx_ml1
from onnxscript.onnx_opset._impl.opset_ai_onnx_ml2 import Opset_ai_onnx_ml2
from onnxscript.onnx_opset._impl.opset_ai_onnx_ml3 import Opset_ai_onnx_ml3
from onnxscript.onnx_opset._impl.opset_ai_onnx_ml4 import Opset_ai_onnx_ml4
-from onnxscript.onnx_opset._impl.opset_ai_onnx_preview_training1 import (
- Opset_ai_onnx_preview_training1,
-)
+from onnxscript.onnx_opset._impl.opset_ai_onnx_ml5 import Opset_ai_onnx_ml5
from onnxscript.values import Opset
__all__ = [
@@ -68,11 +68,15 @@
"opset18",
"opset19",
"opset20",
+ "opset21",
+ "opset22",
+ "opset23",
+ "opset24",
"opset_ai_onnx_ml1",
"opset_ai_onnx_ml2",
"opset_ai_onnx_ml3",
"opset_ai_onnx_ml4",
- "opset_ai_onnx_preview_training1",
+ "opset_ai_onnx_ml5",
]
@@ -102,11 +106,15 @@
opset18 = Opset18()
opset19 = Opset19()
opset20 = Opset20()
+opset21 = Opset21()
+opset22 = Opset22()
+opset23 = Opset23()
+opset24 = Opset24()
opset_ai_onnx_ml1 = Opset_ai_onnx_ml1()
opset_ai_onnx_ml2 = Opset_ai_onnx_ml2()
opset_ai_onnx_ml3 = Opset_ai_onnx_ml3()
opset_ai_onnx_ml4 = Opset_ai_onnx_ml4()
-opset_ai_onnx_preview_training1 = Opset_ai_onnx_preview_training1()
+opset_ai_onnx_ml5 = Opset_ai_onnx_ml5()
all_opsets: Mapping[Tuple[str, int], Opset] = {
(
"",
@@ -188,6 +196,22 @@
"",
20,
): opset20,
+ (
+ "",
+ 21,
+ ): opset21,
+ (
+ "",
+ 22,
+ ): opset22,
+ (
+ "",
+ 23,
+ ): opset23,
+ (
+ "",
+ 24,
+ ): opset24,
(
"ai.onnx.ml",
1,
@@ -205,7 +229,7 @@
4,
): opset_ai_onnx_ml4,
(
- "ai.onnx.preview.training",
- 1,
- ): opset_ai_onnx_preview_training1,
+ "ai.onnx.ml",
+ 5,
+ ): opset_ai_onnx_ml5,
}
diff --git a/onnxscript/onnx_opset/_impl/opset1.py b/onnxscript/onnx_opset/_impl/opset1.py
index 756cc5a150..4af313184d 100644
--- a/onnxscript/onnx_opset/_impl/opset1.py
+++ b/onnxscript/onnx_opset/_impl/opset1.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D214, D402, D405, D411, D416, D417
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -398,7 +397,18 @@ def BatchNormalization(
)
T2_Cast: TypeAlias = Union[
- BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
]
def Cast(self, input: T1_Cast, *, to: str) -> T2_Cast:
@@ -837,7 +847,11 @@ def Dropout(
T_Elu = TypeVar("T_Elu", DOUBLE, FLOAT, FLOAT16)
def Elu(
- self, X: T_Elu, *, alpha: float = 1.0, consumed_inputs: Optional[Sequence[int]] = None
+ self,
+ X: T_Elu,
+ *,
+ alpha: float = 1.0,
+ consumed_inputs: Optional[Sequence[int]] = None,
) -> T_Elu:
r"""[🌐 Elu(1)](https://onnx.ai/onnx/operators/onnx__Elu.html#elu-1 "Online Documentation")
@@ -849,7 +863,7 @@ def Elu(
Args:
- X: 1D input tensor
+ X: Input tensor
alpha: Coefficient of ELU default to 1.0.
@@ -859,7 +873,9 @@ def Elu(
schema = get_schema("Elu", 1, "")
op = Op(self, "Elu", schema)
return op(
- *self._prepare_inputs(schema, X), alpha=alpha, consumed_inputs=consumed_inputs
+ *self._prepare_inputs(schema, X),
+ alpha=alpha,
+ consumed_inputs=consumed_inputs,
)
T_Equal = TypeVar("T_Equal", BOOL, INT32, INT64)
@@ -1338,7 +1354,12 @@ def GlobalMaxPool(self, X: T_GlobalMaxPool) -> T_GlobalMaxPool:
T1_Greater: TypeAlias = BOOL
def Greater(
- self, A: T_Greater, B: T_Greater, *, axis: Optional[int] = None, broadcast: int = 0
+ self,
+ A: T_Greater,
+ B: T_Greater,
+ *,
+ axis: Optional[int] = None,
+ broadcast: int = 0,
) -> T1_Greater:
r"""[🌐 Greater(1)](https://onnx.ai/onnx/operators/onnx__Greater.html#greater-1 "Online Documentation")
@@ -1603,7 +1624,11 @@ def LRN(
schema = get_schema("LRN", 1, "")
op = Op(self, "LRN", schema)
return op(
- *self._prepare_inputs(schema, X), alpha=alpha, beta=beta, bias=bias, size=size
+ *self._prepare_inputs(schema, X),
+ alpha=alpha,
+ beta=beta,
+ bias=bias,
+ size=size,
)
T_LSTM = TypeVar("T_LSTM", DOUBLE, FLOAT, FLOAT16)
@@ -1822,7 +1847,9 @@ def LeakyRelu(
schema = get_schema("LeakyRelu", 1, "")
op = Op(self, "LeakyRelu", schema)
return op(
- *self._prepare_inputs(schema, X), alpha=alpha, consumed_inputs=consumed_inputs
+ *self._prepare_inputs(schema, X),
+ alpha=alpha,
+ consumed_inputs=consumed_inputs,
)
T_Less = TypeVar("T_Less", DOUBLE, FLOAT, FLOAT16)
@@ -1935,7 +1962,11 @@ def LogSoftmax(self, input: T_LogSoftmax, *, axis: int = 1) -> T_LogSoftmax:
)
def Loop(
- self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
) -> V_Loop:
r"""[🌐 Loop(1)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-1 "Online Documentation")
@@ -1954,7 +1985,7 @@ def Loop(
This table summarizes the operating modes of this operator with equivalent
C-style code:
- Operator inputs defined as (max_trip_count, condition_var).
+ Operator inputs defined as (max_trip_count, condition_var).
input ("", ""):
for (int i=0; ; ++i) {
@@ -2171,7 +2202,7 @@ def MatMul(self, A: T_MatMul, B: T_MatMul) -> T_MatMul:
r"""[🌐 MatMul(1)](https://onnx.ai/onnx/operators/onnx__MatMul.html#matmul-1 "Online Documentation")
- Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
Args:
@@ -2493,7 +2524,11 @@ def Or(self, A: T_Or, B: T_Or, *, axis: Optional[int] = None, broadcast: int = 0
T_PRelu = TypeVar("T_PRelu", DOUBLE, FLOAT, FLOAT16)
def PRelu(
- self, X: T_PRelu, slope: T_PRelu, *, consumed_inputs: Optional[Sequence[int]] = None
+ self,
+ X: T_PRelu,
+ slope: T_PRelu,
+ *,
+ consumed_inputs: Optional[Sequence[int]] = None,
) -> T_PRelu:
r"""[🌐 PRelu(1)](https://onnx.ai/onnx/operators/onnx__PRelu.html#prelu-1 "Online Documentation")
@@ -2567,7 +2602,10 @@ def Pad(
schema = get_schema("Pad", 1, "")
op = Op(self, "Pad", schema)
return op(
- *self._prepare_inputs(schema, data), mode=mode, paddings=paddings, value=value
+ *self._prepare_inputs(schema, data),
+ mode=mode,
+ paddings=paddings,
+ value=value,
)
T_Pow = TypeVar("T_Pow", DOUBLE, FLOAT, FLOAT16)
@@ -2975,7 +3013,11 @@ def RandomUniformLike(
schema = get_schema("RandomUniformLike", 1, "")
op = Op(self, "RandomUniformLike", schema)
return op(
- *self._prepare_inputs(schema, input), dtype=dtype, high=high, low=low, seed=seed
+ *self._prepare_inputs(schema, input),
+ dtype=dtype,
+ high=high,
+ low=low,
+ seed=seed,
)
T_Reciprocal = TypeVar("T_Reciprocal", DOUBLE, FLOAT, FLOAT16)
@@ -3004,7 +3046,11 @@ def Reciprocal(
T_ReduceL1 = TypeVar("T_ReduceL1", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceL1(
- self, data: T_ReduceL1, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceL1,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceL1:
r"""[🌐 ReduceL1(1)](https://onnx.ai/onnx/operators/onnx__ReduceL1.html#reducel1-1 "Online Documentation")
@@ -3034,7 +3080,11 @@ def ReduceL1(
T_ReduceL2 = TypeVar("T_ReduceL2", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceL2(
- self, data: T_ReduceL2, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceL2,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceL2:
r"""[🌐 ReduceL2(1)](https://onnx.ai/onnx/operators/onnx__ReduceL2.html#reducel2-1 "Online Documentation")
@@ -3066,7 +3116,11 @@ def ReduceL2(
)
def ReduceLogSum(
- self, data: T_ReduceLogSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceLogSum,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceLogSum:
r"""[🌐 ReduceLogSum(1)](https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html#reducelogsum-1 "Online Documentation")
@@ -3132,7 +3186,11 @@ def ReduceLogSumExp(
T_ReduceMax = TypeVar("T_ReduceMax", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceMax(
- self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMax,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMax:
r"""[🌐 ReduceMax(1)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-1 "Online Documentation")
@@ -3164,7 +3222,11 @@ def ReduceMax(
)
def ReduceMean(
- self, data: T_ReduceMean, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMean,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMean:
r"""[🌐 ReduceMean(1)](https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-1 "Online Documentation")
@@ -3194,7 +3256,11 @@ def ReduceMean(
T_ReduceMin = TypeVar("T_ReduceMin", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceMin(
- self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMin,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMin:
r"""[🌐 ReduceMin(1)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-1 "Online Documentation")
@@ -3226,7 +3292,11 @@ def ReduceMin(
)
def ReduceProd(
- self, data: T_ReduceProd, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceProd,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceProd:
r"""[🌐 ReduceProd(1)](https://onnx.ai/onnx/operators/onnx__ReduceProd.html#reduceprod-1 "Online Documentation")
@@ -3256,7 +3326,11 @@ def ReduceProd(
T_ReduceSum = TypeVar("T_ReduceSum", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceSum(
- self, data: T_ReduceSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceSum,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceSum:
r"""[🌐 ReduceSum(1)](https://onnx.ai/onnx/operators/onnx__ReduceSum.html#reducesum-1 "Online Documentation")
@@ -3371,7 +3445,9 @@ def Reshape(
schema = get_schema("Reshape", 1, "")
op = Op(self, "Reshape", schema)
return op(
- *self._prepare_inputs(schema, data), consumed_inputs=consumed_inputs, shape=shape
+ *self._prepare_inputs(schema, data),
+ consumed_inputs=consumed_inputs,
+ shape=shape,
)
T_Selu = TypeVar("T_Selu", DOUBLE, FLOAT, FLOAT16)
@@ -3538,7 +3614,7 @@ def Slice(
Produces a slice of the input tensor along multiple axes. Similar to numpy:
- https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
+ https://numpy.org/doc/stable/reference/routines.indexing.html
Slices uses `axes`, `starts` and `ends` attributes to specify the start and end
dimension for each axis in the list of axes, it uses this information to
slice the input `data` tensor. If a negative value is passed for any of the
@@ -3632,7 +3708,7 @@ def Softplus(self, X: T_Softplus) -> T_Softplus:
Args:
- X: (differentiable) 1D input tensor
+ X: (differentiable) Input tensor
"""
schema = get_schema("Softplus", 1, "")
@@ -4019,7 +4095,12 @@ def Unsqueeze(self, data: T_Unsqueeze, *, axes: Sequence[int]) -> T_Unsqueeze:
T_Upsample = TypeVar("T_Upsample", BOOL, DOUBLE, FLOAT, FLOAT16, INT32, INT64)
def Upsample(
- self, X: T_Upsample, *, height_scale: float, mode: str = "nearest", width_scale: float
+ self,
+ X: T_Upsample,
+ *,
+ height_scale: float,
+ mode: str = "nearest",
+ width_scale: float,
) -> T_Upsample:
r"""[🌐 Upsample(1)](https://onnx.ai/onnx/operators/onnx__Upsample.html#upsample-1 "Online Documentation")
diff --git a/onnxscript/onnx_opset/_impl/opset10.py b/onnxscript/onnx_opset/_impl/opset10.py
index 65ea0013e3..ec1734b266 100644
--- a/onnxscript/onnx_opset/_impl/opset10.py
+++ b/onnxscript/onnx_opset/_impl/opset10.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -346,7 +345,7 @@ def MatMulInteger(
r"""[🌐 MatMulInteger(10)](https://onnx.ai/onnx/operators/onnx__MatMulInteger.html#matmulinteger-10 "Online Documentation")
- Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html.
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
The production MUST never overflow. The accumulation may overflow if and only if in 32 bits.
@@ -749,7 +748,7 @@ def QLinearMatMul(
r"""[🌐 QLinearMatMul(10)](https://onnx.ai/onnx/operators/onnx__QLinearMatMul.html#qlinearmatmul-10 "Online Documentation")
- Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html.
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
It consumes two quantized input tensors, their scales and zero points, scale and zero point of output,
and computes the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point).
For (x / y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
@@ -1067,7 +1066,7 @@ def Slice(
Produces a slice of the input tensor along multiple axes. Similar to numpy:
- https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
+ https://numpy.org/doc/stable/reference/routines.indexing.html
Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end
dimension and step for each axis in the list of axes, it uses this information to
slice the input `data` tensor. If a negative value is passed for any of the
diff --git a/onnxscript/onnx_opset/_impl/opset11.py b/onnxscript/onnx_opset/_impl/opset11.py
index bb54cbeb02..6538ac3afb 100644
--- a/onnxscript/onnx_opset/_impl/opset11.py
+++ b/onnxscript/onnx_opset/_impl/opset11.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: E741, D214, D402, D405, D411, D416
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -1465,7 +1464,11 @@ def LogSoftmax(self, input: T_LogSoftmax, *, axis: int = 1) -> T_LogSoftmax:
)
def Loop(
- self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
) -> V_Loop:
r"""[🌐 Loop(11)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-11 "Online Documentation")
@@ -1484,7 +1487,7 @@ def Loop(
This table summarizes the operating modes of this operator with equivalent
C-style code:
- Operator inputs defined as (max_trip_count, condition_var).
+ Operator inputs defined as (max_trip_count, condition_var).
input ("", ""):
for (int i=0; ; ++i) {
@@ -2238,7 +2241,11 @@ def Range(self, start: T_Range, limit: T_Range, delta: T_Range) -> T_Range:
T_ReduceL1 = TypeVar("T_ReduceL1", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceL1(
- self, data: T_ReduceL1, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceL1,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceL1:
r"""[🌐 ReduceL1(11)](https://onnx.ai/onnx/operators/onnx__ReduceL1.html#reducel1-11 "Online Documentation")
@@ -2268,7 +2275,11 @@ def ReduceL1(
T_ReduceL2 = TypeVar("T_ReduceL2", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceL2(
- self, data: T_ReduceL2, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceL2,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceL2:
r"""[🌐 ReduceL2(11)](https://onnx.ai/onnx/operators/onnx__ReduceL2.html#reducel2-11 "Online Documentation")
@@ -2300,7 +2311,11 @@ def ReduceL2(
)
def ReduceLogSum(
- self, data: T_ReduceLogSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceLogSum,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceLogSum:
r"""[🌐 ReduceLogSum(11)](https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html#reducelogsum-11 "Online Documentation")
@@ -2366,7 +2381,11 @@ def ReduceLogSumExp(
T_ReduceMax = TypeVar("T_ReduceMax", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceMax(
- self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMax,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMax:
r"""[🌐 ReduceMax(11)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-11 "Online Documentation")
@@ -2399,7 +2418,11 @@ def ReduceMax(
)
def ReduceMean(
- self, data: T_ReduceMean, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMean,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMean:
r"""[🌐 ReduceMean(11)](https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-11 "Online Documentation")
@@ -2429,7 +2452,11 @@ def ReduceMean(
T_ReduceMin = TypeVar("T_ReduceMin", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceMin(
- self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMin,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMin:
r"""[🌐 ReduceMin(11)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-11 "Online Documentation")
@@ -2462,7 +2489,11 @@ def ReduceMin(
)
def ReduceProd(
- self, data: T_ReduceProd, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceProd,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceProd:
r"""[🌐 ReduceProd(11)](https://onnx.ai/onnx/operators/onnx__ReduceProd.html#reduceprod-11 "Online Documentation")
@@ -2492,7 +2523,11 @@ def ReduceProd(
T_ReduceSum = TypeVar("T_ReduceSum", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64)
def ReduceSum(
- self, data: T_ReduceSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceSum,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceSum:
r"""[🌐 ReduceSum(11)](https://onnx.ai/onnx/operators/onnx__ReduceSum.html#reducesum-11 "Online Documentation")
@@ -3314,7 +3349,9 @@ def SequenceEmpty(self, *, dtype: Optional[int] = None) -> S_SequenceEmpty:
I_SequenceErase = TypeVar("I_SequenceErase", INT32, INT64)
def SequenceErase(
- self, input_sequence: S_SequenceErase, position: Optional[I_SequenceErase] = None
+ self,
+ input_sequence: S_SequenceErase,
+ position: Optional[I_SequenceErase] = None,
) -> S_SequenceErase:
r"""[🌐 SequenceErase(11)](https://onnx.ai/onnx/operators/onnx__SequenceErase.html#sequenceerase-11 "Online Documentation")
@@ -3481,7 +3518,7 @@ def Slice(
Produces a slice of the input tensor along multiple axes. Similar to numpy:
- https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
+ https://numpy.org/doc/stable/reference/routines.indexing.html
Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end
dimension and step for each axis in the list of axes, it uses this information to
slice the input `data` tensor. If a negative value is passed for any of the
@@ -3798,7 +3835,10 @@ def TopK(
schema = get_schema("TopK", 11, "")
op = Op(self, "TopK", schema)
return op(
- *self._prepare_inputs(schema, X, K), axis=axis, largest=largest, sorted=sorted
+ *self._prepare_inputs(schema, X, K),
+ axis=axis,
+ largest=largest,
+ sorted=sorted,
)
T_Unique = TypeVar(
diff --git a/onnxscript/onnx_opset/_impl/opset12.py b/onnxscript/onnx_opset/_impl/opset12.py
index ede4fb34a7..95b2ea83c5 100644
--- a/onnxscript/onnx_opset/_impl/opset12.py
+++ b/onnxscript/onnx_opset/_impl/opset12.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -60,7 +59,12 @@ def __new__(cls):
)
def ArgMax(
- self, data: T_ArgMax, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
+ self,
+ data: T_ArgMax,
+ *,
+ axis: int = 0,
+ keepdims: int = 1,
+ select_last_index: int = 0,
) -> INT64:
r"""[🌐 ArgMax(12)](https://onnx.ai/onnx/operators/onnx__ArgMax.html#argmax-12 "Online Documentation")
@@ -111,7 +115,12 @@ def ArgMax(
)
def ArgMin(
- self, data: T_ArgMin, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
+ self,
+ data: T_ArgMin,
+ *,
+ axis: int = 0,
+ keepdims: int = 1,
+ select_last_index: int = 0,
) -> INT64:
r"""[🌐 ArgMin(12)](https://onnx.ai/onnx/operators/onnx__ArgMin.html#argmin-12 "Online Documentation")
@@ -674,7 +683,7 @@ def MaxPool(
```
output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)
```
- if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored.
+ if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`.
`auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled:
```
@@ -938,7 +947,11 @@ def Pow(self, X: T_Pow, Y: T1_Pow) -> T_Pow:
)
def ReduceMax(
- self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMax,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMax:
r"""[🌐 ReduceMax(12)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-12 "Online Documentation")
@@ -970,7 +983,11 @@ def ReduceMax(
)
def ReduceMin(
- self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMin,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMin:
r"""[🌐 ReduceMin(12)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-12 "Online Documentation")
diff --git a/onnxscript/onnx_opset/_impl/opset13.py b/onnxscript/onnx_opset/_impl/opset13.py
index 616fe5ff69..5403df22cf 100644
--- a/onnxscript/onnx_opset/_impl/opset13.py
+++ b/onnxscript/onnx_opset/_impl/opset13.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D214, D402, D405, D411, D416, D417
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -116,7 +115,12 @@ def Add(self, A: T_Add, B: T_Add) -> T_Add:
)
def ArgMax(
- self, data: T_ArgMax, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
+ self,
+ data: T_ArgMax,
+ *,
+ axis: int = 0,
+ keepdims: int = 1,
+ select_last_index: int = 0,
) -> INT64:
r"""[🌐 ArgMax(13)](https://onnx.ai/onnx/operators/onnx__ArgMax.html#argmax-13 "Online Documentation")
@@ -168,7 +172,12 @@ def ArgMax(
)
def ArgMin(
- self, data: T_ArgMin, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
+ self,
+ data: T_ArgMin,
+ *,
+ axis: int = 0,
+ keepdims: int = 1,
+ select_last_index: int = 0,
) -> INT64:
r"""[🌐 ArgMin(13)](https://onnx.ai/onnx/operators/onnx__ArgMin.html#argmin-13 "Online Documentation")
@@ -334,6 +343,8 @@ def Clip(
Clip operator limits the given input within an interval. The interval is
specified by the inputs 'min' and 'max'. They default to
numeric_limits::lowest() and numeric_limits::max(), respectively.
+ When 'min' is greater than 'max', the clip operator sets all the 'input' values to
+ the value of 'max'. Thus, this is equivalent to 'Min(max, Max(input, min))'.
Args:
@@ -875,7 +886,22 @@ def Gather(self, data: T_Gather, indices: Tind_Gather, *, axis: int = 0) -> T_Ga
entries of the axis dimension of `data` (by default outer-most one as axis=0) indexed by `indices`, and concatenates
them in an output tensor of rank q + (r - 1).
- If `axis = 0`, let `k = indices[i_{0}, ..., i_{q-1}]`
+ It is an indexing operation that indexes into the input `data` along a single (specified) axis.
+ Each entry in `indices` produces a `r-1` dimensional slice of the input tensor.
+ The entire operation produces, conceptually, a `q`-dimensional tensor of `r-1` dimensional slices,
+ which is arranged into a `q + (r-1)`-dimensional tensor, with the `q` dimensions taking the
+ place of the original `axis` that is being indexed into.
+
+ The following few examples illustrate how `Gather` works for specific shapes of `data`,
+ `indices`, and given value of `axis`:
+ | data shape | indices shape | axis | output shape | output equation |
+ | --- | --- | --- | --- | --- |
+ | (P, Q) | ( ) (a scalar) | 0 | (Q) | output[q] = data[indices, q] |
+ | (P, Q, R) | ( ) (a scalar) | 1 | (P, R) | output[p, r] = data[p, indices, r] |
+ | (P, Q) | (R, S) | 0 | (R, S, Q) | output[r, s, q] = data[ [indices[r, s], q] |
+ | (P, Q) | (R, S) | 1 | (P, R, S) | output[p, r, s] = data[ p, indices[r, s]] |
+
+ More generally, if `axis = 0`, let `k = indices[i_{0}, ..., i_{q-1}]`
then `output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[k , j_{0}, ..., j_{r-2}]`:
::
@@ -1462,7 +1488,11 @@ def LRN(
schema = get_schema("LRN", 13, "")
op = Op(self, "LRN", schema)
return op(
- *self._prepare_inputs(schema, X), alpha=alpha, beta=beta, bias=bias, size=size
+ *self._prepare_inputs(schema, X),
+ alpha=alpha,
+ beta=beta,
+ bias=bias,
+ size=size,
)
T_Less = TypeVar(
@@ -1589,7 +1619,11 @@ def LogSoftmax(self, input: T_LogSoftmax, *, axis: int = -1) -> T_LogSoftmax:
)
def Loop(
- self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
) -> V_Loop:
r"""[🌐 Loop(13)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-13 "Online Documentation")
@@ -1608,7 +1642,7 @@ def Loop(
This table summarizes the operating modes of this operator with equivalent
C-style code:
- Operator inputs defined as (max_trip_count, condition_var).
+ Operator inputs defined as (max_trip_count, condition_var).
input ("", ""):
for (int i=0; ; ++i) {
@@ -1762,7 +1796,7 @@ def MatMul(self, A: T_MatMul, B: T_MatMul) -> T_MatMul:
r"""[🌐 MatMul(13)](https://onnx.ai/onnx/operators/onnx__MatMul.html#matmul-13 "Online Documentation")
- Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
Args:
@@ -1907,19 +1941,23 @@ def Mod(self, A: T_Mod, B: T_Mod, *, fmod: int = 0) -> T_Mod:
r"""[🌐 Mod(13)](https://onnx.ai/onnx/operators/onnx__Mod.html#mod-13 "Online Documentation")
- Performs element-wise binary modulus (with Numpy-style broadcasting support).
- The sign of the remainder is the same as that of the Divisor.
-
- Mod operator can also behave like C fmod() or numpy.fmod. In this case, the sign of the remainder however, will be the same as the Dividend
- (in contrast to integer mod). To force a behavior like numpy.fmod() an 'fmod' Attribute is provided.
- This attribute is set to 0 by default causing the behavior to be like integer mod.
- Setting this attribute to 1 causes the remainder to be calculated similar to that of numpy.fmod().
+ Performs an element-wise binary modulo operation.
+ The semantics and supported data types depend on the value of the `fmod` attribute which must be `0` (default), or `1`.
- If the input type is floating point, then `fmod` attribute must be set to 1.
+ If the `fmod` attribute is set to `0`, `T` is constrained to integer data types and the semantics follow that of the Python `%`-operator.
+ The sign of the result is that of the divisor.
- In case of dividend being zero, the results will be platform dependent.
+ If `fmod` is set to `1`, the behavior of this operator follows that of the `fmod` function in C and `T` is constrained to floating point data types.
+ The result of this operator is the remainder of the division operation `x / y` where `x` and `y` are respective elements of `A` and `B`. The result is exactly the value `x - n * y`, where `n` is `x / y` with its fractional part truncated.
+ The returned value has the same sign as `x` (except if `x` is `-0`) and is less or equal to `|y|` in magnitude.
+ The following special cases apply when `fmod` is set to `1`:
+ - If `x` is `-0` and `y` is greater than zero, either `+0` or `-0` may be returned.
+ - If `x` is `±∞` and `y` is not `NaN`, `NaN` is returned.
+ - If `y` is `±0` and `x` is not `NaN`, `NaN` should be returned.
+ - If `y` is `±∞` and `x` is finite, `x` is returned.
+ - If either argument is `NaN`, `NaN` is returned.
- This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check `Broadcasting in ONNX `_.
+ This operator supports **multidirectional (i.e., NumPy-style) broadcasting**; for more details please check `Broadcasting in ONNX `_.
Args:
@@ -2414,7 +2452,11 @@ def Reciprocal(self, X: T_Reciprocal) -> T_Reciprocal:
)
def ReduceL1(
- self, data: T_ReduceL1, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceL1,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceL1:
r"""[🌐 ReduceL1(13)](https://onnx.ai/onnx/operators/onnx__ReduceL1.html#reducel1-13 "Online Documentation")
@@ -2448,7 +2490,11 @@ def ReduceL1(
)
def ReduceL2(
- self, data: T_ReduceL2, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceL2,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceL2:
r"""[🌐 ReduceL2(13)](https://onnx.ai/onnx/operators/onnx__ReduceL2.html#reducel2-13 "Online Documentation")
@@ -2482,7 +2528,11 @@ def ReduceL2(
)
def ReduceLogSum(
- self, data: T_ReduceLogSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceLogSum,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceLogSum:
r"""[🌐 ReduceLogSum(13)](https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html#reducelogsum-13 "Online Documentation")
@@ -2512,7 +2562,15 @@ def ReduceLogSum(
return op(*self._prepare_inputs(schema, data), axes=axes, keepdims=keepdims)
T_ReduceLogSumExp = TypeVar(
- "T_ReduceLogSumExp", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64
+ "T_ReduceLogSumExp",
+ BFLOAT16,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT32,
+ INT64,
+ UINT32,
+ UINT64,
)
def ReduceLogSumExp(
@@ -2564,7 +2622,11 @@ def ReduceLogSumExp(
)
def ReduceMax(
- self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMax,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMax:
r"""[🌐 ReduceMax(13)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-13 "Online Documentation")
@@ -2598,7 +2660,11 @@ def ReduceMax(
)
def ReduceMean(
- self, data: T_ReduceMean, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMean,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMean:
r"""[🌐 ReduceMean(13)](https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 "Online Documentation")
@@ -2642,7 +2708,11 @@ def ReduceMean(
)
def ReduceMin(
- self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceMin,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceMin:
r"""[🌐 ReduceMin(13)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-13 "Online Documentation")
@@ -2676,7 +2746,11 @@ def ReduceMin(
)
def ReduceProd(
- self, data: T_ReduceProd, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1
+ self,
+ data: T_ReduceProd,
+ *,
+ axes: Optional[Sequence[int]] = None,
+ keepdims: int = 1,
) -> T_ReduceProd:
r"""[🌐 ReduceProd(13)](https://onnx.ai/onnx/operators/onnx__ReduceProd.html#reduceprod-13 "Online Documentation")
@@ -2733,18 +2807,20 @@ def ReduceSum(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceSum", 13, "")
@@ -2756,7 +2832,15 @@ def ReduceSum(
)
T_ReduceSumSquare = TypeVar(
- "T_ReduceSumSquare", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64
+ "T_ReduceSumSquare",
+ BFLOAT16,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT32,
+ INT64,
+ UINT32,
+ UINT64,
)
def ReduceSumSquare(
diff --git a/onnxscript/onnx_opset/_impl/opset14.py b/onnxscript/onnx_opset/_impl/opset14.py
index 21983c8a94..a9ec21f0d8 100644
--- a/onnxscript/onnx_opset/_impl/opset14.py
+++ b/onnxscript/onnx_opset/_impl/opset14.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402, D405
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset15.py b/onnxscript/onnx_opset/_impl/opset15.py
index 38c235bced..c0758999f0 100644
--- a/onnxscript/onnx_opset/_impl/opset15.py
+++ b/onnxscript/onnx_opset/_impl/opset15.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402, D412
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -291,36 +290,37 @@ def CastLike(self, input: T1_CastLike, target_type: T2_CastLike) -> T2_CastLike:
)
O_Optional: TypeAlias = Union[
- _Optional[Sequence[BOOL]],
- _Optional[Sequence[COMPLEX128]],
- _Optional[Sequence[COMPLEX64]],
- _Optional[Sequence[DOUBLE]],
- _Optional[Sequence[FLOAT]],
- _Optional[Sequence[FLOAT16]],
- _Optional[Sequence[INT16]],
- _Optional[Sequence[INT32]],
- _Optional[Sequence[INT64]],
- _Optional[Sequence[INT8]],
- _Optional[Sequence[STRING]],
- _Optional[Sequence[UINT16]],
- _Optional[Sequence[UINT32]],
- _Optional[Sequence[UINT64]],
- _Optional[Sequence[UINT8]],
- _Optional[BOOL],
- _Optional[COMPLEX128],
- _Optional[COMPLEX64],
- _Optional[DOUBLE],
- _Optional[FLOAT],
- _Optional[FLOAT16],
- _Optional[INT16],
- _Optional[INT32],
- _Optional[INT64],
- _Optional[INT8],
- _Optional[STRING],
- _Optional[UINT16],
- _Optional[UINT32],
- _Optional[UINT64],
- _Optional[UINT8],
+ None,
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
]
def Optional(
@@ -546,11 +546,11 @@ def Shape(self, data: T_Shape, *, end: _Optional[int] = None, start: int = 0) ->
The end axis, if specified, is exclusive (and the returned value will not include the size of that axis).
If the end axis is omitted, the axes upto the last one will be included.
Negative axes indicate counting back from the last axis.
- Note that axes will be clamped to the range [0, r-1], where r is the
+ Note that axes will be clamped to the range [0, r], where r is the
rank of the input tensor if they are out-of-range (after adding r in the case of
negative axis). Thus, specifying any end value > r is equivalent to specifying an end
value of r, and specifying any start value < -r is equivalent to specifying a start
- value of 0.
+ value of 0. If start > end, the result will be an empty shape.
Examples:
diff --git a/onnxscript/onnx_opset/_impl/opset16.py b/onnxscript/onnx_opset/_impl/opset16.py
index c90392d582..21a92a6026 100644
--- a/onnxscript/onnx_opset/_impl/opset16.py
+++ b/onnxscript/onnx_opset/_impl/opset16.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D214, D402, D405, D411, D416
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -253,38 +252,7 @@ def Identity(self, input: V_Identity) -> V_Identity:
B_If: TypeAlias = BOOL
V_If: TypeAlias = Union[
- Optional[Sequence[BFLOAT16]],
- Optional[Sequence[BOOL]],
- Optional[Sequence[COMPLEX128]],
- Optional[Sequence[COMPLEX64]],
- Optional[Sequence[DOUBLE]],
- Optional[Sequence[FLOAT]],
- Optional[Sequence[FLOAT16]],
- Optional[Sequence[INT16]],
- Optional[Sequence[INT32]],
- Optional[Sequence[INT64]],
- Optional[Sequence[INT8]],
- Optional[Sequence[STRING]],
- Optional[Sequence[UINT16]],
- Optional[Sequence[UINT32]],
- Optional[Sequence[UINT64]],
- Optional[Sequence[UINT8]],
- Optional[BFLOAT16],
- Optional[BOOL],
- Optional[COMPLEX128],
- Optional[COMPLEX64],
- Optional[DOUBLE],
- Optional[FLOAT],
- Optional[FLOAT16],
- Optional[INT16],
- Optional[INT32],
- Optional[INT64],
- Optional[INT8],
- Optional[STRING],
- Optional[UINT16],
- Optional[UINT32],
- Optional[UINT64],
- Optional[UINT8],
+ None,
Sequence[BFLOAT16],
Sequence[BOOL],
Sequence[COMPLEX128],
@@ -476,7 +444,11 @@ def LessOrEqual(self, A: T_LessOrEqual, B: T_LessOrEqual) -> T1_LessOrEqual:
)
def Loop(
- self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
) -> V_Loop:
r"""[🌐 Loop(16)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-16 "Online Documentation")
diff --git a/onnxscript/onnx_opset/_impl/opset17.py b/onnxscript/onnx_opset/_impl/opset17.py
index 80b4b457c0..092658a502 100644
--- a/onnxscript/onnx_opset/_impl/opset17.py
+++ b/onnxscript/onnx_opset/_impl/opset17.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset18.py b/onnxscript/onnx_opset/_impl/opset18.py
index c4154635d9..a795391355 100644
--- a/onnxscript/onnx_opset/_impl/opset18.py
+++ b/onnxscript/onnx_opset/_impl/opset18.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402, D405
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -169,12 +168,18 @@ def CenterCropPad(
Center crop or pad an input to given dimensions.
- The crop/pad dimensions can be specified for a subset of the `axes`. Non-specified dimensions will not be
- cropped or padded.
+ The crop/pad dimensions can be specified for a subset of the `axes`; unspecified dimensions will remain unchanged.
- If the input dimensions are bigger than the crop shape, a centered cropping window is extracted from the input.
- If the input dimensions are smaller than the crop shape, the input is padded on each side equally,
- so that the input is centered in the output.
+ If the input dimensions are larger than the target crop dimensions, a centered cropping window will be extracted
+ from the input. The starting value for the cropping window is rounded down, which means that if the difference
+ between the input shape and the crop shape is odd, the cropping window will be shifted half a pixel to the left
+ of the input center.
+
+ If the input dimensions are smaller than the target crop dimensions, the input will be padded equally on both sides
+ to center it in the output. In cases where the total number of padding pixels is odd, an additional pixel will be
+ added to the right side.
+
+ The padding value used is zero.
Args:
@@ -286,65 +291,6 @@ def Col2Im(
strides=strides,
)
- T_GroupNormalization = TypeVar("T_GroupNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
-
- def GroupNormalization(
- self,
- X: T_GroupNormalization,
- scale: T_GroupNormalization,
- bias: T_GroupNormalization,
- *,
- epsilon: float = 9.999999747378752e-06,
- num_groups: int,
- ) -> T_GroupNormalization:
- r"""[🌐 GroupNormalization(18)](https://onnx.ai/onnx/operators/onnx__GroupNormalization.html#groupnormalization-18 "Online Documentation")
-
-
- A GroupNormalization function. Carries out group normalization as described in
- the paper https://arxiv.org/abs/1803.08494
-
- This operator transforms input according to
- ::
-
- y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
-
-
- where the mean and variance are computed per instance per group of channels, and
- `scale` and `bias` should be specified for each group of channels. The number of
- groups `num_groups` should be divisible by the number of channels so that there are
- an equal number of channels per group.
-
- When the number of groups is the same as the number of channels, this operator is
- equivalent to InstanceNormalization. When there is only one group, this operator
- is equivalent to LayerNormalization.
-
-
- Args:
- X: (differentiable) Input data tensor. Dimensions for image cases are `(N x
- C x H x W)`, where `N` is the batch size, `C` is the number of channels,
- and `H` and `W` are the height and width of the data. Statistics are
- computed for every group of channels over `C`, `H`, and `W`. For
- non-image cases, the dimensions are in the form of `(N x C x D1 x D2 ...
- Dn)`.
-
- scale: (differentiable) Scale tensor of shape `(num_groups)`.
-
- bias: (differentiable) Bias tensor of shape `(num_groups)`.
-
- epsilon: The epsilon value to use to avoid division by zero.
-
- num_groups: The number of groups of channels. It should be a divisor of the
- number of channels `C`.
- """
-
- schema = get_schema("GroupNormalization", 18, "")
- op = Op(self, "GroupNormalization", schema)
- return op(
- *self._prepare_inputs(schema, X, scale, bias),
- epsilon=epsilon,
- num_groups=num_groups,
- )
-
T_LpPool = TypeVar("T_LpPool", DOUBLE, FLOAT, FLOAT16)
def LpPool(
@@ -838,18 +784,20 @@ def ReduceL1(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceL1", 18, "")
@@ -888,18 +836,20 @@ def ReduceL2(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceL2", 18, "")
@@ -938,18 +888,20 @@ def ReduceLogSum(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceLogSum", 18, "")
@@ -961,7 +913,15 @@ def ReduceLogSum(
)
T_ReduceLogSumExp = TypeVar(
- "T_ReduceLogSumExp", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64
+ "T_ReduceLogSumExp",
+ BFLOAT16,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT32,
+ INT64,
+ UINT32,
+ UINT64,
)
def ReduceLogSumExp(
@@ -988,18 +948,20 @@ def ReduceLogSumExp(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceLogSumExp", 18, "")
@@ -1048,18 +1010,20 @@ def ReduceMax(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceMax", 18, "")
@@ -1098,18 +1062,20 @@ def ReduceMean(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceMean", 18, "")
@@ -1158,18 +1124,20 @@ def ReduceMin(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceMin", 18, "")
@@ -1208,18 +1176,20 @@ def ReduceProd(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceProd", 18, "")
@@ -1231,7 +1201,15 @@ def ReduceProd(
)
T_ReduceSumSquare = TypeVar(
- "T_ReduceSumSquare", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64
+ "T_ReduceSumSquare",
+ BFLOAT16,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT32,
+ INT64,
+ UINT32,
+ UINT64,
)
def ReduceSumSquare(
@@ -1258,18 +1236,20 @@ def ReduceSumSquare(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceSumSquare", 18, "")
@@ -1434,13 +1414,13 @@ def Resize(
keeping the original aspect ratio:
`scale = Min(sizes[i] /
in_size[d])`
- `out_size[d] = round_int(scale * in_size[i])`
+ `out_size[d] = round_int(scale * in_size[d])`
If `keep_aspect_ratio_policy` is `"not_smaller"`, the sizes are adjusted
so that no extent of the output is smaller than the specified size,
while keeping the original aspect ratio:
`scale = Max(sizes[i] /
in_size[d])`
- `out_size[d] = round_int(scale * in_size[i])`
+ `out_size[d] = round_int(scale * in_size[d])`
For non-resizable axes (those not specified in `axes`), the output size
will be equal to the input size.
@@ -1799,5 +1779,7 @@ def Split(
schema = get_schema("Split", 18, "")
op = Op(self, "Split", schema)
return op(
- *self._prepare_inputs(schema, input, split), axis=axis, num_outputs=num_outputs
+ *self._prepare_inputs(schema, input, split),
+ axis=axis,
+ num_outputs=num_outputs,
)
diff --git a/onnxscript/onnx_opset/_impl/opset19.py b/onnxscript/onnx_opset/_impl/opset19.py
index 467c23917e..18a7cba17a 100644
--- a/onnxscript/onnx_opset/_impl/opset19.py
+++ b/onnxscript/onnx_opset/_impl/opset19.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D214, D402, D405, D411, D412, D416
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -80,7 +79,7 @@ def AveragePool(
```
output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)
```
- if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored.
+ if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`.
`auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled:
```
@@ -245,28 +244,31 @@ def Cast(self, input: T1_Cast, *, saturate: int = 1, to: int) -> T2_Cast:
to the following rules. `[x]` means the value rounded to
the target mantissa width.
- | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
- |------|----|----|----|----|
- | 0 | 0 | 0 | 0 | 0 |
- |-0 | -0 | 0 | -0 | 0 |
- | NaN | NaN | NaN | NaN | NaN |
- | +/- Inf | +/- FLT_MAX | NaN | FLT_MAX | NaN |
- | [x] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX |
- | [x] < -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX |
- | else | RNE | RNE | RNE | RNE |
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | -------- | -------- | -------- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | Inf | FLT_MAX | NaN | FLT_MAX | NaN |
+ | -Inf | -FLT_MAX | NaN | -FLT_MAX | NaN |
+ | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX |
+ | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX |
+ | else | RNE | RNE | RNE | RNE |
The behavior changes if the parameter 'saturate' is set to False.
The rules then become:
- | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
- |------|----|----|----|----|
- | 0 | 0 | 0 | 0 | 0 |
- |-0 | -0 | 0 | -0 | 0 |
- | NaN | NaN | NaN | NaN | NaN |
- | +/- Inf | NaN | NaN | +/- Inf | NaN |
- | [x] > FLT_MAX | NaN | NaN | Inf | NaN |
- | [x] < -FLT_MAX | NaN | NaN | -Inf | NaN |
- | else | RNE | RNE | RNE | RNE |
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | ------ | -------- | ---- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | -NaN | -NaN | NaN | -NaN | NaN |
+ | Inf | NaN | NaN | Inf | NaN |
+ | -Inf | -NaN | NaN | -Inf | NaN |
+ | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN |
+ | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN |
+ | else | RNE | RNE | RNE | RNE |
Args:
@@ -566,9 +568,10 @@ def DequantizeLinear(
It's optional. Zero point is 0 when it's not specified.
axis: (Optional) The axis of the dequantizing dimension of the input tensor.
- Ignored for per-tensor quantization. Negative value means counting
- dimensions from the back. Accepted range is [-r, r-1] where r =
- rank(input).
+ Used only for per-axis quantization. Negative value means counting
+ dimensions from the back. Accepted range is `[-r, r-1]` where `r =
+ rank(input)`. When the rank of the input is 1, per-tensor quantization
+ is applied, rendering the axis unnecessary in this scenario.
"""
schema = get_schema("DequantizeLinear", 19, "")
@@ -700,42 +703,7 @@ def Identity(self, input: V_Identity) -> V_Identity:
B_If: TypeAlias = BOOL
V_If: TypeAlias = Union[
- Optional[Sequence[BFLOAT16]],
- Optional[Sequence[BOOL]],
- Optional[Sequence[COMPLEX128]],
- Optional[Sequence[COMPLEX64]],
- Optional[Sequence[DOUBLE]],
- Optional[Sequence[FLOAT]],
- Optional[Sequence[FLOAT16]],
- Optional[Sequence[INT16]],
- Optional[Sequence[INT32]],
- Optional[Sequence[INT64]],
- Optional[Sequence[INT8]],
- Optional[Sequence[STRING]],
- Optional[Sequence[UINT16]],
- Optional[Sequence[UINT32]],
- Optional[Sequence[UINT64]],
- Optional[Sequence[UINT8]],
- Optional[BFLOAT16],
- Optional[BOOL],
- Optional[COMPLEX128],
- Optional[COMPLEX64],
- Optional[DOUBLE],
- Optional[FLOAT],
- Optional[FLOAT16],
- Optional[FLOAT8E4M3FN],
- Optional[FLOAT8E4M3FNUZ],
- Optional[FLOAT8E5M2],
- Optional[FLOAT8E5M2FNUZ],
- Optional[INT16],
- Optional[INT32],
- Optional[INT64],
- Optional[INT8],
- Optional[STRING],
- Optional[UINT16],
- Optional[UINT32],
- Optional[UINT64],
- Optional[UINT8],
+ None,
Sequence[BFLOAT16],
Sequence[BOOL],
Sequence[COMPLEX128],
@@ -743,10 +711,6 @@ def Identity(self, input: V_Identity) -> V_Identity:
Sequence[DOUBLE],
Sequence[FLOAT],
Sequence[FLOAT16],
- Sequence[FLOAT8E4M3FN],
- Sequence[FLOAT8E4M3FNUZ],
- Sequence[FLOAT8E5M2],
- Sequence[FLOAT8E5M2FNUZ],
Sequence[INT16],
Sequence[INT32],
Sequence[INT64],
@@ -776,6 +740,10 @@ def Identity(self, input: V_Identity) -> V_Identity:
UINT32,
UINT64,
UINT8,
+ Sequence[FLOAT8E4M3FN],
+ Sequence[FLOAT8E4M3FNUZ],
+ Sequence[FLOAT8E5M2],
+ Sequence[FLOAT8E5M2FNUZ],
]
def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If:
@@ -888,7 +856,11 @@ def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) ->
)
def Loop(
- self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
) -> V_Loop:
r"""[🌐 Loop(19)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-19 "Online Documentation")
@@ -1546,7 +1518,7 @@ def Resize(
```
scale = Min(sizes[i] /
in_size[d])
- out_size[d] = round_int(scale * in_size[i])
+ out_size[d] = round_int(scale * in_size[d])
```
If
@@ -1556,7 +1528,7 @@ def Resize(
```
scale = Max(sizes[i] /
in_size[d])
- out_size[d] = round_int(scale * in_size[i])
+ out_size[d] = round_int(scale * in_size[d])
```
For
@@ -1842,11 +1814,11 @@ def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) ->
The end axis, if specified, is exclusive (and the returned value will not include the size of that axis).
If the end axis is omitted, the axes upto the last one will be included.
Negative axes indicate counting back from the last axis.
- Note that axes will be clamped to the range [0, r-1], where r is the
+ Note that axes will be clamped to the range [0, r], where r is the
rank of the input tensor if they are out-of-range (after adding r in the case of
negative axis). Thus, specifying any end value > r is equivalent to specifying an end
value of r, and specifying any start value < -r is equivalent to specifying a start
- value of 0.
+ value of 0. If start > end, the result will be an empty shape.
Examples:
diff --git a/onnxscript/onnx_opset/_impl/opset2.py b/onnxscript/onnx_opset/_impl/opset2.py
index e04537c5f4..a4a0e7f291 100644
--- a/onnxscript/onnx_opset/_impl/opset2.py
+++ b/onnxscript/onnx_opset/_impl/opset2.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402, D411
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -132,7 +131,12 @@ def LpPool(
T_Pad = TypeVar("T_Pad", DOUBLE, FLOAT, FLOAT16)
def Pad(
- self, data: T_Pad, *, mode: str = "constant", pads: Sequence[int], value: float = 0.0
+ self,
+ data: T_Pad,
+ *,
+ mode: str = "constant",
+ pads: Sequence[int],
+ value: float = 0.0,
) -> T_Pad:
r"""[🌐 Pad(2)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-2 "Online Documentation")
diff --git a/onnxscript/onnx_opset/_impl/opset20.py b/onnxscript/onnx_opset/_impl/opset20.py
index e05b5018a4..2f3f264c2a 100644
--- a/onnxscript/onnx_opset/_impl/opset20.py
+++ b/onnxscript/onnx_opset/_impl/opset20.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -513,18 +512,20 @@ def ReduceMax(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceMax", 20, "")
@@ -576,18 +577,20 @@ def ReduceMin(
data: (differentiable) An input tensor.
axes: (optional, non-differentiable) Optional input list of integers, along
- which to reduce. The default is to reduce over all the dimensions of the
- input tensor if 'noop_with_empty_axes' is false, else act as an Identity
- op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1]
- where r = rank(data).
+ which to reduce. The default is to reduce over empty axes. When axes is
+ empty (either not provided or explicitly empty), behavior depends on
+ 'noop_with_empty_axes': reduction over all axes if
+ 'noop_with_empty_axes' is false, or no reduction is applied if
+ 'noop_with_empty_axes' is true (but other operations will be performed).
+ Accepted range is [-r, r-1] where r = rank(data).
keepdims: Keep the reduced dimension or not, default 1 means keep reduced
dimension.
- noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior
- with 'false' is to reduce all axes. When axes is empty and this
- attribute is set to true, input tensor will not be reduced,and the
- output tensor would be equivalent to input tensor.
+ noop_with_empty_axes: Defines behavior when axes is not provided or is
+ empty. If false (default), reduction happens over all axes. If true, no
+ reduction is applied, but other operations will be performed. For
+ example, ReduceSumSquare acts as a vanilla Square.
"""
schema = get_schema("ReduceMin", 20, "")
diff --git a/onnxscript/onnx_opset/_impl/opset21.py b/onnxscript/onnx_opset/_impl/opset21.py
new file mode 100644
index 0000000000..b0ae5a2e9c
--- /dev/null
+++ b/onnxscript/onnx_opset/_impl/opset21.py
@@ -0,0 +1,1940 @@
+# --------------------------------------------------------------------------
+# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
+# ⚙️ Generated by 'python -m opgen'
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+# pylint: disable=W0221,W0222,R0901,W0237
+# mypy: disable-error-code=override
+# ruff: noqa: D214, D402, D405, D411, D412, D416
+# --------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from typing import Optional, Sequence, TypeVar, Union
+
+from onnx import GraphProto, SparseTensorProto, TensorProto
+from onnx.defs import get_schema
+from typing_extensions import TypeAlias
+
+from onnxscript.onnx_opset._impl.opset20 import Opset20
+from onnxscript.onnx_types import (
+ BFLOAT16,
+ BOOL,
+ COMPLEX64,
+ COMPLEX128,
+ DOUBLE,
+ FLOAT,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT16,
+ INT4,
+ INT8,
+ INT16,
+ INT32,
+ INT64,
+ STRING,
+ UINT4,
+ UINT8,
+ UINT16,
+ UINT32,
+ UINT64,
+)
+from onnxscript.values import Op, Opset
+
+
+class Opset21(Opset20):
+ def __new__(cls):
+ return Opset.__new__(cls, "", 21)
+
+ T1_Cast = TypeVar(
+ "T1_Cast",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T2_Cast: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def Cast(self, input: T1_Cast, *, saturate: int = 1, to: int) -> T2_Cast:
+ r"""[🌐 Cast(21)](https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 "Online Documentation")
+
+
+ The operator casts the elements of a given input tensor to a data type
+ specified by the 'to' argument and returns an output tensor of the same size in
+ the converted type. The 'to' argument must be one of the data types specified
+ in the 'DataType' enum field in the TensorProto message.
+
+ Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations
+ (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may
+ yield result 100. There are some string literals reserved for special floating-point values;
+ "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively.
+ Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,
+ this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors
+ to string tensors, plain floating-point representation (such as "314.15926") would be used.
+ Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases
+ of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior.
+
+ Conversion from a numerical type to any numerical type is always allowed.
+ User must be aware of precision loss and value change caused by range difference between two types.
+ For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting
+ an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type.
+
+ In more detail, the conversion among numerical types should follow these rules
+ if the destination type is not a float 8 type.
+
+ * Casting from floating point to:
+ * floating point: +/- infinity if OOR (out of range).
+ * fixed point: undefined if OOR.
+ * bool: +/- 0.0 to False; all else to True.
+ * Casting from fixed point to:
+ * floating point: +/- infinity if OOR. (+ infinity in the case of uint)
+ * fixed point: when OOR, discard higher bits and reinterpret (with respect to two's complement representation for
+ signed types). For example, 200 (int16) -> -56 (int8).
+ * bool: zero to False; nonzero to True.
+ * Casting from bool to:
+ * floating point: `{1.0, 0.0}`.
+ * fixed point: `{1, 0}`.
+ * bool: no change.
+
+ Float 8 type were introduced to speed up the training of
+ deep models. By default the conversion of a float *x* obeys
+ to the following rules. `[x]` means the value rounded to
+ the target mantissa width.
+
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | -------- | -------- | -------- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | Inf | FLT_MAX | NaN | FLT_MAX | NaN |
+ | -Inf | -FLT_MAX | NaN | -FLT_MAX | NaN |
+ | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX |
+ | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX |
+ | else | RNE | RNE | RNE | RNE |
+
+ The behavior changes if the parameter 'saturate' is set to False.
+ The rules then become:
+
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | ------ | -------- | ---- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | -NaN | -NaN | NaN | -NaN | NaN |
+ | Inf | NaN | NaN | Inf | NaN |
+ | -Inf | -NaN | NaN | -Inf | NaN |
+ | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN |
+ | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN |
+ | else | RNE | RNE | RNE | RNE |
+
+
+ Args:
+ input: (differentiable) Input tensor to be cast.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz).
+ It is true by default. All cases are fully described in two tables
+ inserted in the operator description.
+
+ to: The data type to which the elements of the input tensor are cast.
+ Strictly must be one of the types from DataType enum in TensorProto
+ """
+
+ schema = get_schema("Cast", 21, "")
+ op = Op(self, "Cast", schema)
+ return op(*self._prepare_inputs(schema, input), saturate=saturate, to=to)
+
+ T1_CastLike = TypeVar(
+ "T1_CastLike",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T2_CastLike = TypeVar(
+ "T2_CastLike",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def CastLike(
+ self, input: T1_CastLike, target_type: T2_CastLike, *, saturate: int = 1
+ ) -> T2_CastLike:
+ r"""[🌐 CastLike(21)](https://onnx.ai/onnx/operators/onnx__CastLike.html#castlike-21 "Online Documentation")
+
+
+ The operator casts the elements of a given input tensor (the first input) to
+ the same data type as the elements of the second input tensor.
+ See documentation of the Cast operator for further details.
+
+
+ Args:
+ input: (differentiable) Input tensor to be cast.
+
+ target_type: (non-differentiable) The (first) input tensor will be cast to
+ produce a tensor of the same type as this (second input) tensor.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz).
+ It is true by default. Please refer to operator Cast description for
+ further details.
+ """
+
+ schema = get_schema("CastLike", 21, "")
+ op = Op(self, "CastLike", schema)
+ return op(*self._prepare_inputs(schema, input, target_type), saturate=saturate)
+
+ T_Constant: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def Constant(
+ self,
+ *,
+ sparse_value: Optional[SparseTensorProto] = None,
+ value: Optional[TensorProto] = None,
+ value_float: Optional[float] = None,
+ value_floats: Optional[Sequence[float]] = None,
+ value_int: Optional[int] = None,
+ value_ints: Optional[Sequence[int]] = None,
+ value_string: Optional[str] = None,
+ value_strings: Optional[Sequence[str]] = None,
+ ) -> T_Constant:
+ r"""[🌐 Constant(21)](https://onnx.ai/onnx/operators/onnx__Constant.html#constant-21 "Online Documentation")
+
+
+ This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value,
+ or value_* must be specified.
+
+
+ Args:
+ sparse_value: The value for the elements of the output tensor in sparse
+ format.
+
+ value: The value for the elements of the output tensor.
+
+ value_float: The value for the sole element for the scalar, float32, output
+ tensor.
+
+ value_floats: The values for the elements for the 1D, float32, output
+ tensor.
+
+ value_int: The value for the sole element for the scalar, int64, output
+ tensor.
+
+ value_ints: The values for the elements for the 1D, int64, output tensor.
+
+ value_string: The value for the sole element for the scalar, UTF-8 string,
+ output tensor.
+
+ value_strings: The values for the elements for the 1D, UTF-8 string, output
+ tensor.
+ """
+
+ schema = get_schema("Constant", 21, "")
+ op = Op(self, "Constant", schema)
+ return op(
+ sparse_value=sparse_value,
+ value=value,
+ value_float=value_float,
+ value_floats=value_floats,
+ value_int=value_int,
+ value_ints=value_ints,
+ value_string=value_string,
+ value_strings=value_strings,
+ )
+
+ T1_ConstantOfShape: TypeAlias = INT64
+
+ T2_ConstantOfShape: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def ConstantOfShape(
+ self, input: T1_ConstantOfShape, *, value: Optional[TensorProto] = None
+ ) -> T2_ConstantOfShape:
+ r"""[🌐 ConstantOfShape(21)](https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html#constantofshape-21 "Online Documentation")
+
+
+ Generate a tensor with given value and shape.
+
+
+ Args:
+ input: 1D tensor. The shape of the expected output tensor. If empty tensor
+ is given, the output would be a scalar. All values must be >= 0.
+
+ value: (Optional) The value of the output elements.Should be a one-element
+ tensor. If not specified, it defaults to a tensor of value 0 and
+ datatype float32
+ """
+
+ schema = get_schema("ConstantOfShape", 21, "")
+ op = Op(self, "ConstantOfShape", schema)
+ return op(*self._prepare_inputs(schema, input), value=value)
+
+ T1_DequantizeLinear = TypeVar(
+ "T1_DequantizeLinear",
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT8,
+ UINT16,
+ UINT4,
+ UINT8,
+ )
+
+ T2_DequantizeLinear = TypeVar("T2_DequantizeLinear", BFLOAT16, FLOAT, FLOAT16)
+
+ def DequantizeLinear(
+ self,
+ x: T1_DequantizeLinear,
+ x_scale: T2_DequantizeLinear,
+ x_zero_point: Optional[T1_DequantizeLinear] = None,
+ *,
+ axis: int = 1,
+ block_size: int = 0,
+ ) -> T2_DequantizeLinear:
+ r"""[🌐 DequantizeLinear(21)](https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html#dequantizelinear-21 "Online Documentation")
+
+
+ The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the
+ full-precision tensor. The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point`
+ must have the same shape, determining the quantization's granularity: a scalar for per-tensor/per-layer quantization,
+ a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization.
+ See QuantizeLinear for details on quantization granularity.
+ `x_zero_point` and `x` must have the same type. `x` and `y` must have the same shape. In the case of dequantizing
+ `int32`, there's no zero point (zero point is supposed to be 0).
+ `zero-point` is usually not used in the case of float8 types quantization, but the dequantization formula remains the same
+ for consistency, and `x_scale` still determines the output type.
+
+
+ Args:
+ x: N-D quantized input tensor to be de-quantized.
+
+ x_scale: Scale for input `x`. For per-tensor/layer dequantization the scale
+ is a scalar, for per per-axis dequantization it is a 1-D Tensor and for
+ blocked dequantization it has the same shape as the input, except for
+ one dimension in which blocking is performed.
+
+ x_zero_point: (optional) Zero point for input `x`. Shape must match x_scale.
+ It's optional. Zero point is 0 when it's not specified.
+
+ axis: (Optional) The axis of the dequantizing dimension of the input tensor.
+ Used for per-axis and blocked quantization. Negative value means
+ counting dimensions from the back. Accepted range is `[-r, r-1]` where
+ `r = rank(input)`.
+
+ block_size: (Optional) The size of the quantization block (number of times
+ every scale is replicated). Used only for blocked quantization. The
+ block size is a positive integer. Given `x` shape `(D0, ..., Di, ...,
+ Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted
+ range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]`
+ """
+
+ schema = get_schema("DequantizeLinear", 21, "")
+ op = Op(self, "DequantizeLinear", schema)
+ return op(
+ *self._prepare_inputs(schema, x, x_scale, x_zero_point),
+ axis=axis,
+ block_size=block_size,
+ )
+
+ T_Flatten = TypeVar(
+ "T_Flatten",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Flatten(self, input: T_Flatten, *, axis: int = 1) -> T_Flatten:
+ r"""[🌐 Flatten(21)](https://onnx.ai/onnx/operators/onnx__Flatten.html#flatten-21 "Online Documentation")
+
+
+ Flattens the input tensor into a 2D matrix. If input tensor has shape
+ (d_0, d_1, ... d_n) then the output will have shape
+ (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn).
+
+
+ Args:
+ input: (differentiable) A tensor of rank >= axis.
+
+ axis: Indicate up to which input dimensions (exclusive) should be flattened
+ to the outer dimension of the output. The value for axis must be in the
+ range [-r, r], where r is the rank of the input tensor. Negative value
+ means counting dimensions from the back. When axis = 0, the shape of the
+ output tensor is (1, (d_0 X d_1 ... d_n), where the shape of the input
+ tensor is (d_0, d_1, ... d_n).
+ """
+
+ schema = get_schema("Flatten", 21, "")
+ op = Op(self, "Flatten", schema)
+ return op(*self._prepare_inputs(schema, input), axis=axis)
+
+ T_GroupNormalization = TypeVar("T_GroupNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def GroupNormalization(
+ self,
+ X: T_GroupNormalization,
+ scale: T_GroupNormalization,
+ bias: T_GroupNormalization,
+ *,
+ epsilon: float = 9.999999747378752e-06,
+ num_groups: int,
+ stash_type: int = 1,
+ ) -> T_GroupNormalization:
+ r"""[🌐 GroupNormalization(21)](https://onnx.ai/onnx/operators/onnx__GroupNormalization.html#groupnormalization-21 "Online Documentation")
+
+
+ A GroupNormalization function. Carries out group normalization as described in
+ the paper https://arxiv.org/abs/1803.08494
+
+ This operator transforms input according to
+ ::
+
+ y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
+
+
+ where the mean and variance are computed per instance per group of channels, and
+ `scale` and `bias` should be specified for each channel. The number of
+ groups `num_groups` should be divisible by the number of channels so that there are
+ an equal number of channels per group.
+
+ The overall computation has two stages: the first stage normalizes the elements to
+ have zero mean and unit variance for each instance in each group, and the second
+ stage scales and shifts the results of the first stage. The floating-point precision
+ used in the first stage is determined by the `stash_type` attribute. For example,
+ if `stash_type` is 1, the operator casts all input variables to 32-bit float,
+ performs the computation, and finally casts the normalized results back to the
+ original type of `X`. The second stage does not depend on `stash_type`.
+
+ When the number of groups is the same as the number of channels, this operator is
+ equivalent to InstanceNormalization. When there is only one group, this operator
+ is equivalent to LayerNormalization.
+
+
+ Args:
+ X: (differentiable) Input data tensor. Dimensions for image cases are `(N x
+ C x H x W)`, where `N` is the batch size, `C` is the number of channels,
+ and `H` and `W` are the height and width of the data. Statistics are
+ computed for every group of channels over `C`, `H`, and `W`. For
+ non-image cases, the dimensions are in the form of `(N x C x D1 x D2 ...
+ Dn)`.
+
+ scale: (differentiable) Scale tensor of shape `(C)`.
+
+ bias: (differentiable) Bias tensor of shape `(C)`.
+
+ epsilon: The epsilon value to use to avoid division by zero.
+
+ num_groups: The number of groups of channels. It should be a divisor of the
+ number of channels `C`.
+
+ stash_type: The floating-point precision used in stage one of the
+ computation.
+ """
+
+ schema = get_schema("GroupNormalization", 21, "")
+ op = Op(self, "GroupNormalization", schema)
+ return op(
+ *self._prepare_inputs(schema, X, scale, bias),
+ epsilon=epsilon,
+ num_groups=num_groups,
+ stash_type=stash_type,
+ )
+
+ V_Identity = TypeVar(
+ "V_Identity",
+ Optional[Sequence[BOOL]],
+ Optional[Sequence[COMPLEX128]],
+ Optional[Sequence[COMPLEX64]],
+ Optional[Sequence[DOUBLE]],
+ Optional[Sequence[FLOAT]],
+ Optional[Sequence[FLOAT16]],
+ Optional[Sequence[INT16]],
+ Optional[Sequence[INT32]],
+ Optional[Sequence[INT64]],
+ Optional[Sequence[INT8]],
+ Optional[Sequence[STRING]],
+ Optional[Sequence[UINT16]],
+ Optional[Sequence[UINT32]],
+ Optional[Sequence[UINT64]],
+ Optional[Sequence[UINT8]],
+ Optional[BOOL],
+ Optional[COMPLEX128],
+ Optional[COMPLEX64],
+ Optional[DOUBLE],
+ Optional[FLOAT],
+ Optional[FLOAT16],
+ Optional[INT16],
+ Optional[INT32],
+ Optional[INT64],
+ Optional[INT8],
+ Optional[STRING],
+ Optional[UINT16],
+ Optional[UINT32],
+ Optional[UINT64],
+ Optional[UINT8],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Identity(self, input: V_Identity) -> V_Identity:
+ r"""[🌐 Identity(21)](https://onnx.ai/onnx/operators/onnx__Identity.html#identity-21 "Online Documentation")
+
+ Identity operator
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Identity", 21, "")
+ op = Op(self, "Identity", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ B_If: TypeAlias = BOOL
+
+ V_If: TypeAlias = Union[
+ None,
+ Sequence[BFLOAT16],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ Sequence[FLOAT8E4M3FN],
+ Sequence[FLOAT8E4M3FNUZ],
+ Sequence[FLOAT8E5M2],
+ Sequence[FLOAT8E5M2FNUZ],
+ Sequence[INT4],
+ Sequence[UINT4],
+ ]
+
+ def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If:
+ r"""[🌐 If(21)](https://onnx.ai/onnx/operators/onnx__If.html#if-21 "Online Documentation")
+
+ If conditional
+
+ Args:
+ cond: Condition for the if. The tensor must contain a single element.
+
+ else_branch: Graph to run if condition is false. Has N outputs: values you
+ wish to be live-out to the enclosing scope. The number of outputs must
+ match the number of outputs in the then_branch.
+
+ then_branch: Graph to run if condition is true. Has N outputs: values you
+ wish to be live-out to the enclosing scope. The number of outputs must
+ match the number of outputs in the else_branch.
+ """
+
+ schema = get_schema("If", 21, "")
+ op = Op(self, "If", schema)
+ return op(
+ *self._prepare_inputs(schema, cond),
+ else_branch=else_branch,
+ then_branch=then_branch,
+ )
+
+ I_Loop: TypeAlias = INT64
+
+ B_Loop: TypeAlias = BOOL
+
+ V_Loop = TypeVar(
+ "V_Loop",
+ Optional[Sequence[BFLOAT16]],
+ Optional[Sequence[BOOL]],
+ Optional[Sequence[COMPLEX128]],
+ Optional[Sequence[COMPLEX64]],
+ Optional[Sequence[DOUBLE]],
+ Optional[Sequence[FLOAT]],
+ Optional[Sequence[FLOAT16]],
+ Optional[Sequence[INT16]],
+ Optional[Sequence[INT32]],
+ Optional[Sequence[INT64]],
+ Optional[Sequence[INT8]],
+ Optional[Sequence[STRING]],
+ Optional[Sequence[UINT16]],
+ Optional[Sequence[UINT32]],
+ Optional[Sequence[UINT64]],
+ Optional[Sequence[UINT8]],
+ Optional[BFLOAT16],
+ Optional[BOOL],
+ Optional[COMPLEX128],
+ Optional[COMPLEX64],
+ Optional[DOUBLE],
+ Optional[FLOAT],
+ Optional[FLOAT16],
+ Optional[FLOAT8E4M3FN],
+ Optional[FLOAT8E4M3FNUZ],
+ Optional[FLOAT8E5M2],
+ Optional[FLOAT8E5M2FNUZ],
+ Optional[INT16],
+ Optional[INT32],
+ Optional[INT4],
+ Optional[INT64],
+ Optional[INT8],
+ Optional[STRING],
+ Optional[UINT16],
+ Optional[UINT32],
+ Optional[UINT4],
+ Optional[UINT64],
+ Optional[UINT8],
+ Sequence[BFLOAT16],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[FLOAT8E4M3FN],
+ Sequence[FLOAT8E4M3FNUZ],
+ Sequence[FLOAT8E5M2],
+ Sequence[FLOAT8E5M2FNUZ],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT4],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT4],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Loop(
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
+ ) -> V_Loop:
+ r"""[🌐 Loop(21)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-21 "Online Documentation")
+
+
+ Generic Looping construct. This loop has multiple termination conditions:
+
+ 1) Trip count. Iteration count specified at runtime. Set by
+ specifying the input M. Optional. Set to empty string to omit.
+ Note that a static trip count (specified at graph construction time) can be
+ specified by passing in a constant node for input M.
+ 2) Loop termination condition. This is an input to the op that determines
+ whether to run the first iteration and also a loop-carried dependency for
+ the body graph. The body graph must yield a value for the condition variable,
+ whether this input is provided or not.
+
+ This table summarizes the operating modes of this operator with equivalent
+ C-style code:
+
+ Operator inputs defined as (max_trip_count, condition_var).
+
+ * input ("", ""):
+ for (int i=0; ; ++i) {
+ cond = ... // Note this value is ignored, but is required in the body
+ }
+
+ * input ("", cond) // Note this is analogous to a while loop
+ bool cond = ...;
+ for (int i=0; cond; ++i) {
+ cond = ...;
+ }
+
+ * input ("", 1) // Note this is analogous to a do-while loop
+ bool cond = true
+ for (int i=0; cond; ++i) {
+ cond = ...;
+ }
+
+ * input (trip_count, "") // Note this is analogous to a for loop
+ int trip_count = ...
+ for (int i=0; i < trip_count; ++i) {
+ cond = ...; // ignored
+ }
+
+ * input (trip_count, cond)
+ int trip_count = ...;
+ bool cond = ...;
+ for (int i=0; i < trip_count && cond; ++i) {
+ cond = ...;
+ }
+
+
+ *Sample usage - cond as well as trip count*
+
+ graph predict-net {
+ %a = Constant[value = ]()
+ %b = Constant[value = ]()
+ %keepgoing = Constant[value = ]()
+ %max_trip_count = Constant[value = ]()
+ %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b)
+ return
+ }
+
+ graph body-net (
+ %i[INT32, scalar] // iteration number
+ %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used
+ %b_in[INT32, scalar] // incoming value of loop-carried-dependency b
+ ) {
+ %my_local = Add(%a, %b_in)
+ %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b
+ %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition
+ %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated
+ return %keepgoing_out, %b_out, %user_defined_val
+ }
+
+ *Sample equivalent C code*
+
+ {
+ /* User-defined code (enclosing scope) */
+ int a = 3, b = 6;
+ bool keepgoing = true; // Analogous to input cond
+ /* End user-defined code */
+
+ /* Implicitly-defined code */
+ const int max_trip_count = 10; // Analogous to input M
+ int user_defined_vals[]; // Imagine this is resizable
+ /* End implicitly-defined code */
+ /* initialize loop-carried variables and scan-output variables */
+ bool keepgoing_out = keepgoing
+ int b_out = b
+
+ for (int i=0; i < max_trip_count && keepgoing_out; ++i) {
+ /* Implicitly-defined code: bind actual parameter values
+ to formal parameter variables of loop-body */
+ bool keepgoing_in = keepgoing_out;
+ bool b_in = b_out;
+
+ /* User-defined code (loop body) */
+ int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine
+ b_out = a - b_in;
+ keepgoing_out = my_local > b_out;
+ user_defined_val = b_in + b_in; // b_in and b_out are different variables
+ /* End user-defined code */
+
+ /* Implicitly defined-code */
+ user_defined_vals[i] = user_defined_val // accumulate scan-output values
+ }
+ // int t = my_local; // Can't do this. my_local is not accessible here.
+
+ // The values below are bound to the output variables of the loop and therefore accessible
+ // b_out; user_defined_vals; keepgoing_out;
+ }
+
+ There are several things of note in this code snippet:
+
+ 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can
+ be referenced in the inputs of the loop.
+ 2) Any values computed in the loop body that needs to be used in a subsequent
+ iteration or after the loop are modelled using a pair of variables in the loop-body,
+ consisting of an input variable (eg., b_in) and an output variable (eg., b_out).
+ These are referred to as loop-carried dependences. The loop operation node
+ supplies the input value of the input variable for the first iteration, and
+ returns the output value of the output variable produced by the final
+ iteration.
+ 3) Scan_output variables are used to implicitly concatenate values computed across
+ all the iterations. In the above example, the value of user_defined_val computed
+ over all iterations are concatenated and returned as the value of user_defined_vals
+ after the loop.
+ 4) Values created in the body cannot be accessed in the enclosing scope,
+ except using the mechanism described above.
+
+ Note that the semantics of this op support "diagonal" or "wavefront" execution.
+ (See Step 3 here for an example:
+ https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/).
+ Frontends should emit multi-layer RNNs as a series of While operators (with
+ time being the inner looping dimension), with each successive layer consuming
+ the scan_outputs from the previous layer, possibly going through several
+ point-wise operators (e.g. dropout, residual connections, linear layer).
+
+ The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order.
+
+
+ Args:
+ M: (optional) A maximum trip-count for the loop specified at runtime.
+ Optional. Pass empty string to skip.
+
+ cond: (optional) A boolean termination condition. Optional. Pass empty
+ string to skip.
+
+ v_initial: (variadic, heterogeneous) The initial values of any loop-carried
+ dependencies (values that change across loop iterations)
+
+ body: The graph run each iteration. It has 2+N inputs: (iteration_num,
+ condition, loop carried dependencies...). It has 1+N+K outputs:
+ (condition, loop carried dependencies..., scan_outputs...). Each
+ scan_output is created by concatenating the value of the specified
+ output value at the end of each iteration of the loop. It is an error if
+ the dimensions or data type of these scan_outputs change across loop
+ iterations.
+ """
+
+ schema = get_schema("Loop", 21, "")
+ op = Op(self, "Loop", schema)
+ return op(*self._prepare_inputs(schema, M, cond, *v_initial), body=body)
+
+ T_Pad = TypeVar(
+ "T_Pad",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ Tind_Pad = TypeVar("Tind_Pad", INT32, INT64)
+
+ def Pad(
+ self,
+ data: T_Pad,
+ pads: INT64,
+ constant_value: Optional[T_Pad] = None,
+ axes: Optional[Tind_Pad] = None,
+ *,
+ mode: str = "constant",
+ ) -> T_Pad:
+ r"""[🌐 Pad(21)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-21 "Online Documentation")
+
+
+ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`,
+ a padded tensor (`output`) is generated.
+
+ The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`):
+
+ 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False)
+
+ 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis
+
+ 3) `edge` - pads with the edge values of array
+
+ 4) `wrap` - wrap-around padding as if the data tensor forms a torus
+
+
+ Example 1 (`constant` mode):
+
+ Insert 0 pads to the beginning of the second dimension.
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'constant'
+
+ constant_value = 0.0
+
+ output = [
+ [0.0, 0.0, 1.0, 1.2],
+ [0.0, 0.0, 2.3, 3.4],
+ [0.0, 0.0, 4.5, 5.7],
+ ]
+
+
+
+ Example 2 (`reflect` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'reflect'
+
+ output = [
+ [1.0, 1.2, 1.0, 1.2],
+ [2.3, 3.4, 2.3, 3.4],
+ [4.5, 5.7, 4.5, 5.7],
+ ]
+
+
+
+ Example 3 (`edge` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'edge'
+
+ output = [
+ [1.0, 1.0, 1.0, 1.2],
+ [2.3, 2.3, 2.3, 3.4],
+ [4.5, 4.5, 4.5, 5.7],
+ ]
+
+
+
+ Example 4 (`wrap` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [2, 1, 1, 1]
+
+ mode = 'wrap'
+
+ output = [
+ [3.4, 2.3, 3.4, 2.3],
+ [5.7, 4.5, 5.7, 4.5],
+ [1.2, 1.0, 1.2, 1.0],
+ [3.4, 2.3, 3.4, 2.3],
+ [5.7, 4.5, 5.7, 4.5],
+ [1.2, 1.0, 1.2, 1.0],
+ ]
+
+
+
+
+ Args:
+ data: (differentiable) Input tensor.
+
+ pads: (non-differentiable) Tensor of integers indicating the number of
+ padding elements to add or remove (if negative) at the beginning and end
+ of each axis. For 2D input tensor, it is the number of pixels. `pads`
+ should be a 1D tensor of shape [2 * num_axes] where `num_axes` refers to
+ the number of elements in the `axes` input or the input rank if `axes`
+ are not provided explicitly. `pads` format should be: [x1_begin,
+ x2_begin, ..., x1_end, x2_end,...], where xi_begin is the number of pad
+ values added at the beginning of axis `axes[i]` and xi_end, the number
+ of pad values added at the end of axis `axes[i]`.
+
+ constant_value: (optional, non-differentiable) (Optional) A scalar value to
+ be used if the mode chosen is `constant` (by default it is 0, empty
+ string or False).
+
+ axes: (optional, non-differentiable) 1-D tensor of axes that `pads` apply
+ to. Negative value means counting dimensions from the back. Accepted
+ range is [-r, r-1] where r = rank(data). Behavior is undefined if an
+ axis is repeated. If not provided, all axes are assumed (`[0, 1, ...,
+ input_rank-1]`).
+
+ mode: Supported modes: `constant`(default), `reflect`, `edge`, `wrap`
+ """
+
+ schema = get_schema("Pad", 21, "")
+ op = Op(self, "Pad", schema)
+ return op(*self._prepare_inputs(schema, data, pads, constant_value, axes), mode=mode)
+
+ T1_QLinearMatMul = TypeVar(
+ "T1_QLinearMatMul",
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT8,
+ UINT8,
+ )
+
+ TS_QLinearMatMul = TypeVar("TS_QLinearMatMul", BFLOAT16, FLOAT, FLOAT16)
+
+ T2_QLinearMatMul = TypeVar(
+ "T2_QLinearMatMul",
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT8,
+ UINT8,
+ )
+
+ T3_QLinearMatMul = TypeVar(
+ "T3_QLinearMatMul",
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT8,
+ UINT8,
+ )
+
+ def QLinearMatMul(
+ self,
+ a: T1_QLinearMatMul,
+ a_scale: TS_QLinearMatMul,
+ a_zero_point: T1_QLinearMatMul,
+ b: T2_QLinearMatMul,
+ b_scale: TS_QLinearMatMul,
+ b_zero_point: T2_QLinearMatMul,
+ y_scale: TS_QLinearMatMul,
+ y_zero_point: T3_QLinearMatMul,
+ ) -> T3_QLinearMatMul:
+ r"""[🌐 QLinearMatMul(21)](https://onnx.ai/onnx/operators/onnx__QLinearMatMul.html#qlinearmatmul-21 "Online Documentation")
+
+
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
+ It consumes two quantized input tensors, their scales and zero points, scale and zero point of output,
+ and computes the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point).
+ For (x / y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
+ Scale and zero point must have same shape. They must be either scalar (per tensor) or N-D tensor
+ (per row for 'a' and per column for 'b'). Scalar refers to per tensor quantization whereas N-D refers to per row
+ or per column quantization. If the input is 2D of shape [M, K] then zero point and scale tensor may be
+ an M element vector [v_1, v_2, ..., v_M] for per row quantization and K element vector of shape [v_1, v_2, ..., v_K]
+ for per column quantization. If the input is N-D tensor with shape [D1, D2, M, K] then zero point and scale tensor may
+ have shape [D1, D2, M, 1] for per row quantization and shape [D1, D2, 1, K] for per column quantization.
+ Production must never overflow, and accumulation may overflow if and only if in 32 bits.
+
+
+ Args:
+ a: (non-differentiable) N-dimensional quantized matrix a
+
+ a_scale: (non-differentiable) scale of quantized input a
+
+ a_zero_point: (non-differentiable) zero point of quantized input a
+
+ b: (non-differentiable) N-dimensional quantized matrix b
+
+ b_scale: (non-differentiable) scale of quantized input b
+
+ b_zero_point: (non-differentiable) zero point of quantized input b
+
+ y_scale: (non-differentiable) scale of quantized output y
+
+ y_zero_point: (non-differentiable) zero point of quantized output y
+ """
+
+ schema = get_schema("QLinearMatMul", 21, "")
+ op = Op(self, "QLinearMatMul", schema)
+ return op(
+ *self._prepare_inputs(
+ schema,
+ a,
+ a_scale,
+ a_zero_point,
+ b,
+ b_scale,
+ b_zero_point,
+ y_scale,
+ y_zero_point,
+ )
+ )
+
+ T1_QuantizeLinear = TypeVar("T1_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32)
+
+ T2_QuantizeLinear = TypeVar(
+ "T2_QuantizeLinear",
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT4,
+ INT8,
+ UINT16,
+ UINT4,
+ UINT8,
+ )
+
+ def QuantizeLinear(
+ self,
+ x: T1_QuantizeLinear,
+ y_scale: T1_QuantizeLinear,
+ y_zero_point: Optional[T2_QuantizeLinear] = None,
+ *,
+ axis: int = 1,
+ block_size: int = 0,
+ output_dtype: int = 0,
+ saturate: int = 1,
+ ) -> T2_QuantizeLinear:
+ r"""[🌐 QuantizeLinear(21)](https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html#quantizelinear-21 "Online Documentation")
+
+
+ The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the
+ low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization
+ granularity. The quantization formula is `y = saturate((x / y_scale) + y_zero_point)`.
+ Saturation is done according to:
+ - uint16: [0, 65535]
+ - int16: [-32768, 32767]
+ - uint8: [0, 255]
+ - int8: [-128, 127]
+ - uint4: [0, 15]
+ - int4: [-8, 7]
+ For `(x / y_scale)`, it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
+ `y_zero_point` and `y` must have the same type. `y_zero_point` is usually not used for quantization to float8 types, but the quantization
+ formula remains the same for consistency, and the type of the attribute `y_zero_point` still determines the quantization type.
+ There are three supported quantization granularities, determined by the shape of `y_scale`.
+ In all cases, `y_zero_point` must have the same shape as `y_scale`.
+ - Per-tensor (per-layer) quantization: `y_scale` is a scalar.
+ - Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape
+ `(D0, ..., Di, ..., Dn)` and `axis=i`, `y_scale` is a 1-D tensor of length `Di`.
+ - Blocked quantization: The scale's shape is identical to the input's shape, except for one dimension, in which
+ blocking is performed. Given `x` shape `(D0, ..., Di, ..., Dn)`, `axis=i`, and block size `B`: `y_scale` shape is
+ `(D0, ..., ceil(Di/B), ..., Dn)`.
+
+
+ Args:
+ x: N-D full precision Input tensor to be quantized.
+
+ y_scale: Scale for doing quantization to get `y`. For per-tensor/layer
+ quantization the scale is a scalar, for per-axis quantization it is a
+ 1-D Tensor and for blocked quantization it has the same shape as the
+ input, except for one dimension in which blocking is performed.
+
+ y_zero_point: (optional) Zero point for doing quantization to get `y`. Shape
+ must match `y_scale`.Default is uint8 with zero point of 0 if it's not
+ specified.
+
+ axis: (Optional) The axis of the dequantizing dimension of the input tensor.
+ Used only for per-axis and blocked quantization. Negative value means
+ counting dimensions from the back. Accepted range is `[-r, r-1]` where
+ `r = rank(input)`. When the rank of the input is 1, per-tensor
+ quantization is applied, rendering the axis unnecessary in this
+ scenario.
+
+ block_size: (Optional) The size of the quantization block (number of times
+ every scale is replicated). Used only for blocked quantization. The
+ block size is a positive integer. Given `x` shape `(D0, ..., Di, ...,
+ Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted
+ range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]`
+
+ output_dtype: (Optional) The output data type. If not supplied, the output
+ data type is inferred from `y_zero_point` data type (`T2`). If neither
+ `output_dtype` nor `y_zero_point` are supplied, output data type is
+ uint8. If both `output_dtype` and `y_zero_point` are specified,
+ `output_dtype` must be `T2`.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ quantization (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz).
+ It is true by default. All cases are fully described in two tables
+ inserted in the operator description.
+ """
+
+ schema = get_schema("QuantizeLinear", 21, "")
+ op = Op(self, "QuantizeLinear", schema)
+ return op(
+ *self._prepare_inputs(schema, x, y_scale, y_zero_point),
+ axis=axis,
+ block_size=block_size,
+ output_dtype=output_dtype,
+ saturate=saturate,
+ )
+
+ T_Reshape = TypeVar(
+ "T_Reshape",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Reshape(self, data: T_Reshape, shape: INT64, *, allowzero: int = 0) -> T_Reshape:
+ r"""[🌐 Reshape(21)](https://onnx.ai/onnx/operators/onnx__Reshape.html#reshape-21 "Online Documentation")
+
+
+ Reshape the input tensor similar to numpy.reshape.
+ First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor.
+ At most one dimension of the new shape can be -1. In this case, the value is
+ inferred from the size of the tensor and the remaining dimensions. A dimension
+ could also be 0, in which case the actual dimension value is unchanged (i.e. taken
+ from the input tensor). If 'allowzero' is set, and the new shape includes 0, the
+ dimension will be set explicitly to zero (i.e. not taken from input tensor).
+ Shape (second input) could be an empty shape, which means converting to a scalar.
+ The input tensor's shape and the output tensor's shape are required to have the same number of elements.
+
+ If the attribute 'allowzero' is set, it is invalid for the specified shape to
+ contain both a zero value and -1, as the value of the dimension corresponding
+ to -1 cannot be determined uniquely.
+
+
+ Args:
+ data: (differentiable) An input tensor.
+
+ shape: (non-differentiable) Specified shape for output.
+
+ allowzero: (Optional) By default, when any value in the 'shape' input is
+ equal to zero the corresponding dimension value is copied from the input
+ tensor dynamically. allowzero=1 indicates that if any value in the
+ 'shape' input is set to zero, the zero value is honored, similar to
+ NumPy.
+ """
+
+ schema = get_schema("Reshape", 21, "")
+ op = Op(self, "Reshape", schema)
+ return op(*self._prepare_inputs(schema, data, shape), allowzero=allowzero)
+
+ V_Scan = TypeVar(
+ "V_Scan",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Scan(
+ self,
+ *initial_state_and_scan_inputs: V_Scan,
+ body: GraphProto,
+ num_scan_inputs: int,
+ scan_input_axes: Optional[Sequence[int]] = None,
+ scan_input_directions: Optional[Sequence[int]] = None,
+ scan_output_axes: Optional[Sequence[int]] = None,
+ scan_output_directions: Optional[Sequence[int]] = None,
+ ) -> V_Scan:
+ r"""[🌐 Scan(21)](https://onnx.ai/onnx/operators/onnx__Scan.html#scan-21 "Online Documentation")
+
+
+ Scan can be used to iterate over one or more scan_input tensors,
+ constructing zero or more scan_output tensors. It combines ideas from general recurrences,
+ functional programming constructs such as scan, fold, map, and zip, and is intended to enable
+ generalizations of RNN-like constructs for sequence-to-sequence processing.
+ Other tensors (referred to as state_variables here) can be used to carry a state
+ when iterating from one element to another (similar to hidden-state in RNNs, also referred
+ to as loop-carried dependences in the context of loops).
+ Many common usages involve a single scan_input tensor (where functionality
+ similar to scan, fold and map can be obtained). When more than one scan_input is used,
+ a behavior similar to zip is obtained.
+
+ The attribute body must be a graph, specifying the computation to be performed in
+ every iteration. It takes as input the current values of the state_variables and
+ the current iterated element of the scan_inputs. It must return the (updated) values
+ of the state_variables and zero or more scan_output_element tensors. The values of the
+ scan_output_element tensors are concatenated over all the iterations to produce the
+ scan_output values of the scan construct (similar to the concatenated intermediate
+ hidden-state values of RNN-like constructs). All the output tensors (state_variables as
+ well as scan_output_element tensors) are required to have the same shape in each iteration
+ of the loop (a restriction imposed to enable efficient memory allocation).
+
+ Note that the iterated element passed to the body subgraph does not have a sequence
+ axis. It will have a rank one less than the rank of the corresponding scan_input.
+
+ The scan operation returns the final values of the state_variables as well as the
+ scan_outputs.
+
+ The optional attribute scan_input_directions specifies the direction (forward or backward)
+ for each scan input. If this attribute is omitted, all sequences are scanned in the forward
+ direction. A bidirectional scan may be performed by specifying the same tensor input twice
+ in the scan_inputs, once with a forward direction, and once with a backward direction.
+
+ The scan_output of the operation is produced by concatenating the scan_output_element
+ values produced by the body in each iteration. The optional attribute scan_output_directions
+ specifies the direction in which scan_output is constructed (by appending or prepending the
+ scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
+ is omitted, the scan_output_element is appended to the scan_output in each iteration.
+
+ The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
+ If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
+ batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
+ Note that scanning a non-zero axis may be less efficient than scanning axis zero.
+
+ The optional attribute scan_output_axes specifies the axis along which the scan_outputs
+ are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
+ scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
+ value of 1.
+
+ Note that because of the ONNX restriction that only the last parameter of an operator can
+ be variadic, the initial-states and scan-inputs are listed together as one input parameter.
+ Similarly, the final-states and scan-outputs are listed together as one output parameter.
+ The attribute num_scan_inputs indicates the number M of scan-inputs.
+
+ The behavior of
+
+ Scan <
+ num_scan_inputs = m,
+ body = loop-body,
+ scan_input_axes = [axis_1, ..., axis_m]
+ > (init_1, ..., init_n, scan_1, ..., scan_m)
+
+ is equivalent to the following pseudo-code:
+
+ // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
+ // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
+ sequence_length = scan_1.shape[axis_1];
+
+ // initialize state-variables
+ st_1 = init_1; ... st_n = init_n;
+ // initialize scan-output variables: [] denotes an empty tensor
+ scan_out_1 = []; ...; scan_out_k = [];
+ // identify number of iterations:
+
+ // execute loop
+ for (int t = 0; t < sequence_length; ++t) {
+ // generate the scan-input elements: the notation T[t] indicates the sub-tensor
+ // of rank one less than T obtained by indexing T at position t along axis k.
+ si_1 = scan_1[t];
+ ... ;
+ si_m = scan_m[t];
+ // execute loop-body
+ st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
+ // accumulate the scan-output elements
+ scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k);
+ }
+
+ return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
+
+ *Sample usage: Encoding RNN using a Scan*
+
+ The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
+ recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
+ be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
+ %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
+ values are computed in the outer graph, they need to be passed in as extra state_variables.
+
+ graph rnn-encoding {
+ %H_0 = ...
+ %X = ...
+ %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X)
+ return %Y, %Y_h
+ }
+
+ graph rnn-cell-1 (
+ %H_tminus1[FLOAT, tensor]
+ %X_t[FLOAT, tensor]
+ ) {
+ %Wi = ...
+ %Ri = ...
+ %Wbi = ...
+ %Rbi = ...
+ %t1 = X_t * (Wi^T)
+ %t2 = H_tminus1*(Ri^T)
+ %t3 = Add(%t1, %t2)
+ %t4 = Add(%t3, %Wbi)
+ %t5 = Add(%t4, %Rbi)
+ %Ht = Tanh(%t5)
+ %Accumulate = Identity(%Ht)
+ return %Ht, %Accumulate
+ }
+
+
+
+ Args:
+ initial_state_and_scan_inputs: (variadic, heterogeneous) Initial values of
+ the loop's N state variables followed by M scan_inputs
+
+ body: The graph run each iteration. It has N+M inputs: (loop state
+ variables..., scan_input_elts...). It has N+K outputs: (loop state
+ variables..., scan_output_elts...). Each scan_output is created by
+ concatenating the value of the specified scan_output_elt value at the
+ end of each iteration of the loop. It is an error if the dimensions of
+ these values change across loop iterations.
+
+ num_scan_inputs: An attribute specifying the number of scan_inputs M.
+
+ scan_input_axes: An optional list of M flags. The i-th element of the list
+ specifies the axis to be scanned (the sequence axis) for the i-th
+ scan_input. If omitted, 0 will be used as the scan axis for every
+ scan_input. Negative value for an axis means counting dimensions from
+ the back. Accepted range is [-r, r-1] where r = rank(input).
+
+ scan_input_directions: An optional list of M flags. The i-th element of the
+ list specifies the direction to be scanned for the i-th scan_input
+ tensor: 0 indicates forward direction and 1 indicates reverse direction.
+ If omitted, all scan_input tensors will be scanned in the forward
+ direction.
+
+ scan_output_axes: An optional list of K flags. The i-th element of the list
+ specifies the axis for the i-th scan_output. The scan outputs are
+ accumulated along the specified axis. If omitted, 0 will be used as the
+ scan axis for every scan_output. Negative value for an axis means
+ counting dimensions from the back. Accepted range is [-r, r-1].
+
+ scan_output_directions: An optional list of K flags, one for each
+ scan_output. The i-th element of the list specifies whether the i-th
+ scan_output should be constructed by appending or prepending a new value
+ in each iteration: 0 indicates appending and 1 indicates prepending. If
+ omitted, all scan_output tensors will be produced by appending a value
+ in each iteration.
+ """
+
+ schema = get_schema("Scan", 21, "")
+ op = Op(self, "Scan", schema)
+ return op(
+ *self._prepare_inputs(schema, *initial_state_and_scan_inputs),
+ body=body,
+ num_scan_inputs=num_scan_inputs,
+ scan_input_axes=scan_input_axes,
+ scan_input_directions=scan_input_directions,
+ scan_output_axes=scan_output_axes,
+ scan_output_directions=scan_output_directions,
+ )
+
+ T_Shape = TypeVar(
+ "T_Shape",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T1_Shape: TypeAlias = INT64
+
+ def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) -> T1_Shape:
+ r"""[🌐 Shape(21)](https://onnx.ai/onnx/operators/onnx__Shape.html#shape-21 "Online Documentation")
+
+
+ Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor.
+ Optional attributes start and end can be used to compute a slice of the input tensor's shape.
+ If start axis is omitted, the slice starts from axis 0.
+ The end axis, if specified, is exclusive (and the returned value will not include the size of that axis).
+ If the end axis is omitted, the axes upto the last one will be included.
+ Negative axes indicate counting back from the last axis.
+ Note that axes will be clamped to the range [0, r], where r is the
+ rank of the input tensor if they are out-of-range (after adding r in the case of
+ negative axis). Thus, specifying any end value > r is equivalent to specifying an end
+ value of r, and specifying any start value < -r is equivalent to specifying a start
+ value of 0. If start > end, the result will be an empty shape.
+
+ Examples:
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ No attributes specified.
+ Output: [2, 3, 4]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ start: -1
+ Output: [4]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ end: -1
+ Output: [2, 3]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ start: 1
+ end: 2
+ Output: [3]
+
+
+
+
+ Args:
+ data: (non-differentiable) An input tensor.
+
+ end: (Optional) Ending axis for slicing the shape. Negative value means
+ counting dimensions from the back. If omitted, sizes of all axes upto
+ (including) the last one will be included.
+
+ start: (Optional) Starting axis for slicing the shape. Default value is
+ 0.Negative value means counting dimensions from the back.
+ """
+
+ schema = get_schema("Shape", 21, "")
+ op = Op(self, "Shape", schema)
+ return op(*self._prepare_inputs(schema, data), end=end, start=start)
+
+ T_Size = TypeVar(
+ "T_Size",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T1_Size: TypeAlias = INT64
+
+ def Size(self, data: T_Size) -> T1_Size:
+ r"""[🌐 Size(21)](https://onnx.ai/onnx/operators/onnx__Size.html#size-21 "Online Documentation")
+
+
+ Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor.
+
+
+ Args:
+ data: (non-differentiable) An input tensor.
+ """
+
+ schema = get_schema("Size", 21, "")
+ op = Op(self, "Size", schema)
+ return op(*self._prepare_inputs(schema, data))
+
+ T_Squeeze = TypeVar(
+ "T_Squeeze",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Squeeze(self, data: T_Squeeze, axes: Optional[INT64] = None) -> T_Squeeze:
+ r"""[🌐 Squeeze(21)](https://onnx.ai/onnx/operators/onnx__Squeeze.html#squeeze-21 "Online Documentation")
+
+
+ Remove single-dimensional entries from the shape of a tensor.
+ Takes an input `axes` with a list of axes to squeeze.
+ If `axes` is not provided, all the single dimensions will be removed from
+ the shape. If an axis is selected with shape entry not equal to one, an error is raised.
+
+
+ Args:
+ data: (differentiable) Tensors with at least max(dims) dimensions.
+
+ axes: (optional, non-differentiable) List of integers indicating the
+ dimensions to squeeze. Negative value means counting dimensions from the
+ back. Accepted range is [-r, r-1] where r = rank(data).
+ """
+
+ schema = get_schema("Squeeze", 21, "")
+ op = Op(self, "Squeeze", schema)
+ return op(*self._prepare_inputs(schema, data, axes))
+
+ T_Transpose = TypeVar(
+ "T_Transpose",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Transpose(
+ self, data: T_Transpose, *, perm: Optional[Sequence[int]] = None
+ ) -> T_Transpose:
+ r"""[🌐 Transpose(21)](https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 "Online Documentation")
+
+
+ Transpose the input tensor similar to numpy.transpose. For example, when
+ perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape
+ will be (2, 1, 3).
+
+
+ Args:
+ data: (differentiable) An input tensor.
+
+ perm: A list of integers. By default, reverse the dimensions, otherwise
+ permute the axes according to the values given. Its length must be equal
+ to the rank of the input.
+ """
+
+ schema = get_schema("Transpose", 21, "")
+ op = Op(self, "Transpose", schema)
+ return op(*self._prepare_inputs(schema, data), perm=perm)
+
+ T_Unsqueeze = TypeVar(
+ "T_Unsqueeze",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Unsqueeze(self, data: T_Unsqueeze, axes: INT64) -> T_Unsqueeze:
+ r"""[🌐 Unsqueeze(21)](https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#unsqueeze-21 "Online Documentation")
+
+
+ Insert single-dimensional entries to the shape of an input tensor (`data`).
+ Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`).
+
+ For example, given an input tensor (`data`) of shape [3, 4, 5], then
+ Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1].
+
+ The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates.
+ The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`.
+ Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1].
+ The order of values in `axes` does not matter and can come in any order.
+
+
+ Args:
+ data: (differentiable) Original tensor
+
+ axes: (non-differentiable) List of integers indicating the dimensions to be
+ inserted. Negative value means counting dimensions from the back.
+ Accepted range is [-r, r-1] where r = rank(expanded).
+ """
+
+ schema = get_schema("Unsqueeze", 21, "")
+ op = Op(self, "Unsqueeze", schema)
+ return op(*self._prepare_inputs(schema, data, axes))
diff --git a/onnxscript/onnx_opset/_impl/opset22.py b/onnxscript/onnx_opset/_impl/opset22.py
new file mode 100644
index 0000000000..2b1656ed2a
--- /dev/null
+++ b/onnxscript/onnx_opset/_impl/opset22.py
@@ -0,0 +1,2593 @@
+# --------------------------------------------------------------------------
+# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
+# ⚙️ Generated by 'python -m opgen'
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+# pylint: disable=W0221,W0222,R0901,W0237
+# mypy: disable-error-code=override
+# ruff: noqa: E741, D402, D405
+# --------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from typing import Optional, Sequence, Tuple, TypeVar, Union
+
+from onnx.defs import get_schema
+from typing_extensions import TypeAlias
+
+from onnxscript.onnx_opset._impl.opset21 import Opset21
+from onnxscript.onnx_types import (
+ BFLOAT16,
+ BOOL,
+ COMPLEX64,
+ COMPLEX128,
+ DOUBLE,
+ FLOAT,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT16,
+ INT8,
+ INT16,
+ INT32,
+ INT64,
+ STRING,
+ UINT8,
+ UINT16,
+ UINT32,
+ UINT64,
+)
+from onnxscript.values import Op, Opset
+
+
+class Opset22(Opset21):
+ def __new__(cls):
+ return Opset.__new__(cls, "", 22)
+
+ T_Acos = TypeVar("T_Acos", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Acos(self, input: T_Acos) -> T_Acos:
+ r"""[🌐 Acos(22)](https://onnx.ai/onnx/operators/onnx__Acos.html#acos-22 "Online Documentation")
+
+
+ Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Acos", 22, "")
+ op = Op(self, "Acos", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Acosh = TypeVar("T_Acosh", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Acosh(self, input: T_Acosh) -> T_Acosh:
+ r"""[🌐 Acosh(22)](https://onnx.ai/onnx/operators/onnx__Acosh.html#acosh-22 "Online Documentation")
+
+
+ Calculates the hyperbolic arccosine of the given input tensor element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Acosh", 22, "")
+ op = Op(self, "Acosh", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Asin = TypeVar("T_Asin", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Asin(self, input: T_Asin) -> T_Asin:
+ r"""[🌐 Asin(22)](https://onnx.ai/onnx/operators/onnx__Asin.html#asin-22 "Online Documentation")
+
+
+ Calculates the arcsine (inverse of sine) of the given input tensor, element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Asin", 22, "")
+ op = Op(self, "Asin", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Asinh = TypeVar("T_Asinh", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Asinh(self, input: T_Asinh) -> T_Asinh:
+ r"""[🌐 Asinh(22)](https://onnx.ai/onnx/operators/onnx__Asinh.html#asinh-22 "Online Documentation")
+
+
+ Calculates the hyperbolic arcsine of the given input tensor element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Asinh", 22, "")
+ op = Op(self, "Asinh", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Atan = TypeVar("T_Atan", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Atan(self, input: T_Atan) -> T_Atan:
+ r"""[🌐 Atan(22)](https://onnx.ai/onnx/operators/onnx__Atan.html#atan-22 "Online Documentation")
+
+
+ Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Atan", 22, "")
+ op = Op(self, "Atan", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Atanh = TypeVar("T_Atanh", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Atanh(self, input: T_Atanh) -> T_Atanh:
+ r"""[🌐 Atanh(22)](https://onnx.ai/onnx/operators/onnx__Atanh.html#atanh-22 "Online Documentation")
+
+
+ Calculates the hyperbolic arctangent of the given input tensor element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Atanh", 22, "")
+ op = Op(self, "Atanh", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_AveragePool = TypeVar("T_AveragePool", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def AveragePool(
+ self,
+ X: T_AveragePool,
+ *,
+ auto_pad: str = "NOTSET",
+ ceil_mode: int = 0,
+ count_include_pad: int = 0,
+ dilations: Optional[Sequence[int]] = None,
+ kernel_shape: Sequence[int],
+ pads: Optional[Sequence[int]] = None,
+ strides: Optional[Sequence[int]] = None,
+ ) -> T_AveragePool:
+ r"""[🌐 AveragePool(22)](https://onnx.ai/onnx/operators/onnx__AveragePool.html#averagepool-22 "Online Documentation")
+
+
+ AveragePool consumes an input tensor X and applies average pooling across
+ the tensor according to kernel sizes, stride sizes, and pad lengths.
+ average pooling consisting of computing the average on all values of a
+ subset of the input tensor according to the kernel size and downsampling the
+ data into the output tensor Y for further processing. The output spatial shape is calculated differently
+ depending on whether explicit padding is used, where pads is employed, or auto padding is used, where auto_pad is utilized.
+ With explicit padding (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html?highlight=maxpool#torch.nn.MaxPool2d):
+ ```
+ output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)
+ ```
+ or
+ ```
+ output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)
+ ```
+ if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored.
+
+ `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled:
+ ```
+ VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) + 1) / strides_spatial_shape[i])
+ SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i])
+ ```
+ or when ceil_mode is disabled (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AveragePooling2D):
+ ```
+ VALID: output_spatial_shape[i] = floor((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i]) + 1
+ SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = floor((input_spatial_shape[i] - 1) / strides_spatial_shape[i]) + 1
+ ```
+ And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`:
+ ```
+ pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]
+ ```
+ The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero).
+
+
+ Args:
+ X: (differentiable) Input data tensor from the previous operator; dimensions
+ for image case are (N x C x H x W), where N is the batch size, C is the
+ number of channels, and H and W are the height and the width of the
+ data. For non image case, the dimensions are in the form of (N x C x D1
+ x D2 ... Dn), where N is the batch size. Optionally, if dimension
+ denotation is in effect, the operation expects the input data tensor to
+ arrive with the dimension denotation of [DATA_BATCH, DATA_CHANNEL,
+ DATA_FEATURE, DATA_FEATURE ...].
+
+ auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID.
+ Where default value is NOTSET, which means explicit padding is used.
+ SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] =
+ ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is
+ split between the two sides equally or almost equally (depending on
+ whether it is even or odd). In case the padding is an odd number, the
+ extra padding is added at the end for SAME_UPPER and at the beginning
+ for SAME_LOWER.
+
+ ceil_mode: Whether to use ceil or floor (default) to compute the output
+ shape.
+
+ count_include_pad: Whether include pad pixels when calculating values for
+ the edges. Default is 0, doesn't count include pad.
+
+ dilations: Dilation value along each spatial axis of filter. If not present,
+ the dilation defaults to 1 along each spatial axis.
+
+ kernel_shape: The size of the kernel along each axis.
+
+ pads: Padding for the beginning and ending along each spatial axis, it can
+ take any value greater than or equal to 0. The value represent the
+ number of pixels added to the beginning and end part of the
+ corresponding axis. `pads` format should be as follow [x1_begin,
+ x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels
+ added at the beginning of axis `i` and xi_end, the number of pixels
+ added at the end of axis `i`. This attribute cannot be used
+ simultaneously with auto_pad attribute. If not present, the padding
+ defaults to 0 along start and end of each spatial axis.
+
+ strides: Stride along each spatial axis. If not present, the stride defaults
+ to 1 along each spatial axis.
+ """
+
+ schema = get_schema("AveragePool", 22, "")
+ op = Op(self, "AveragePool", schema)
+ return op(
+ *self._prepare_inputs(schema, X),
+ auto_pad=auto_pad,
+ ceil_mode=ceil_mode,
+ count_include_pad=count_include_pad,
+ dilations=dilations,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ T1_Bernoulli = TypeVar("T1_Bernoulli", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T2_Bernoulli: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ ]
+
+ def Bernoulli(
+ self,
+ input: T1_Bernoulli,
+ *,
+ dtype: Optional[int] = None,
+ seed: Optional[float] = None,
+ ) -> T2_Bernoulli:
+ r"""[🌐 Bernoulli(22)](https://onnx.ai/onnx/operators/onnx__Bernoulli.html#bernoulli-22 "Online Documentation")
+
+
+ Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor
+ containing probabilities p (a value in the range [0,1]) to be used for drawing the binary random number,
+ where an output of 1 is produced with probability p and an output of 0 is produced with probability (1-p).
+
+ This operator is non-deterministic and may not produce the same values in different
+ implementations (even if a seed is specified).
+
+
+ Args:
+ input: All values in input have to be in the range:[0, 1].
+
+ dtype: The data type for the elements of the output tensor. if not
+ specified, we will use the data type of the input tensor.
+
+ seed: (Optional) Seed to the random generator, if not specified we will auto
+ generate one.
+ """
+
+ schema = get_schema("Bernoulli", 22, "")
+ op = Op(self, "Bernoulli", schema)
+ return op(*self._prepare_inputs(schema, input), dtype=dtype, seed=seed)
+
+ T_Conv = TypeVar("T_Conv", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Conv(
+ self,
+ X: T_Conv,
+ W: T_Conv,
+ B: Optional[T_Conv] = None,
+ *,
+ auto_pad: str = "NOTSET",
+ dilations: Optional[Sequence[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[Sequence[int]] = None,
+ pads: Optional[Sequence[int]] = None,
+ strides: Optional[Sequence[int]] = None,
+ ) -> T_Conv:
+ r"""[🌐 Conv(22)](https://onnx.ai/onnx/operators/onnx__Conv.html#conv-22 "Online Documentation")
+
+
+ The convolution operator consumes an input tensor and a filter, and
+ computes the output.
+
+ Args:
+ X: (differentiable) Input data tensor from previous layer; has size (N x C x
+ H x W), where N is the batch size, C is the number of channels, and H
+ and W are the height and width. Note that this is for the 2D image.
+ Otherwise the size is (N x C x D1 x D2 ... x Dn). Optionally, if
+ dimension denotation is in effect, the operation expects input data
+ tensor to arrive with the dimension denotation of [DATA_BATCH,
+ DATA_CHANNEL, DATA_FEATURE, DATA_FEATURE ...].
+
+ W: (differentiable) The weight tensor that will be used in the convolutions;
+ has size (M x C/group x kH x kW), where C is the number of channels, and
+ kH and kW are the height and width of the kernel, and M is the number of
+ feature maps. For more than 2 dimensions, the kernel shape will be (M x
+ C/group x k1 x k2 x ... x kn), where (k1 x k2 x ... kn) is the dimension
+ of the kernel. Optionally, if dimension denotation is in effect, the
+ operation expects the weight tensor to arrive with the dimension
+ denotation of [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, FILTER_SPATIAL,
+ FILTER_SPATIAL ...]. Assuming zero based indices for the shape array,
+ X.shape[1] == (W.shape[1] * group) == C and W.shape[0] mod G == 0. Or in
+ other words FILTER_IN_CHANNEL multiplied by the number of groups should
+ be equal to DATA_CHANNEL and the number of feature maps M should be a
+ multiple of the number of groups G.
+
+ B: (optional, differentiable) Optional 1D bias to be added to the
+ convolution, has size of M.
+
+ auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID.
+ Where default value is NOTSET, which means explicit padding is used.
+ SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] =
+ ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is
+ split between the two sides equally or almost equally (depending on
+ whether it is even or odd). In case the padding is an odd number, the
+ extra padding is added at the end for SAME_UPPER and at the beginning
+ for SAME_LOWER.
+
+ dilations: dilation value along each spatial axis of the filter. If not
+ present, the dilation defaults is 1 along each spatial axis.
+
+ group: number of groups input channels and output channels are divided into.
+
+ kernel_shape: The shape of the convolution kernel. If not present, should be
+ inferred from input W.
+
+ pads: Padding for the beginning and ending along each spatial axis, it can
+ take any value greater than or equal to 0. The value represent the
+ number of pixels added to the beginning and end part of the
+ corresponding axis. `pads` format should be as follow [x1_begin,
+ x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels
+ added at the beginning of axis `i` and xi_end, the number of pixels
+ added at the end of axis `i`. This attribute cannot be used
+ simultaneously with auto_pad attribute. If not present, the padding
+ defaults to 0 along start and end of each spatial axis.
+
+ strides: Stride along each spatial axis. If not present, the stride defaults
+ is 1 along each spatial axis.
+ """
+
+ schema = get_schema("Conv", 22, "")
+ op = Op(self, "Conv", schema)
+ return op(
+ *self._prepare_inputs(schema, X, W, B),
+ auto_pad=auto_pad,
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ T_ConvTranspose = TypeVar("T_ConvTranspose", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def ConvTranspose(
+ self,
+ X: T_ConvTranspose,
+ W: T_ConvTranspose,
+ B: Optional[T_ConvTranspose] = None,
+ *,
+ auto_pad: str = "NOTSET",
+ dilations: Optional[Sequence[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[Sequence[int]] = None,
+ output_padding: Optional[Sequence[int]] = None,
+ output_shape: Optional[Sequence[int]] = None,
+ pads: Optional[Sequence[int]] = None,
+ strides: Optional[Sequence[int]] = None,
+ ) -> T_ConvTranspose:
+ r"""[🌐 ConvTranspose(22)](https://onnx.ai/onnx/operators/onnx__ConvTranspose.html#convtranspose-22 "Online Documentation")
+
+
+ The convolution transpose operator consumes an input tensor and a filter,
+ and computes the output.
+
+ If the pads parameter is provided the shape of the output is calculated via the following equation:
+
+ output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i]
+
+ output_shape can also be explicitly specified in which case pads values are auto generated using these equations:
+
+ total_padding[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]
+ If (auto_pads == SAME_UPPER): pads[start_i] = total_padding[i]/2; pads[end_i] = total_padding[i] - (total_padding[i]/2)
+ Else: pads[start_i] = total_padding[i] - (total_padding[i]/2); pads[end_i] = (total_padding[i]/2).
+
+
+
+ Args:
+ X: (differentiable) Input data tensor from previous layer; has size (N x C x
+ H x W), where N is the batch size, C is the number of channels, and H
+ and W are the height and width. Note that this is for the 2D image.
+ Otherwise the size is (N x C x D1 x D2 ... x Dn)
+
+ W: (differentiable) The weight tensor that will be used in the convolutions;
+ has size (C x M/group x kH x kW), where C is the number of channels, and
+ kH and kW are the height and width of the kernel, and M is the number of
+ feature maps. For more than 2 dimensions, the weight shape will be (C x
+ M/group x k1 x k2 x ... x kn), where (k1 x k2 x ... x kn) is the
+ dimension of the kernel. The number of channels in the output should be
+ equal to W.shape[1] * group (assuming zero based indices of the shape
+ array)
+
+ B: (optional, differentiable) Optional 1D bias to be added to the
+ convolution, has size of M.
+
+ auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID.
+ Where default value is NOTSET, which means explicit padding is used.
+ SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] =
+ input_shape[i] * strides[i]` for each axis `i`. The padding is split
+ between the two sides equally or almost equally (depending on whether it
+ is even or odd). In case the padding is an odd number, the extra padding
+ is added at the end for SAME_UPPER and at the beginning for SAME_LOWER.
+
+ dilations: dilation value along each spatial axis of the filter. If not
+ present, the dilation defaults to 1 along each spatial axis.
+
+ group: number of groups input channels and output channels are divided into.
+
+ kernel_shape: The shape of the convolution kernel. If not present, should be
+ inferred from input W.
+
+ output_padding: Additional elements added to the side with higher coordinate
+ indices in the output. Each padding value in "output_padding" must be
+ less than the corresponding stride/dilation dimension. By default, this
+ attribute is a zero vector. Note that this attribute doesn't directly
+ affect the computed output values. It only controls the selection of the
+ computed values, so changing this attribute only adds or removes output
+ elements. If "output_shape" is explicitly provided, "output_padding"
+ does not contribute additional size to "output_shape" but participates
+ in the computation of the needed padding amount. This is also called
+ adjs or adjustment in some frameworks.
+
+ output_shape: The shape of the output can be explicitly set which will cause
+ pads values to be auto generated. If output_shape is specified pads
+ values are ignored. See doc for details for equations to generate pads.
+ Note that the output_shape attribute value should not include dimensions
+ for batch size and channels, which are automatically inferred.
+
+ pads: Padding for the beginning and ending along each spatial axis, it can
+ take any value greater than or equal to 0. The value represent the
+ number of pixels added to the beginning and end part of the
+ corresponding axis. `pads` format should be as follow [x1_begin,
+ x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels
+ added at the beginning of axis `i` and xi_end, the number of pixels
+ added at the end of axis `i`. This attribute cannot be used
+ simultaneously with auto_pad attribute. If not present, the padding
+ defaults to 0 along start and end of each spatial axis.
+
+ strides: Stride along each spatial axis. If not present, the stride defaults
+ to 1 along each spatial axis.
+ """
+
+ schema = get_schema("ConvTranspose", 22, "")
+ op = Op(self, "ConvTranspose", schema)
+ return op(
+ *self._prepare_inputs(schema, X, W, B),
+ auto_pad=auto_pad,
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ output_padding=output_padding,
+ output_shape=output_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ T_Cos = TypeVar("T_Cos", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Cos(self, input: T_Cos) -> T_Cos:
+ r"""[🌐 Cos(22)](https://onnx.ai/onnx/operators/onnx__Cos.html#cos-22 "Online Documentation")
+
+
+ Calculates the cosine of the given input tensor, element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Cos", 22, "")
+ op = Op(self, "Cos", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Cosh = TypeVar("T_Cosh", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Cosh(self, input: T_Cosh) -> T_Cosh:
+ r"""[🌐 Cosh(22)](https://onnx.ai/onnx/operators/onnx__Cosh.html#cosh-22 "Online Documentation")
+
+
+ Calculates the hyperbolic cosine of the given input tensor element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Cosh", 22, "")
+ op = Op(self, "Cosh", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_DeformConv = TypeVar("T_DeformConv", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def DeformConv(
+ self,
+ X: T_DeformConv,
+ W: T_DeformConv,
+ offset: T_DeformConv,
+ B: Optional[T_DeformConv] = None,
+ mask: Optional[T_DeformConv] = None,
+ *,
+ dilations: Optional[Sequence[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[Sequence[int]] = None,
+ offset_group: int = 1,
+ pads: Optional[Sequence[int]] = None,
+ strides: Optional[Sequence[int]] = None,
+ ) -> T_DeformConv:
+ r"""[🌐 DeformConv(22)](https://onnx.ai/onnx/operators/onnx__DeformConv.html#deformconv-22 "Online Documentation")
+
+
+ Performs deformable convolution as described in https://arxiv.org/abs/1703.06211 and https://arxiv.org/abs/1811.11168.
+ This operator specification supports the general N-D case. Note that most common use cases have 2D or 3D data.
+
+
+ Args:
+ X: Input data tensor. For 2D image data, it has shape (N, C, H, W) where N
+ is the batch size, C is the number of input channels, and H and W are
+ the height and width. In general, the shape is (N, C, D1, D2, ... , Dn)
+ for n-dimensional data, where D1 to Dn are the spatial dimension sizes.
+ Most common use cases have n = 2 or 3.
+
+ W: Weight tensor that will be used in the convolutions. It has shape (oC,
+ C/group, kH, kW), where oC is the number of output channels and kH and
+ kW are the kernel height and width. For more than 2 dimensions, it has
+ shape (oC, C/group, k1, k2, ... , kn).
+
+ offset: Offset tensor denoting the offset for the sampling locations in the
+ convolution kernel. It has shape (N, offset_group * kH * kW * 2, oH, oW)
+ for 2D data or (N, offset_group * k1 * k2 * ... * kn * n, o1, o2, ... ,
+ on) for nD data. Use linear interpolationfor fractional offset values.
+ Sampling locations outside of the padded input tensor gives zero.
+
+ B: (optional) Optional 1D bias of length oC to be added to the convolution.
+ Default is a tensor of zeros.
+
+ mask: (optional) The mask tensor to be applied to each position in the
+ convolution kernel. It has shape (N, offset_group * kH * kW, oH, oW) for
+ 2D data or (N, offset_group * k1 * k2 * ... * kn * n, o1, o2, ... , on)
+ for nD data. Default is a tensor of ones.
+
+ dilations: Dilation value along each spatial axis of the kernel. Default is
+ 1 along each axis.
+
+ group: Number of groups the input and output channels, C and oC, are divided
+ into. C and oC must both be divisible by group. Default is 1.
+
+ kernel_shape: Shape of the convolution kernel. If not present, it is
+ inferred from the shape of input W.
+
+ offset_group: Number of groups of offset. C must be divisible by
+ offset_group. Default is 1.
+
+ pads: Padding for the beginning and end along each spatial axis. The values
+ represent the number of pixels added to the beginning and end of the
+ corresponding axis and can take any nonnegative value. The format should
+ be as follows: [x1_begin, x2_begin, ..., x1_end, x2_end, ...], where
+ xi_begin is the number of pixels added at the beginning of axis `i` and
+ xi_end is the number of pixels added at the end of axis `i`. Default is
+ 0 along each axis.
+
+ strides: Stride along each spatial axis. Default is 1 along each axis.
+ """
+
+ schema = get_schema("DeformConv", 22, "")
+ op = Op(self, "DeformConv", schema)
+ return op(
+ *self._prepare_inputs(schema, X, W, offset, B, mask),
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ offset_group=offset_group,
+ pads=pads,
+ strides=strides,
+ )
+
+ T_Det = TypeVar("T_Det", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Det(self, X: T_Det) -> T_Det:
+ r"""[🌐 Det(22)](https://onnx.ai/onnx/operators/onnx__Det.html#det-22 "Online Documentation")
+
+
+ Det calculates determinant of a square matrix or batches of square matrices.
+ Det takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions,
+ and the inner-most 2 dimensions form square matrices.
+ The output is a tensor of shape `[*]`, containing the determinants of all input submatrices.
+ e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`).
+
+
+ Args:
+ X: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Det", 22, "")
+ op = Op(self, "Det", schema)
+ return op(*self._prepare_inputs(schema, X))
+
+ T_Dropout = TypeVar(
+ "T_Dropout",
+ BFLOAT16,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ )
+
+ T1_Dropout = TypeVar(
+ "T1_Dropout",
+ BFLOAT16,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ )
+
+ T2_Dropout: TypeAlias = BOOL
+
+ def Dropout(
+ self,
+ data: T_Dropout,
+ ratio: Optional[T1_Dropout] = None,
+ training_mode: Optional[T2_Dropout] = None,
+ *,
+ seed: Optional[int] = None,
+ ) -> Tuple[T_Dropout, T2_Dropout]:
+ r"""[🌐 Dropout(22)](https://onnx.ai/onnx/operators/onnx__Dropout.html#dropout-22 "Online Documentation")
+
+
+ Dropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar). It produces two tensor outputs,
+ output (floating-point tensor) and mask (optional `Tensor`). If `training_mode` is true then the output Y will be a random dropout;
+ Note that this Dropout scales the masked input data by the following equation, so to convert the trained model into inference mode,
+ the user can simply not pass `training_mode` input or set it to false.
+ ::
+
+ output = scale * data * mask,
+
+
+ where
+ ::
+
+ scale = 1. / (1. - ratio).
+
+
+ This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
+
+
+ Args:
+ data: (differentiable) The input data as Tensor.
+
+ ratio: (optional, non-differentiable) The ratio of random dropout, with
+ value in [0, 1). If set to 0, the output would be a simple copy of the
+ input. If it's non-zero, output will be a random dropout of the scaled
+ input, which is typically the case during training. It is an optional
+ value, if not specified it will default to 0.5.
+
+ training_mode: (optional, non-differentiable) If set to true then it
+ indicates dropout is being used for training. It is an optional value
+ hence unless specified explicitly, it is false. If it is false, ratio is
+ ignored and the operation mimics inference mode where nothing will be
+ dropped from the input data and if mask is requested as output it will
+ contain all ones.
+
+ seed: (Optional) Seed to the random generator, if not specified we will auto
+ generate one.
+ """
+
+ schema = get_schema("Dropout", 22, "")
+ op = Op(self, "Dropout", schema)
+ return op(*self._prepare_inputs(schema, data, ratio, training_mode), seed=seed)
+
+ T_Elu = TypeVar("T_Elu", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Elu(self, X: T_Elu, *, alpha: float = 1.0) -> T_Elu:
+ r"""[🌐 Elu(22)](https://onnx.ai/onnx/operators/onnx__Elu.html#elu-22 "Online Documentation")
+
+
+ Elu takes one input data (Tensor) and produces one output data
+ (Tensor) where the function `f(x) = alpha * (exp(x) - 1.) for x <
+ 0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise.
+
+
+
+ Args:
+ X: (differentiable) Input tensor
+
+ alpha: Coefficient of ELU.
+ """
+
+ schema = get_schema("Elu", 22, "")
+ op = Op(self, "Elu", schema)
+ return op(*self._prepare_inputs(schema, X), alpha=alpha)
+
+ T1_EyeLike = TypeVar(
+ "T1_EyeLike",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ T2_EyeLike: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ ]
+
+ def EyeLike(
+ self, input: T1_EyeLike, *, dtype: Optional[int] = None, k: int = 0
+ ) -> T2_EyeLike:
+ r"""[🌐 EyeLike(22)](https://onnx.ai/onnx/operators/onnx__EyeLike.html#eyelike-22 "Online Documentation")
+
+
+ Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D
+ tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the
+ same as the input tensor. The data type can be specified by the 'dtype' argument. If
+ 'dtype' is not specified, then the type of input tensor is used. By default, the main diagonal
+ is populated with ones, but attribute 'k' can be used to populate upper or lower diagonals.
+ The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
+ TensorProto message and be valid as an output type.
+
+
+ Args:
+ input: 2D input tensor to copy shape, and optionally, type information from.
+
+ dtype: (Optional) The data type for the elements of the output tensor. If
+ not specified, the data type of the input tensor T1 is used.
+
+ k: (Optional) Index of the diagonal to be populated with ones. Default is 0.
+ If T2 is the output, this op sets T2[i, i+k] = 1. k = 0 populates the
+ main diagonal, k > 0 populates an upper diagonal, and k < 0 populates a
+ lower diagonal.
+ """
+
+ schema = get_schema("EyeLike", 22, "")
+ op = Op(self, "EyeLike", schema)
+ return op(*self._prepare_inputs(schema, input), dtype=dtype, k=k)
+
+ T_GRU = TypeVar("T_GRU", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T1_GRU: TypeAlias = INT32
+
+ def GRU(
+ self,
+ X: T_GRU,
+ W: T_GRU,
+ R: T_GRU,
+ B: Optional[T_GRU] = None,
+ sequence_lens: Optional[T1_GRU] = None,
+ initial_h: Optional[T_GRU] = None,
+ *,
+ activation_alpha: Optional[Sequence[float]] = None,
+ activation_beta: Optional[Sequence[float]] = None,
+ activations: Optional[Sequence[str]] = None,
+ clip: Optional[float] = None,
+ direction: str = "forward",
+ hidden_size: Optional[int] = None,
+ layout: int = 0,
+ linear_before_reset: int = 0,
+ ) -> Tuple[T_GRU, T_GRU]:
+ r"""[🌐 GRU(22)](https://onnx.ai/onnx/operators/onnx__GRU.html#gru-22 "Online Documentation")
+
+
+ Computes an one-layer GRU. This operator is usually supported via some custom
+ implementation such as CuDNN.
+
+ Notations:
+
+ * `X` - input tensor
+ * `z` - update gate
+ * `r` - reset gate
+ * `h` - hidden gate
+ * `t` - time step (t-1 means previous time step)
+ * `W[zrh]` - W parameter weight matrix for update, reset, and hidden gates
+ * `R[zrh]` - R recurrence weight matrix for update, reset, and hidden gates
+ * `Wb[zrh]` - W bias vectors for update, reset, and hidden gates
+ * `Rb[zrh]` - R bias vectors for update, reset, and hidden gates
+ * `WB[zrh]` - W parameter weight matrix for backward update, reset, and hidden gates
+ * `RB[zrh]` - R recurrence weight matrix for backward update, reset, and hidden gates
+ * `WBb[zrh]` - W bias vectors for backward update, reset, and hidden gates
+ * `RBb[zrh]` - R bias vectors for backward update, reset, and hidden gates
+ * `H` - Hidden state
+ * `num_directions` - 2 if direction == bidirectional else 1
+
+ Activation functions:
+
+ * Relu(x) - max(0, x)
+ * Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x})
+ * Sigmoid(x) - 1/(1 + e^{-x})
+
+ NOTE:
+ Below are optional
+
+ * Affine(x) - alpha * x + beta
+ * LeakyRelu(x) - x if x >= 0 else alpha * x
+ * ThresholdedRelu(x) - x if x >= alpha else 0
+ * ScaledTanh(x) - alpha * Tanh(beta * x)
+ * HardSigmoid(x) - min(max(alpha * x + beta, 0), 1)
+ * Elu(x) - x if x >= 0 else alpha * (e^x - 1)
+ * Softsign(x) - x/(1 + |x|)
+ * Softplus(x) - log(1 + e^x)
+
+ Equations (Default: f=Sigmoid, g=Tanh):
+
+ * zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
+ * rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
+ * ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0
+ * ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0
+ * Ht = (1 - zt) (.) ht + zt (.) Ht-1
+ This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
+
+
+ Args:
+ X: (differentiable) The input sequences packed (and potentially padded) into
+ one 3-D tensor with the shape of `[seq_length, batch_size, input_size]`.
+
+ W: (differentiable) The weight tensor for the gates. Concatenation of
+ `W[zrh]` and `WB[zrh]` (if bidirectional) along dimension 0. This tensor
+ has shape `[num_directions, 3*hidden_size, input_size]`.
+
+ R: (differentiable) The recurrence weight tensor. Concatenation of `R[zrh]`
+ and `RB[zrh]` (if bidirectional) along dimension 0. This tensor has
+ shape `[num_directions, 3*hidden_size, hidden_size]`.
+
+ B: (optional, differentiable) The bias tensor for the gates. Concatenation
+ of `[Wb[zrh], Rb[zrh]]` and `[WBb[zrh], RBb[zrh]]` (if bidirectional)
+ along dimension 0. This tensor has shape `[num_directions,
+ 6*hidden_size]`. Optional: If not specified - assumed to be 0
+
+ sequence_lens: (optional, non-differentiable) Optional tensor specifying
+ lengths of the sequences in a batch. If not specified - assumed all
+ sequences in the batch to have length `seq_length`. It has shape
+ `[batch_size]`.
+
+ initial_h: (optional, non-differentiable) Optional initial value of the
+ hidden. If not specified - assumed to be 0. It has shape
+ `[num_directions, batch_size, hidden_size]`.
+
+ activation_alpha: Optional scaling values used by some activation functions.
+ The values are consumed in the order of activation functions, for
+ example (f, g, h) in LSTM. Default values are the same as of
+ corresponding ONNX operators.For example with LeakyRelu, the default
+ alpha is 0.01.
+
+ activation_beta: Optional scaling values used by some activation functions.
+ The values are consumed in the order of activation functions, for
+ example (f, g, h) in LSTM. Default values are the same as of
+ corresponding ONNX operators.
+
+ activations: A list of 2 (or 4 if bidirectional) activation functions for
+ update, reset, and hidden gates. The activation functions must be one of
+ the activation functions specified above. Optional: See the equations
+ for default if not specified.
+
+ clip: Cell clip threshold. Clipping bounds the elements of a tensor in the
+ range of [-threshold, +threshold] and is applied to the input of
+ activations. No clip if not specified.
+
+ direction: Specify if the RNN is forward, reverse, or bidirectional. Must be
+ one of forward (default), reverse, or bidirectional.
+
+ hidden_size: Number of neurons in the hidden layer
+
+ layout: The shape format of inputs X, initial_h and outputs Y, Y_h. If 0,
+ the following shapes are expected: X.shape = [seq_length, batch_size,
+ input_size], Y.shape = [seq_length, num_directions, batch_size,
+ hidden_size], initial_h.shape = Y_h.shape = [num_directions, batch_size,
+ hidden_size]. If 1, the following shapes are expected: X.shape =
+ [batch_size, seq_length, input_size], Y.shape = [batch_size, seq_length,
+ num_directions, hidden_size], initial_h.shape = Y_h.shape = [batch_size,
+ num_directions, hidden_size].
+
+ linear_before_reset: When computing the output of the hidden gate, apply the
+ linear transformation before multiplying by the output of the reset
+ gate.
+ """
+
+ schema = get_schema("GRU", 22, "")
+ op = Op(self, "GRU", schema)
+ return op(
+ *self._prepare_inputs(schema, X, W, R, B, sequence_lens, initial_h),
+ activation_alpha=activation_alpha,
+ activation_beta=activation_beta,
+ activations=activations,
+ clip=clip,
+ direction=direction,
+ hidden_size=hidden_size,
+ layout=layout,
+ linear_before_reset=linear_before_reset,
+ )
+
+ T_GlobalAveragePool = TypeVar("T_GlobalAveragePool", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def GlobalAveragePool(self, X: T_GlobalAveragePool) -> T_GlobalAveragePool:
+ r"""[🌐 GlobalAveragePool(22)](https://onnx.ai/onnx/operators/onnx__GlobalAveragePool.html#globalaveragepool-22 "Online Documentation")
+
+
+ GlobalAveragePool consumes an input tensor X and applies average pooling across
+ the values in the same channel. This is equivalent to AveragePool with kernel size
+ equal to the spatial dimension of input tensor.
+
+ Args:
+ X: (differentiable) Input data tensor from the previous operator; dimensions
+ for image case are (N x C x H x W), where N is the batch size, C is the
+ number of channels, and H and W are the height and the width of the
+ data. For non image case, the dimensions are in the form of (N x C x D1
+ x D2 ... Dn), where N is the batch size.
+ """
+
+ schema = get_schema("GlobalAveragePool", 22, "")
+ op = Op(self, "GlobalAveragePool", schema)
+ return op(*self._prepare_inputs(schema, X))
+
+ T_GlobalLpPool = TypeVar("T_GlobalLpPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def GlobalLpPool(self, X: T_GlobalLpPool, *, p: int = 2) -> T_GlobalLpPool:
+ r"""[🌐 GlobalLpPool(22)](https://onnx.ai/onnx/operators/onnx__GlobalLpPool.html#globallppool-22 "Online Documentation")
+
+
+ GlobalLpPool consumes an input tensor X and applies lp pool pooling across
+ the values in the same channel. This is equivalent to LpPool with kernel size
+ equal to the spatial dimension of input tensor.
+
+ Args:
+ X: (differentiable) Input data tensor from the previous operator; dimensions
+ for image case are (N x C x H x W), where N is the batch size, C is the
+ number of channels, and H and W are the height and the width of the
+ data. For non image case, the dimensions are in the form of (N x C x D1
+ x D2 ... Dn), where N is the batch size.
+
+ p: p value of the Lp norm used to pool over the input data.
+ """
+
+ schema = get_schema("GlobalLpPool", 22, "")
+ op = Op(self, "GlobalLpPool", schema)
+ return op(*self._prepare_inputs(schema, X), p=p)
+
+ T_GlobalMaxPool = TypeVar("T_GlobalMaxPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def GlobalMaxPool(self, X: T_GlobalMaxPool) -> T_GlobalMaxPool:
+ r"""[🌐 GlobalMaxPool(22)](https://onnx.ai/onnx/operators/onnx__GlobalMaxPool.html#globalmaxpool-22 "Online Documentation")
+
+
+ GlobalMaxPool consumes an input tensor X and applies max pooling across
+ the values in the same channel. This is equivalent to MaxPool with kernel size
+ equal to the spatial dimension of input tensor.
+
+ Args:
+ X: (differentiable) Input data tensor from the previous operator; dimensions
+ for image case are (N x C x H x W), where N is the batch size, C is the
+ number of channels, and H and W are the height and the width of the
+ data. For non image case, the dimensions are in the form of (N x C x D1
+ x D2 ... Dn), where N is the batch size.
+ """
+
+ schema = get_schema("GlobalMaxPool", 22, "")
+ op = Op(self, "GlobalMaxPool", schema)
+ return op(*self._prepare_inputs(schema, X))
+
+ T1_GridSample = TypeVar(
+ "T1_GridSample",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ T2_GridSample = TypeVar("T2_GridSample", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def GridSample(
+ self,
+ X: T1_GridSample,
+ grid: T2_GridSample,
+ *,
+ align_corners: int = 0,
+ mode: str = "linear",
+ padding_mode: str = "zeros",
+ ) -> T1_GridSample:
+ r"""[🌐 GridSample(22)](https://onnx.ai/onnx/operators/onnx__GridSample.html#gridsample-22 "Online Documentation")
+
+
+ Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
+ For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2),
+ the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W),
+ the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out).
+ More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr),
+ the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out).
+
+ The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in).
+ The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values
+ at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode)
+ and a padding mode (for `grid` positions falling outside the 2-dimensional image).
+
+ For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`.
+ They are used to interpolate output values of `Y[n, c, h_out, w_out]`.
+
+ The GridSample operator is often used in doing grid generator and sampler in the
+ [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
+ See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
+
+
+ Args:
+ X: (differentiable) Input tensor of rank r+2 that has shape (N, C, D1, D2,
+ ..., Dr), where N is the batch size, C is the number of channels, D1,
+ D2, ..., Dr are the spatial dimensions.
+
+ grid: (non-differentiable) Input offset of shape (N, D1_out, D2_out, ...,
+ Dr_out, r), where D1_out, D2_out, ..., Dr_out are the spatial dimensions
+ of the grid and output, and r is the number of spatial dimensions. Grid
+ specifies the sampling locations normalized by the input spatial
+ dimensions. Therefore, it should have most values in the range of [-1,
+ 1]. If the grid has values outside the range of [-1, 1], the
+ corresponding outputs will be handled as defined by padding_mode.
+ Following computer vision convention, the coordinates in the length-r
+ location vector are listed from the innermost tensor dimension to the
+ outermost, the opposite of regular tensor indexing.
+
+ align_corners: If align_corners=1, the extrema (-1 and 1) are considered as
+ referring to the center points of the input's corner pixels (voxels,
+ etc.). If align_corners=0, they are instead considered as referring to
+ the corner points of the input's corner pixels (voxels, etc.), making
+ the sampling more resolution agnostic.
+
+ mode: Three interpolation modes: linear (default), nearest and cubic. The
+ "linear" mode includes linear and N-linear interpolation modes depending
+ on the number of spatial dimensions of the input tensor (i.e. linear for
+ 1 spatial dimension, bilinear for 2 spatial dimensions, etc.). The
+ "cubic" mode also includes N-cubic interpolation modes following the
+ same rules. The "nearest" mode rounds to the nearest even index when the
+ sampling point falls halfway between two indices.
+
+ padding_mode: Support padding modes for outside grid values:
+ `zeros`(default), `border`, `reflection`. zeros: use 0 for out-of-bound
+ grid locations, border: use border values for out-of-bound grid
+ locations, reflection: use values at locations reflected by the border
+ for out-of-bound grid locations. If index 0 represents the margin pixel,
+ the reflected value at index -1 will be the same as the value at index
+ 1. For location far away from the border, it will keep being reflected
+ until becoming in bound. If pixel location x = -3.5 reflects by border
+ -1 and becomes x' = 1.5, then reflects by border 1 and becomes x'' =
+ 0.5.
+ """
+
+ schema = get_schema("GridSample", 22, "")
+ op = Op(self, "GridSample", schema)
+ return op(
+ *self._prepare_inputs(schema, X, grid),
+ align_corners=align_corners,
+ mode=mode,
+ padding_mode=padding_mode,
+ )
+
+ T_HardSigmoid = TypeVar("T_HardSigmoid", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def HardSigmoid(
+ self, X: T_HardSigmoid, *, alpha: float = 0.20000000298023224, beta: float = 0.5
+ ) -> T_HardSigmoid:
+ r"""[🌐 HardSigmoid(22)](https://onnx.ai/onnx/operators/onnx__HardSigmoid.html#hardsigmoid-22 "Online Documentation")
+
+
+ HardSigmoid takes one input data (Tensor) and produces one output data
+ (Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)),
+ is applied to the tensor elementwise.
+
+
+ Args:
+ X: (differentiable) Input tensor
+
+ alpha: Value of alpha.
+
+ beta: Value of beta.
+ """
+
+ schema = get_schema("HardSigmoid", 22, "")
+ op = Op(self, "HardSigmoid", schema)
+ return op(*self._prepare_inputs(schema, X), alpha=alpha, beta=beta)
+
+ T_HardSwish = TypeVar("T_HardSwish", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def HardSwish(self, X: T_HardSwish) -> T_HardSwish:
+ r"""[🌐 HardSwish(22)](https://onnx.ai/onnx/operators/onnx__HardSwish.html#hardswish-22 "Online Documentation")
+
+
+ HardSwish takes one input data (Tensor) and produces one output data (Tensor) where
+ the HardSwish function, y = x * max(0, min(1, alpha * x + beta)) = x * HardSigmoid(x),
+ where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise.
+
+
+ Args:
+ X: (differentiable) Input tensor
+ """
+
+ schema = get_schema("HardSwish", 22, "")
+ op = Op(self, "HardSwish", schema)
+ return op(*self._prepare_inputs(schema, X))
+
+ T_InstanceNormalization = TypeVar(
+ "T_InstanceNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16
+ )
+
+ def InstanceNormalization(
+ self,
+ input: T_InstanceNormalization,
+ scale: T_InstanceNormalization,
+ B: T_InstanceNormalization,
+ *,
+ epsilon: float = 9.999999747378752e-06,
+ ) -> T_InstanceNormalization:
+ r"""[🌐 InstanceNormalization(22)](https://onnx.ai/onnx/operators/onnx__InstanceNormalization.html#instancenormalization-22 "Online Documentation")
+
+
+ Carries out instance normalization as described in the paper
+ https://arxiv.org/abs/1607.08022.
+
+ y = scale * (x - mean) / sqrt(variance + epsilon) + B,
+ where mean and variance are computed per instance per channel.
+
+
+
+ Args:
+ input: (differentiable) Input data tensor from the previous operator;
+ dimensions for image case are (N x C x H x W), where N is the batch
+ size, C is the number of channels, and H and W are the height and the
+ width of the data. For non image case, the dimensions are in the form of
+ (N x C x D1 x D2 ... Dn), where N is the batch size.
+
+ scale: (differentiable) The input 1-dimensional scale tensor of size C.
+
+ B: (differentiable) The input 1-dimensional bias tensor of size C.
+
+ epsilon: The epsilon value to use to avoid division by zero.
+ """
+
+ schema = get_schema("InstanceNormalization", 22, "")
+ op = Op(self, "InstanceNormalization", schema)
+ return op(*self._prepare_inputs(schema, input, scale, B), epsilon=epsilon)
+
+ T_LSTM = TypeVar("T_LSTM", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T1_LSTM: TypeAlias = INT32
+
+ def LSTM(
+ self,
+ X: T_LSTM,
+ W: T_LSTM,
+ R: T_LSTM,
+ B: Optional[T_LSTM] = None,
+ sequence_lens: Optional[T1_LSTM] = None,
+ initial_h: Optional[T_LSTM] = None,
+ initial_c: Optional[T_LSTM] = None,
+ P: Optional[T_LSTM] = None,
+ *,
+ activation_alpha: Optional[Sequence[float]] = None,
+ activation_beta: Optional[Sequence[float]] = None,
+ activations: Optional[Sequence[str]] = None,
+ clip: Optional[float] = None,
+ direction: str = "forward",
+ hidden_size: Optional[int] = None,
+ input_forget: int = 0,
+ layout: int = 0,
+ ) -> Tuple[T_LSTM, T_LSTM, T_LSTM]:
+ r"""[🌐 LSTM(22)](https://onnx.ai/onnx/operators/onnx__LSTM.html#lstm-22 "Online Documentation")
+
+
+ Computes an one-layer LSTM. This operator is usually supported via some
+ custom implementation such as CuDNN.
+
+ Notations:
+
+ * `X` - input tensor
+ * `i` - input gate
+ * `o` - output gate
+ * `f` - forget gate
+ * `c` - cell gate
+ * `t` - time step (t-1 means previous time step)
+ * `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates
+ * `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates
+ * `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates
+ * `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates
+ * `P[iof]` - P peephole weight vector for input, output, and forget gates
+ * `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates
+ * `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates
+ * `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates
+ * `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates
+ * `PB[iof]` - P peephole weight vector for backward input, output, and forget gates
+ * `H` - Hidden state
+ * `num_directions` - 2 if direction == bidirectional else 1
+
+ Activation functions:
+
+ * Relu(x) - max(0, x)
+ * Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x})
+ * Sigmoid(x) - 1/(1 + e^{-x})
+
+ NOTE: Below are optional
+
+ * Affine(x) - alpha*x + beta
+ * LeakyRelu(x) - x if x >= 0 else alpha * x
+ * ThresholdedRelu(x) - x if x >= alpha else 0
+ * ScaledTanh(x) - alpha*Tanh(beta*x)
+ * HardSigmoid(x) - min(max(alpha*x + beta, 0), 1)
+ * Elu(x) - x if x >= 0 else alpha*(e^x - 1)
+ * Softsign(x) - x/(1 + |x|)
+ * Softplus(x) - log(1 + e^x)
+
+ Equations (Default: f=Sigmoid, g=Tanh, h=Tanh):
+
+ * it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
+ * ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
+ * ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
+ * Ct = ft (.) Ct-1 + it (.) ct
+ * ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
+ * Ht = ot (.) h(Ct)
+ This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
+
+
+ Args:
+ X: (differentiable) The input sequences packed (and potentially padded) into
+ one 3-D tensor with the shape of `[seq_length, batch_size, input_size]`.
+
+ W: (differentiable) The weight tensor for the gates. Concatenation of
+ `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0. The
+ tensor has shape `[num_directions, 4*hidden_size, input_size]`.
+
+ R: (differentiable) The recurrence weight tensor. Concatenation of `R[iofc]`
+ and `RB[iofc]` (if bidirectional) along dimension 0. This tensor has
+ shape `[num_directions, 4*hidden_size, hidden_size]`.
+
+ B: (optional, differentiable) The bias tensor for input gate. Concatenation
+ of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if
+ bidirectional) along dimension 0. This tensor has shape
+ `[num_directions, 8*hidden_size]`. Optional: If not specified - assumed
+ to be 0.
+
+ sequence_lens: (optional, non-differentiable) Optional tensor specifying
+ lengths of the sequences in a batch. If not specified - assumed all
+ sequences in the batch to have length `seq_length`. It has shape
+ `[batch_size]`.
+
+ initial_h: (optional, non-differentiable) Optional initial value of the
+ hidden. If not specified - assumed to be 0. It has shape
+ `[num_directions, batch_size, hidden_size]`.
+
+ initial_c: (optional, non-differentiable) Optional initial value of the
+ cell. If not specified - assumed to be 0. It has shape `[num_directions,
+ batch_size, hidden_size]`.
+
+ P: (optional, differentiable) The weight tensor for peepholes. Concatenation
+ of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0. It has
+ shape `[num_directions, 3*hidde_size]`. Optional: If not specified -
+ assumed to be 0.
+
+ activation_alpha: Optional scaling values used by some activation functions.
+ The values are consumed in the order of activation functions, for
+ example (f, g, h) in LSTM. Default values are the same as of
+ corresponding ONNX operators.For example with LeakyRelu, the default
+ alpha is 0.01.
+
+ activation_beta: Optional scaling values used by some activation functions.
+ The values are consumed in the order of activation functions, for
+ example (f, g, h) in LSTM. Default values are the same as of
+ corresponding ONNX operators.
+
+ activations: A list of 3 (or 6 if bidirectional) activation functions for
+ input, output, forget, cell, and hidden. The activation functions must
+ be one of the activation functions specified above. Optional: See the
+ equations for default if not specified.
+
+ clip: Cell clip threshold. Clipping bounds the elements of a tensor in the
+ range of [-threshold, +threshold] and is applied to the input of
+ activations. No clip if not specified.
+
+ direction: Specify if the RNN is forward, reverse, or bidirectional. Must be
+ one of forward (default), reverse, or bidirectional.
+
+ hidden_size: Number of neurons in the hidden layer
+
+ input_forget: Couple the input and forget gates if 1.
+
+ layout: The shape format of inputs X, initial_h, initial_c and outputs Y,
+ Y_h, Y_c. If 0, the following shapes are expected: X.shape =
+ [seq_length, batch_size, input_size], Y.shape = [seq_length,
+ num_directions, batch_size, hidden_size], initial_h.shape = Y_h.shape =
+ initial_c.shape = Y_c.shape = [num_directions, batch_size, hidden_size].
+ If 1, the following shapes are expected: X.shape = [batch_size,
+ seq_length, input_size], Y.shape = [batch_size, seq_length,
+ num_directions, hidden_size], initial_h.shape = Y_h.shape =
+ initial_c.shape = Y_c.shape = [batch_size, num_directions, hidden_size].
+ """
+
+ schema = get_schema("LSTM", 22, "")
+ op = Op(self, "LSTM", schema)
+ return op(
+ *self._prepare_inputs(schema, X, W, R, B, sequence_lens, initial_h, initial_c, P),
+ activation_alpha=activation_alpha,
+ activation_beta=activation_beta,
+ activations=activations,
+ clip=clip,
+ direction=direction,
+ hidden_size=hidden_size,
+ input_forget=input_forget,
+ layout=layout,
+ )
+
+ T_LpNormalization = TypeVar("T_LpNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def LpNormalization(
+ self, input: T_LpNormalization, *, axis: int = -1, p: int = 2
+ ) -> T_LpNormalization:
+ r"""[🌐 LpNormalization(22)](https://onnx.ai/onnx/operators/onnx__LpNormalization.html#lpnormalization-22 "Online Documentation")
+
+
+ Given a matrix, apply Lp-normalization along the provided axis.
+
+
+ Args:
+ input: (differentiable) Input matrix
+
+ axis: The axis on which to apply normalization, -1 mean last axis.
+
+ p: The order of the normalization, only 1 or 2 are supported.
+ """
+
+ schema = get_schema("LpNormalization", 22, "")
+ op = Op(self, "LpNormalization", schema)
+ return op(*self._prepare_inputs(schema, input), axis=axis, p=p)
+
+ T_LpPool = TypeVar("T_LpPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def LpPool(
+ self,
+ X: T_LpPool,
+ *,
+ auto_pad: str = "NOTSET",
+ ceil_mode: int = 0,
+ dilations: Optional[Sequence[int]] = None,
+ kernel_shape: Sequence[int],
+ p: int = 2,
+ pads: Optional[Sequence[int]] = None,
+ strides: Optional[Sequence[int]] = None,
+ ) -> T_LpPool:
+ r"""[🌐 LpPool(22)](https://onnx.ai/onnx/operators/onnx__LpPool.html#lppool-22 "Online Documentation")
+
+
+ LpPool consumes an input tensor X and applies Lp pooling across
+ the tensor according to kernel sizes, stride sizes, and pad lengths.
+ Lp pooling consisting of computing the Lp norm on all values of a subset
+ of the input tensor according to the kernel size and downsampling the
+ data into the output tensor Y for further processing. The output spatial shape will be following:
+ ```
+ output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - {kernelSpatialShape}) / strides_spatial_shape[i] + 1)
+ ```
+ or
+ ```
+ output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - {kernelSpatialShape}) / strides_spatial_shape[i] + 1)
+ ```
+ if ceil_mode is enabled `pad_shape[i]` is the sum of pads along axis `i`.
+
+ `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following:
+ ```
+ VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - {kernelSpatialShape} + 1) / strides_spatial_shape[i])
+ SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i])
+ ```
+ And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`:
+ ```
+ pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + {kernelSpatialShape} - input_spatial_shape[i]
+ ```
+
+ Args:
+ X: (differentiable) Input data tensor from the previous operator; dimensions
+ for image case are (N x C x H x W), where N is the batch size, C is the
+ number of channels, and H and W are the height and the width of the
+ data. For non image case, the dimensions are in the form of (N x C x D1
+ x D2 ... Dn), where N is the batch size.
+
+ auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID.
+ Where default value is NOTSET, which means explicit padding is used.
+ SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] =
+ ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is
+ split between the two sides equally or almost equally (depending on
+ whether it is even or odd). In case the padding is an odd number, the
+ extra padding is added at the end for SAME_UPPER and at the beginning
+ for SAME_LOWER.
+
+ ceil_mode: Whether to use ceil or floor (default) to compute the output
+ shape.
+
+ dilations: dilation value along each spatial axis of the filter. If not
+ present, the dilation defaults is 1 along each spatial axis.
+
+ kernel_shape: The size of the kernel along each axis.
+
+ p: p value of the Lp norm used to pool over the input data.
+
+ pads: Padding for the beginning and ending along each spatial axis, it can
+ take any value greater than or equal to 0. The value represent the
+ number of pixels added to the beginning and end part of the
+ corresponding axis. `pads` format should be as follow [x1_begin,
+ x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels
+ added at the beginning of axis `i` and xi_end, the number of pixels
+ added at the end of axis `i`. This attribute cannot be used
+ simultaneously with auto_pad attribute. If not present, the padding
+ defaults to 0 along start and end of each spatial axis.
+
+ strides: Stride along each spatial axis. If not present, the stride defaults
+ to 1 along each spatial axis.
+ """
+
+ schema = get_schema("LpPool", 22, "")
+ op = Op(self, "LpPool", schema)
+ return op(
+ *self._prepare_inputs(schema, X),
+ auto_pad=auto_pad,
+ ceil_mode=ceil_mode,
+ dilations=dilations,
+ kernel_shape=kernel_shape,
+ p=p,
+ pads=pads,
+ strides=strides,
+ )
+
+ T_MaxPool = TypeVar("T_MaxPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT8, UINT8)
+
+ I_MaxPool: TypeAlias = INT64
+
+ def MaxPool(
+ self,
+ X: T_MaxPool,
+ *,
+ auto_pad: str = "NOTSET",
+ ceil_mode: int = 0,
+ dilations: Optional[Sequence[int]] = None,
+ kernel_shape: Sequence[int],
+ pads: Optional[Sequence[int]] = None,
+ storage_order: int = 0,
+ strides: Optional[Sequence[int]] = None,
+ ) -> Tuple[T_MaxPool, I_MaxPool]:
+ r"""[🌐 MaxPool(22)](https://onnx.ai/onnx/operators/onnx__MaxPool.html#maxpool-22 "Online Documentation")
+
+
+ MaxPool consumes an input tensor X and applies max pooling across
+ the tensor according to kernel sizes, stride sizes, and pad lengths.
+ max pooling consisting of computing the max on all values of a
+ subset of the input tensor according to the kernel size and downsampling the
+ data into the output tensor Y for further processing. The output spatial shape is calculated differently
+ depending on whether explicit padding is used, where pads is employed, or auto padding is used, where auto_pad is utilized.
+ With explicit padding (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html?highlight=maxpool#torch.nn.MaxPool2d):
+ ```
+ output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)
+ ```
+ or
+ ```
+ output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)
+ ```
+ if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored.
+
+ `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled:
+ ```
+ VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) + 1) / strides_spatial_shape[i])
+ SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i])
+ ```
+ or when ceil_mode is disabled (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AveragePooling2D):
+ ```
+ VALID: output_spatial_shape[i] = floor((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i]) + 1
+ SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = floor((input_spatial_shape[i] - 1) / strides_spatial_shape[i]) + 1
+ ```
+ And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`:
+ ```
+ pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]
+ ```
+ The output of each pooling window is maximum number of elements exclude pad.
+
+
+ Args:
+ X: (differentiable) Input data tensor from the previous operator; dimensions
+ for image case are (N x C x H x W), where N is the batch size, C is the
+ number of channels, and H and W are the height and the width of the
+ data. For non image case, the dimensions are in the form of (N x C x D1
+ x D2 ... Dn), where N is the batch size. Optionally, if dimension
+ denotation is in effect, the operation expects the input data tensor to
+ arrive with the dimension denotation of [DATA_BATCH, DATA_CHANNEL,
+ DATA_FEATURE, DATA_FEATURE ...].
+
+ auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID.
+ Where default value is NOTSET, which means explicit padding is used.
+ SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] =
+ ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is
+ split between the two sides equally or almost equally (depending on
+ whether it is even or odd). In case the padding is an odd number, the
+ extra padding is added at the end for SAME_UPPER and at the beginning
+ for SAME_LOWER.
+
+ ceil_mode: Whether to use ceil or floor (default) to compute the output
+ shape.
+
+ dilations: Dilation value along each spatial axis of filter. If not present,
+ the dilation defaults to 1 along each spatial axis.
+
+ kernel_shape: The size of the kernel along each axis.
+
+ pads: Padding for the beginning and ending along each spatial axis, it can
+ take any value greater than or equal to 0. The value represent the
+ number of pixels added to the beginning and end part of the
+ corresponding axis. `pads` format should be as follow [x1_begin,
+ x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels
+ added at the beginning of axis `i` and xi_end, the number of pixels
+ added at the end of axis `i`. This attribute cannot be used
+ simultaneously with auto_pad attribute. If not present, the padding
+ defaults to 0 along start and end of each spatial axis.
+
+ storage_order: The storage order of the tensor. 0 is row major, and 1 is
+ column major. This attribute is used only to convert an n-tuple index
+ value into a single integer value for producing the second output.
+
+ strides: Stride along each spatial axis. If not present, the stride defaults
+ to 1 along each spatial axis.
+ """
+
+ schema = get_schema("MaxPool", 22, "")
+ op = Op(self, "MaxPool", schema)
+ return op(
+ *self._prepare_inputs(schema, X),
+ auto_pad=auto_pad,
+ ceil_mode=ceil_mode,
+ dilations=dilations,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ storage_order=storage_order,
+ strides=strides,
+ )
+
+ T_MaxRoiPool = TypeVar("T_MaxRoiPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def MaxRoiPool(
+ self,
+ X: T_MaxRoiPool,
+ rois: T_MaxRoiPool,
+ *,
+ pooled_shape: Sequence[int],
+ spatial_scale: float = 1.0,
+ ) -> T_MaxRoiPool:
+ r"""[🌐 MaxRoiPool(22)](https://onnx.ai/onnx/operators/onnx__MaxRoiPool.html#maxroipool-22 "Online Documentation")
+
+
+ ROI max pool consumes an input tensor X and region of interests (RoIs) to
+ apply max pooling across each RoI, to produce output 4-D tensor of shape
+ (num_rois, channels, pooled_shape[0], pooled_shape[1]).
+
+ Args:
+ X: (differentiable) Input data tensor from the previous operator; dimensions
+ for image case are (N x C x H x W), where N is the batch size, C is the
+ number of channels, and H and W are the height and the width of the
+ data.
+
+ rois: (non-differentiable) RoIs (Regions of Interest) to pool over. Should
+ be a 2-D tensor of shape (num_rois, 5) given as [[batch_id, x1, y1, x2,
+ y2], ...].
+
+ pooled_shape: ROI pool output shape (height, width).
+
+ spatial_scale: Multiplicative spatial scale factor to translate ROI
+ coordinates from their input scale to the scale used when pooling.
+ """
+
+ schema = get_schema("MaxRoiPool", 22, "")
+ op = Op(self, "MaxRoiPool", schema)
+ return op(
+ *self._prepare_inputs(schema, X, rois),
+ pooled_shape=pooled_shape,
+ spatial_scale=spatial_scale,
+ )
+
+ T1_MaxUnpool = TypeVar("T1_MaxUnpool", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T2_MaxUnpool: TypeAlias = INT64
+
+ def MaxUnpool(
+ self,
+ X: T1_MaxUnpool,
+ I: T2_MaxUnpool,
+ output_shape: Optional[T2_MaxUnpool] = None,
+ *,
+ kernel_shape: Sequence[int],
+ pads: Optional[Sequence[int]] = None,
+ strides: Optional[Sequence[int]] = None,
+ ) -> T1_MaxUnpool:
+ r"""[🌐 MaxUnpool(22)](https://onnx.ai/onnx/operators/onnx__MaxUnpool.html#maxunpool-22 "Online Documentation")
+
+
+ MaxUnpool essentially computes the partial inverse of the MaxPool op.
+ The input information to this op is typically the output information from a MaxPool op. The first
+ input tensor X is the tensor that needs to be unpooled, which is typically the pooled tensor (first output)
+ from MaxPool. The second input tensor, I, contains the indices to the (locally maximal) elements corresponding
+ to the elements in the first input tensor X. Input tensor I is typically the second output of the MaxPool op.
+ The third (optional) input is a tensor that specifies the output size of the unpooling operation.
+
+ MaxUnpool is intended to do 'partial' inverse of the MaxPool op. 'Partial' because all the non-maximal
+ values from the original input to MaxPool are set to zero in the output of the MaxUnpool op. Pooling
+ the result of an unpooling operation should give back the original input to the unpooling op.
+
+ MaxUnpool can produce the same output size for several input sizes, which makes unpooling op ambiguous.
+ The third input argument, output_size, is meant to disambiguate the op and produce output tensor of
+ known/predictable size.
+
+ In addition to the inputs, MaxUnpool takes three attributes, namely kernel_shape, strides, and pads,
+ which define the exact unpooling op. The attributes typically have the same values as the corresponding
+ pooling op that the unpooling op is trying to invert.
+
+
+ Args:
+ X: (differentiable) Input data tensor that has to be unpooled. This tensor
+ is typically the first output of the MaxPool op.Dimensions for image
+ case are (N x C x H x W), where N is the batch size, C is the number of
+ channels, and H and W are the height and the width of the data. For
+ non-image case, the dimensions are in the form of (N x C x D1 x D2 ...
+ Dn), where N is the batch size. Optionally, if dimension denotation is
+ in effect, the operation expects the input data tensor to arrive with
+ the dimension denotation of [DATA_BATCH, DATA_CHANNEL, DATA_FEATURE,
+ DATA_FEATURE ...].
+
+ I: (non-differentiable) Input data tensor containing the indices
+ corresponding to elements in the first input tensor X.This tensor is
+ typically the second output of the MaxPool op.Dimensions must be the
+ same as input tensor X. The indices are linear, i.e. computed
+ considering the tensor as flattened 1-D tensor, assuming row-major
+ storage. Also, the linear indices should not consider padding. So the
+ values in indices are in the range [0, N x C x D1 x ... x Dn).
+
+ output_shape: (optional, non-differentiable) The shape of the output can be
+ explicitly set which will cause pads values to be auto generated. If
+ 'output_shape' is specified, 'pads' values are ignored.
+
+ kernel_shape: The size of the kernel along each axis.
+
+ pads: Padding for the beginning and ending along each spatial axis, it can
+ take any value greater than or equal to 0. The value represent the
+ number of pixels added to the beginning and end part of the
+ corresponding axis. `pads` format should be as follow [x1_begin,
+ x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels
+ added at the beginning of axis `i` and xi_end, the number of pixels
+ added at the end of axis `i`. This attribute cannot be used
+ simultaneously with auto_pad attribute. If not present, the padding
+ defaults to 0 along start and end of each spatial axis.
+
+ strides: Stride along each spatial axis. If not present, the stride defaults
+ to 1 along each spatial axis.
+ """
+
+ schema = get_schema("MaxUnpool", 22, "")
+ op = Op(self, "MaxUnpool", schema)
+ return op(
+ *self._prepare_inputs(schema, X, I, output_shape),
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ T_Mish = TypeVar("T_Mish", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Mish(self, X: T_Mish) -> T_Mish:
+ r"""[🌐 Mish(22)](https://onnx.ai/onnx/operators/onnx__Mish.html#mish-22 "Online Documentation")
+
+
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
+
+ Perform the linear unit element-wise on the input tensor X using formula:
+
+ ::
+
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
+
+
+
+
+ Args:
+ X: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Mish", 22, "")
+ op = Op(self, "Mish", schema)
+ return op(*self._prepare_inputs(schema, X))
+
+ T1_Multinomial = TypeVar("T1_Multinomial", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T2_Multinomial: TypeAlias = Union[INT32, INT64]
+
+ def Multinomial(
+ self,
+ input: T1_Multinomial,
+ *,
+ dtype: int = 6,
+ sample_size: int = 1,
+ seed: Optional[float] = None,
+ ) -> T2_Multinomial:
+ r"""[🌐 Multinomial(22)](https://onnx.ai/onnx/operators/onnx__Multinomial.html#multinomial-22 "Online Documentation")
+
+
+ Generate a tensor of samples from a multinomial distribution according to the probabilities
+ of each of the possible outcomes.
+
+
+ Args:
+ input: Input tensor with shape [batch_size, class_size], where class_size is
+ the number of all possible outcomes. Each value along the axis zero
+ represents the unnormalized log-probability of each corresponding
+ outcome in a batch.
+
+ dtype: (Optional) The data type for the elements of the output tensor, if
+ not specified, we will use int32.
+
+ sample_size: Number of times to sample.
+
+ seed: (Optional) Seed to the random generator, if not specified we will auto
+ generate one.
+ """
+
+ schema = get_schema("Multinomial", 22, "")
+ op = Op(self, "Multinomial", schema)
+ return op(
+ *self._prepare_inputs(schema, input),
+ dtype=dtype,
+ sample_size=sample_size,
+ seed=seed,
+ )
+
+ T_NegativeLogLikelihoodLoss = TypeVar(
+ "T_NegativeLogLikelihoodLoss", BFLOAT16, DOUBLE, FLOAT, FLOAT16
+ )
+
+ Tind_NegativeLogLikelihoodLoss = TypeVar("Tind_NegativeLogLikelihoodLoss", INT32, INT64)
+
+ def NegativeLogLikelihoodLoss(
+ self,
+ input: T_NegativeLogLikelihoodLoss,
+ target: Tind_NegativeLogLikelihoodLoss,
+ weight: Optional[T_NegativeLogLikelihoodLoss] = None,
+ *,
+ ignore_index: Optional[int] = None,
+ reduction: str = "mean",
+ ) -> T_NegativeLogLikelihoodLoss:
+ r"""[🌐 NegativeLogLikelihoodLoss(22)](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html#negativeloglikelihoodloss-22 "Online Documentation")
+
+
+ A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss.
+ Its "input" tensor has the shape of (N, C, d1, d2, ..., dk) where k >= 0.
+ The "input" tensor contains log-probabilities for input[n, :, d_1, d_2,..., d_k] being in a class of [0, C).
+ The operator's "target" input tensor has the shape of (N, d1, d2, ..., dk). It encodes class labels (one of C classes)
+ or it may contain a special value (indicated by an attribute ignore_index) for N x d1 x d2 x ... x dk samples.
+ The loss value for input[n, :, d_1, d_2,...d_k] being classified as class c = target[n][d_1][d_2]...[d_k] is computed as:
+
+ ::
+
+ loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k].
+
+
+
+ When an optional "weight" is provided, the sample loss is calculated as:
+
+ ::
+
+ loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k] * weight[c].
+
+
+
+ loss is zero for the case when target-value equals ignore_index.
+
+ ::
+
+ loss[n][d_1][d_2]...[d_k] = 0, when target[n][d_1][d_2]...[d_k] = ignore_index
+
+
+
+ If "reduction" attribute is set to "none", the operator's output will be the above loss with shape (N, d1, d2, ..., dk).
+ If "reduction" attribute is set to "mean" (the default attribute value), the output loss is (weight) averaged:
+
+ ::
+
+ mean(loss), if "weight" is not provided,
+
+
+
+ or if weight is provided,
+
+ ::
+
+ sum(loss) / sum(weight[target[n][d_1][d_2]...[d_k]]]), for all samples.
+
+
+
+ If "reduction" attribute is set to "sum", the output is a scalar: `sum(loss)`.
+
+ See also https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss.
+
+ Example 1:
+
+ ::
+
+ // negative log likelihood loss, "none" reduction
+ N, C, d1 = 2, 3, 2
+ input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
+ [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
+ target = [[2, 1], [0, 2]]
+
+ loss = np.zeros((N, d1))
+ for n in range(N):
+ for d_1 in range(d1):
+ c = target[n][d_1]
+ loss[n][d_1] = -input[n][c][d_1]
+
+ // print(loss)
+ // [[-3. -2.]
+ // [-0. -2.]]
+
+
+
+ Example 2:
+
+ ::
+
+ // weighted negative log likelihood loss, sum reduction
+ N, C, d1 = 2, 3, 2
+ input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
+ [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
+ target = [[2, 1], [0, 2]]
+ weight = [0.2, 0.3, 0.1]
+ loss = np.zeros((N, d1))
+ for n in range(N):
+ for d_1 in range(d1):
+ c = target[n][d_1]
+ loss[n][d_1] = -input[n][c][d_1] * weight[c]
+
+ loss = np.sum(loss)
+ // print(loss)
+ // -1.1
+
+
+
+ Example 3:
+
+ ::
+
+ // weighted negative log likelihood loss, mean reduction
+ N, C, d1 = 2, 3, 2
+ input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
+ [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
+ target = [[2, 1], [0, 2]]
+ weight = [0.2, 0.3, 0.1]
+ loss = np.zeros((N, d1))
+ weight_total = 0
+ for n in range(N):
+ for d_1 in range(d1):
+ c = target[n][d_1]
+ loss[n][d_1] = -input[n][c][d_1] * weight[c]
+ weight_total = weight_total + weight[c]
+
+ loss = np.sum(loss) / weight_total
+ // print(loss)
+ // -1.57
+
+
+
+
+ Args:
+ input: (differentiable) Input tensor of shape (N, C) or (N, C, d1, d2, ...,
+ dk).
+
+ target: (non-differentiable) Target tensor of shape (N) or (N, d1, d2, ...,
+ dk). Target element value shall be in range of [0, C). If ignore_index
+ is specified, it may have a value outside [0, C) and the target values
+ should either be in the range [0, C) or have the value ignore_index.
+
+ weight: (optional, non-differentiable) Optional rescaling weight tensor. If
+ given, it has to be a tensor of size C. Otherwise, it is treated as if
+ having all ones.
+
+ ignore_index: Specifies a target value that is ignored and does not
+ contribute to the input gradient. It's an optional value.
+
+ reduction: Type of reduction to apply to loss: none, sum, mean (default).
+ 'none': the output is the loss for each sample. 'sum': the output will
+ be summed. 'mean': the sum of the output will be divided by the sum of
+ applied weights.
+ """
+
+ schema = get_schema("NegativeLogLikelihoodLoss", 22, "")
+ op = Op(self, "NegativeLogLikelihoodLoss", schema)
+ return op(
+ *self._prepare_inputs(schema, input, target, weight),
+ ignore_index=ignore_index,
+ reduction=reduction,
+ )
+
+ T_RNN = TypeVar("T_RNN", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T1_RNN: TypeAlias = INT32
+
+ def RNN(
+ self,
+ X: T_RNN,
+ W: T_RNN,
+ R: T_RNN,
+ B: Optional[T_RNN] = None,
+ sequence_lens: Optional[T1_RNN] = None,
+ initial_h: Optional[T_RNN] = None,
+ *,
+ activation_alpha: Optional[Sequence[float]] = None,
+ activation_beta: Optional[Sequence[float]] = None,
+ activations: Sequence[str] = ("Tanh", "Tanh"),
+ clip: Optional[float] = None,
+ direction: str = "forward",
+ hidden_size: Optional[int] = None,
+ layout: int = 0,
+ ) -> Tuple[T_RNN, T_RNN]:
+ r"""[🌐 RNN(22)](https://onnx.ai/onnx/operators/onnx__RNN.html#rnn-22 "Online Documentation")
+
+
+ Computes an one-layer simple RNN. This operator is usually supported
+ via some custom implementation such as CuDNN.
+
+ Notations:
+
+ * `X` - input tensor
+ * `i` - input gate
+ * `t` - time step (t-1 means previous time step)
+ * `Wi` - W parameter weight matrix for input gate
+ * `Ri` - R recurrence weight matrix for input gate
+ * `Wbi` - W parameter bias vector for input gate
+ * `Rbi` - R parameter bias vector for input gate
+ * `WBi` - W parameter weight matrix for backward input gate
+ * `RBi` - R recurrence weight matrix for backward input gate
+ * `WBbi` - WR bias vectors for backward input gate
+ * `RBbi` - RR bias vectors for backward input gate
+ * `H` - Hidden state
+ * `num_directions` - 2 if direction == bidirectional else 1
+
+ Activation functions:
+
+ * Relu(x) - max(0, x)
+ * Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x})
+ * Sigmoid(x) - 1/(1 + e^{-x})
+
+ NOTE: Below are optional
+
+ * Affine(x) - alpha*x + beta
+ * LeakyRelu(x) - x if x >= 0 else alpha * x
+ * ThresholdedRelu(x) - x if x >= alpha else 0
+ * ScaledTanh(x) - alpha*Tanh(beta*x)
+ * HardSigmoid(x) - min(max(alpha*x + beta, 0), 1)
+ * Elu(x) - x if x >= 0 else alpha*(e^x - 1)
+ * Softsign(x) - x/(1 + |x|)
+ * Softplus(x) - log(1 + e^x)
+
+ Equations (Default: f=Tanh):
+
+ * Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
+ This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
+
+
+ Args:
+ X: (differentiable) The input sequences packed (and potentially padded) into
+ one 3-D tensor with the shape of `[seq_length, batch_size, input_size]`.
+
+ W: (differentiable) The weight tensor for input gate. Concatenation of `Wi`
+ and `WBi` (if bidirectional). The tensor has shape `[num_directions,
+ hidden_size, input_size]`.
+
+ R: (differentiable) The recurrence weight tensor. Concatenation of `Ri` and
+ `RBi` (if bidirectional). The tensor has shape `[num_directions,
+ hidden_size, hidden_size]`.
+
+ B: (optional, differentiable) The bias tensor for input gate. Concatenation
+ of `[Wbi, Rbi]` and `[WBbi, RBbi]` (if bidirectional). The tensor has
+ shape `[num_directions, 2*hidden_size]`. Optional: If not specified -
+ assumed to be 0.
+
+ sequence_lens: (optional, non-differentiable) Optional tensor specifying
+ lengths of the sequences in a batch. If not specified - assumed all
+ sequences in the batch to have length `seq_length`. It has shape
+ `[batch_size]`.
+
+ initial_h: (optional, non-differentiable) Optional initial value of the
+ hidden. If not specified - assumed to be 0. It has shape
+ `[num_directions, batch_size, hidden_size]`.
+
+ activation_alpha: Optional scaling values used by some activation functions.
+ The values are consumed in the order of activation functions, for
+ example (f, g, h) in LSTM. Default values are the same as of
+ corresponding ONNX operators.For example with LeakyRelu, the default
+ alpha is 0.01.
+
+ activation_beta: Optional scaling values used by some activation functions.
+ The values are consumed in the order of activation functions, for
+ example (f, g, h) in LSTM. Default values are the same as of
+ corresponding ONNX operators.
+
+ activations: One (or two if bidirectional) activation function for input
+ gate. The activation function must be one of the activation functions
+ specified above. Optional: Default `Tanh` if not specified.
+
+ clip: Cell clip threshold. Clipping bounds the elements of a tensor in the
+ range of [-threshold, +threshold] and is applied to the input of
+ activations. No clip if not specified.
+
+ direction: Specify if the RNN is forward, reverse, or bidirectional. Must be
+ one of forward (default), reverse, or bidirectional.
+
+ hidden_size: Number of neurons in the hidden layer
+
+ layout: The shape format of inputs X, initial_h and outputs Y, Y_h. If 0,
+ the following shapes are expected: X.shape = [seq_length, batch_size,
+ input_size], Y.shape = [seq_length, num_directions, batch_size,
+ hidden_size], initial_h.shape = Y_h.shape = [num_directions, batch_size,
+ hidden_size]. If 1, the following shapes are expected: X.shape =
+ [batch_size, seq_length, input_size], Y.shape = [batch_size, seq_length,
+ num_directions, hidden_size], initial_h.shape = Y_h.shape = [batch_size,
+ num_directions, hidden_size].
+ """
+
+ schema = get_schema("RNN", 22, "")
+ op = Op(self, "RNN", schema)
+ return op(
+ *self._prepare_inputs(schema, X, W, R, B, sequence_lens, initial_h),
+ activation_alpha=activation_alpha,
+ activation_beta=activation_beta,
+ activations=activations,
+ clip=clip,
+ direction=direction,
+ hidden_size=hidden_size,
+ layout=layout,
+ )
+
+ T_RandomNormal: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16]
+
+ def RandomNormal(
+ self,
+ *,
+ dtype: int = 1,
+ mean: float = 0.0,
+ scale: float = 1.0,
+ seed: Optional[float] = None,
+ shape: Sequence[int],
+ ) -> T_RandomNormal:
+ r"""[🌐 RandomNormal(22)](https://onnx.ai/onnx/operators/onnx__RandomNormal.html#randomnormal-22 "Online Documentation")
+
+
+ Generate a tensor with random values drawn from a normal distribution. The shape
+ of the tensor is specified by the `shape` argument and the parameter of the normal distribution
+ specified by `mean` and `scale`.
+
+ The data type is specified by the 'dtype' argument. The 'dtype' argument must
+ be one of the data types specified in the 'DataType' enum field in the
+ TensorProto message.
+
+
+ Args:
+ dtype: The data type for the elements of the output tensor. Default is
+ TensorProto::FLOAT.
+
+ mean: The mean of the normal distribution.
+
+ scale: The standard deviation of the normal distribution.
+
+ seed: (Optional) Seed to the random generator, if not specified we will auto
+ generate one.
+
+ shape: The shape of the output tensor.
+ """
+
+ schema = get_schema("RandomNormal", 22, "")
+ op = Op(self, "RandomNormal", schema)
+ return op(dtype=dtype, mean=mean, scale=scale, seed=seed, shape=shape)
+
+ T1_RandomNormalLike = TypeVar(
+ "T1_RandomNormalLike",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ T2_RandomNormalLike: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16]
+
+ def RandomNormalLike(
+ self,
+ input: T1_RandomNormalLike,
+ *,
+ dtype: Optional[int] = None,
+ mean: float = 0.0,
+ scale: float = 1.0,
+ seed: Optional[float] = None,
+ ) -> T2_RandomNormalLike:
+ r"""[🌐 RandomNormalLike(22)](https://onnx.ai/onnx/operators/onnx__RandomNormalLike.html#randomnormallike-22 "Online Documentation")
+
+
+ Generate a tensor with random values drawn from a normal distribution.
+ The shape of the output tensor is copied from the shape of the input tensor,
+ and the parameters of the normal distribution are specified by `mean` and `scale`.
+
+ The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided.
+ The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
+ TensorProto message, and be valid as an output type.
+
+
+ Args:
+ input: Input tensor to copy shape and optionally type information from.
+
+ dtype: (Optional) The data type for the elements of the output tensor, if
+ not specified, we will use the data type of the input tensor.
+
+ mean: The mean of the normal distribution.
+
+ scale: The standard deviation of the normal distribution.
+
+ seed: (Optional) Seed to the random generator, if not specified we will auto
+ generate one.
+ """
+
+ schema = get_schema("RandomNormalLike", 22, "")
+ op = Op(self, "RandomNormalLike", schema)
+ return op(
+ *self._prepare_inputs(schema, input),
+ dtype=dtype,
+ mean=mean,
+ scale=scale,
+ seed=seed,
+ )
+
+ T_RandomUniform: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16]
+
+ def RandomUniform(
+ self,
+ *,
+ dtype: int = 1,
+ high: float = 1.0,
+ low: float = 0.0,
+ seed: Optional[float] = None,
+ shape: Sequence[int],
+ ) -> T_RandomUniform:
+ r"""[🌐 RandomUniform(22)](https://onnx.ai/onnx/operators/onnx__RandomUniform.html#randomuniform-22 "Online Documentation")
+
+
+ Generate a tensor with random values drawn from a uniform distribution. The shape
+ of the tensor is specified by the `shape` argument and the range by `low` and `high`.
+
+ The data type is specified by the 'dtype' argument. The 'dtype' argument must
+ be one of the data types specified in the 'DataType' enum field in the
+ TensorProto message.
+
+
+ Args:
+ dtype: The data type for the elements of the output tensor. If not
+ specified, default is TensorProto::FLOAT.
+
+ high: Upper boundary of the output values.
+
+ low: Lower boundary of the output values.
+
+ seed: (Optional) Seed to the random generator, if not specified we will auto
+ generate one.
+
+ shape: The shape of the output tensor.
+ """
+
+ schema = get_schema("RandomUniform", 22, "")
+ op = Op(self, "RandomUniform", schema)
+ return op(dtype=dtype, high=high, low=low, seed=seed, shape=shape)
+
+ T1_RandomUniformLike = TypeVar(
+ "T1_RandomUniformLike",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ T2_RandomUniformLike: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16]
+
+ def RandomUniformLike(
+ self,
+ input: T1_RandomUniformLike,
+ *,
+ dtype: Optional[int] = None,
+ high: float = 1.0,
+ low: float = 0.0,
+ seed: Optional[float] = None,
+ ) -> T2_RandomUniformLike:
+ r"""[🌐 RandomUniformLike(22)](https://onnx.ai/onnx/operators/onnx__RandomUniformLike.html#randomuniformlike-22 "Online Documentation")
+
+
+ Generate a tensor with random values drawn from a uniform distribution.
+ The shape of the output tensor is copied from the shape of the input tensor,
+ and the parameters of the uniform distribution are specified by `low` and `high`.
+
+ The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided.
+ The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
+ TensorProto message and be valid as an output type.
+
+
+ Args:
+ input: Input tensor to copy shape and optionally type information from.
+
+ dtype: (Optional) The data type for the elements of the output tensor, if
+ not specified, we will use the data type of the input tensor.
+
+ high: Upper boundary of the output values.
+
+ low: Lower boundary of the output values.
+
+ seed: (Optional) Seed to the random generator, if not specified we will auto
+ generate one.
+ """
+
+ schema = get_schema("RandomUniformLike", 22, "")
+ op = Op(self, "RandomUniformLike", schema)
+ return op(
+ *self._prepare_inputs(schema, input),
+ dtype=dtype,
+ high=high,
+ low=low,
+ seed=seed,
+ )
+
+ T1_RoiAlign = TypeVar("T1_RoiAlign", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T2_RoiAlign: TypeAlias = INT64
+
+ def RoiAlign(
+ self,
+ X: T1_RoiAlign,
+ rois: T1_RoiAlign,
+ batch_indices: T2_RoiAlign,
+ *,
+ coordinate_transformation_mode: str = "half_pixel",
+ mode: str = "avg",
+ output_height: int = 1,
+ output_width: int = 1,
+ sampling_ratio: int = 0,
+ spatial_scale: float = 1.0,
+ ) -> T1_RoiAlign:
+ r"""[🌐 RoiAlign(22)](https://onnx.ai/onnx/operators/onnx__RoiAlign.html#roialign-22 "Online Documentation")
+
+
+ Region of Interest (RoI) align operation described in the
+ [Mask R-CNN paper](https://arxiv.org/abs/1703.06870).
+ RoiAlign consumes an input tensor X and region of interests (rois)
+ to apply pooling across each RoI; it produces a 4-D tensor of shape
+ (num_rois, C, output_height, output_width).
+
+ RoiAlign is proposed to avoid the misalignment by removing
+ quantizations while converting from original image into feature
+ map and from feature map into RoI feature; in each ROI bin,
+ the value of the sampled locations are computed directly
+ through bilinear interpolation.
+
+
+ Args:
+ X: Input data tensor from the previous operator; 4-D feature map of shape
+ (N, C, H, W), where N is the batch size, C is the number of channels,
+ and H and W are the height and the width of the data.
+
+ rois: RoIs (Regions of Interest) to pool over; rois is 2-D input of shape
+ (num_rois, 4) given as [[x1, y1, x2, y2], ...]. The RoIs' coordinates
+ are in the coordinate system of the input image. Each coordinate set has
+ a 1:1 correspondence with the 'batch_indices' input.
+
+ batch_indices: 1-D tensor of shape (num_rois,) with each element denoting
+ the index of the corresponding image in the batch.
+
+ coordinate_transformation_mode: Allowed values are 'half_pixel' and
+ 'output_half_pixel'. Use the value 'half_pixel' to pixel shift the input
+ coordinates by -0.5 (the recommended behavior). Use the value
+ 'output_half_pixel' to omit the pixel shift for the input (use this for
+ a backward-compatible behavior).
+
+ mode: The pooling method. Two modes are supported: 'avg' and 'max'. Default
+ is 'avg'.
+
+ output_height: default 1; Pooled output Y's height.
+
+ output_width: default 1; Pooled output Y's width.
+
+ sampling_ratio: Number of sampling points in the interpolation grid used to
+ compute the output value of each pooled output bin. If > 0, then exactly
+ sampling_ratio x sampling_ratio grid points are used. If == 0, then an
+ adaptive number of grid points are used (computed as ceil(roi_width /
+ output_width), and likewise for height). Default is 0.
+
+ spatial_scale: Multiplicative spatial scale factor to translate ROI
+ coordinates from their input spatial scale to the scale used when
+ pooling, i.e., spatial scale of the input feature map X relative to the
+ input image. E.g.; default is 1.0f.
+ """
+
+ schema = get_schema("RoiAlign", 22, "")
+ op = Op(self, "RoiAlign", schema)
+ return op(
+ *self._prepare_inputs(schema, X, rois, batch_indices),
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ mode=mode,
+ output_height=output_height,
+ output_width=output_width,
+ sampling_ratio=sampling_ratio,
+ spatial_scale=spatial_scale,
+ )
+
+ T_Round = TypeVar("T_Round", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Round(self, X: T_Round) -> T_Round:
+ r"""[🌐 Round(22)](https://onnx.ai/onnx/operators/onnx__Round.html#round-22 "Online Documentation")
+
+
+ Round takes one input Tensor and rounds the values, element-wise, meaning
+ it finds the nearest integer for each value.
+ In case of halves, the rule is to round them to the nearest even integer.
+ If input x is integral, +0, -0, NaN, or infinite, x itself is returned.
+ The output tensor has the same shape and type as the input.
+
+ Examples:
+ ::
+
+ round([0.9]) = [1.0]
+ round([2.5]) = [2.0]
+ round([2.3]) = [2.0]
+ round([1.5]) = [2.0]
+ round([-4.5]) = [-4.0]
+
+
+
+
+ Args:
+ X: (non-differentiable) Input tensor
+ """
+
+ schema = get_schema("Round", 22, "")
+ op = Op(self, "Round", schema)
+ return op(*self._prepare_inputs(schema, X))
+
+ T_Selu = TypeVar("T_Selu", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Selu(
+ self,
+ X: T_Selu,
+ *,
+ alpha: float = 1.6732631921768188,
+ gamma: float = 1.0507010221481323,
+ ) -> T_Selu:
+ r"""[🌐 Selu(22)](https://onnx.ai/onnx/operators/onnx__Selu.html#selu-22 "Online Documentation")
+
+
+ Selu takes one input data (Tensor) and produces one output data
+ (Tensor) where the scaled exponential linear unit function,
+ `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`,
+ is applied to the tensor elementwise.
+
+
+ Args:
+ X: (differentiable) Input tensor
+
+ alpha: Coefficient of SELU default to 1.67326319217681884765625 (i.e.,
+ float32 approximation of 1.6732632423543772848170429916717).
+
+ gamma: Coefficient of SELU default to 1.05070102214813232421875 (i.e.,
+ float32 approximation of 1.0507009873554804934193349852946).
+ """
+
+ schema = get_schema("Selu", 22, "")
+ op = Op(self, "Selu", schema)
+ return op(*self._prepare_inputs(schema, X), alpha=alpha, gamma=gamma)
+
+ T_Sin = TypeVar("T_Sin", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Sin(self, input: T_Sin) -> T_Sin:
+ r"""[🌐 Sin(22)](https://onnx.ai/onnx/operators/onnx__Sin.html#sin-22 "Online Documentation")
+
+
+ Calculates the sine of the given input tensor, element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Sin", 22, "")
+ op = Op(self, "Sin", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Sinh = TypeVar("T_Sinh", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Sinh(self, input: T_Sinh) -> T_Sinh:
+ r"""[🌐 Sinh(22)](https://onnx.ai/onnx/operators/onnx__Sinh.html#sinh-22 "Online Documentation")
+
+
+ Calculates the hyperbolic sine of the given input tensor element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Sinh", 22, "")
+ op = Op(self, "Sinh", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Softplus = TypeVar("T_Softplus", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Softplus(self, X: T_Softplus) -> T_Softplus:
+ r"""[🌐 Softplus(22)](https://onnx.ai/onnx/operators/onnx__Softplus.html#softplus-22 "Online Documentation")
+
+
+ Softplus takes one input data (Tensor) and produces one output data
+ (Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to
+ the tensor elementwise.
+
+
+ Args:
+ X: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Softplus", 22, "")
+ op = Op(self, "Softplus", schema)
+ return op(*self._prepare_inputs(schema, X))
+
+ T_Softsign = TypeVar("T_Softsign", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Softsign(self, input: T_Softsign) -> T_Softsign:
+ r"""[🌐 Softsign(22)](https://onnx.ai/onnx/operators/onnx__Softsign.html#softsign-22 "Online Documentation")
+
+
+ Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Softsign", 22, "")
+ op = Op(self, "Softsign", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_Tan = TypeVar("T_Tan", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Tan(self, input: T_Tan) -> T_Tan:
+ r"""[🌐 Tan(22)](https://onnx.ai/onnx/operators/onnx__Tan.html#tan-22 "Online Documentation")
+
+
+ Calculates the tangent of the given input tensor, element-wise.
+
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Tan", 22, "")
+ op = Op(self, "Tan", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ T_ThresholdedRelu = TypeVar("T_ThresholdedRelu", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def ThresholdedRelu(
+ self, X: T_ThresholdedRelu, *, alpha: float = 1.0
+ ) -> T_ThresholdedRelu:
+ r"""[🌐 ThresholdedRelu(22)](https://onnx.ai/onnx/operators/onnx__ThresholdedRelu.html#thresholdedrelu-22 "Online Documentation")
+
+
+ ThresholdedRelu takes one input data (Tensor) and produces one output data
+ (Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise,
+ is applied to the tensor elementwise.
+
+
+ Args:
+ X: (differentiable) Input tensor
+
+ alpha: Threshold value
+ """
+
+ schema = get_schema("ThresholdedRelu", 22, "")
+ op = Op(self, "ThresholdedRelu", schema)
+ return op(*self._prepare_inputs(schema, X), alpha=alpha)
diff --git a/onnxscript/onnx_opset/_impl/opset23.py b/onnxscript/onnx_opset/_impl/opset23.py
new file mode 100644
index 0000000000..73b7480073
--- /dev/null
+++ b/onnxscript/onnx_opset/_impl/opset23.py
@@ -0,0 +1,2210 @@
+# --------------------------------------------------------------------------
+# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
+# ⚙️ Generated by 'python -m opgen'
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+# pylint: disable=W0221,W0222,R0901,W0237
+# mypy: disable-error-code=override
+# ruff: noqa: D214, D402, D405, D411, D412, D416
+# --------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from typing import Optional, Sequence, Tuple, TypeVar, Union
+
+from onnx import GraphProto, SparseTensorProto, TensorProto
+from onnx.defs import get_schema
+from typing_extensions import TypeAlias
+
+from onnxscript.onnx_opset._impl.opset22 import Opset22
+from onnxscript.onnx_types import (
+ BFLOAT16,
+ BOOL,
+ COMPLEX64,
+ COMPLEX128,
+ DOUBLE,
+ FLOAT,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT16,
+ INT4,
+ INT8,
+ INT16,
+ INT32,
+ INT64,
+ STRING,
+ UINT4,
+ UINT8,
+ UINT16,
+ UINT32,
+ UINT64,
+)
+from onnxscript.values import Op, Opset
+
+
+class Opset23(Opset22):
+ def __new__(cls):
+ return Opset.__new__(cls, "", 23)
+
+ T1_Attention = TypeVar("T1_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T2_Attention = TypeVar("T2_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ U_Attention = TypeVar(
+ "U_Attention",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ def Attention(
+ self,
+ Q: T1_Attention,
+ K: T1_Attention,
+ V: T2_Attention,
+ attn_mask: Optional[U_Attention] = None,
+ past_key: Optional[T1_Attention] = None,
+ past_value: Optional[T2_Attention] = None,
+ *,
+ is_causal: int = 0,
+ kv_num_heads: Optional[int] = None,
+ q_num_heads: Optional[int] = None,
+ qk_matmul_output_mode: int = 0,
+ scale: Optional[float] = None,
+ softcap: float = 0.0,
+ softmax_precision: Optional[int] = None,
+ ) -> Tuple[T1_Attention, T1_Attention, T2_Attention, T1_Attention]:
+ r"""[🌐 Attention(23)](https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23 "Online Documentation")
+
+
+
+ Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
+
+ This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
+
+ For self attention, `kv_sequence_length` equals to `q_sequence_length`.
+
+ For cross attention, query and key might have different lengths.
+
+ This operator also covers the 3 following variants based on the number of heads:
+ 1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`.
+ 2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
+ 3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`.
+
+ Attention bias to be added is calculated based on `attn_mask` input and `is_causal attribute`, only one of which can be provided.
+ 1) If `is_causal` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.
+ 2) `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
+
+ Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
+ The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:
+
+ ::
+
+ The following pattern is applied by this operator:
+ Q K V
+ | | |
+ Q*sqrt(scale) K*sqrt(scale) |
+ | | |
+ | Transpose |
+ | | |
+ ---MatMul--- |
+ | |
+ at_mask---Add |
+ | |
+ softcap (if provided) |
+ | |
+ Softmax |
+ | |
+ -----MatMul------
+ |
+ Y
+
+
+
+
+
+ Args:
+ Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads,
+ q_sequence_length, head_size)` or 3D tensor with shape `(batch_size,
+ q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor,
+ `q_hidden_size = q_num_heads * head_size`
+
+ K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads,
+ kv_sequence_length, head_size)` or 3D tensor with shape `(batch_size,
+ kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor,
+ `k_hidden_size = kv_num_heads * head_size`
+
+ V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads,
+ kv_sequence_length, v_head_size)` or 3D tensor with shape `(batch_size,
+ kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor,
+ `v_hidden_size = kv_num_heads * v_head_size`
+
+ attn_mask: (optional) Attention mask. Shape must be broadcastable to 4D
+ tensor with shape `(batch_size, q_num_heads, q_sequence_length,
+ total_sequence_length)` where `total_sequence_length =
+ past_sequence_length + kv_sequence_length.` Two types of masks are
+ supported. A boolean mask where a value of `True` indicates that the
+ element should take part in attention. Also supports a float mask of the
+ same type as query, key, value that is added to the attention score.
+
+ past_key: (optional) past state cache for key with shape `(batch_size,
+ kv_num_heads, past_sequence_length, head_size)`
+
+ past_value: (optional) past state cache for value with shape `(batch_size,
+ kv_num_heads, past_sequence_length, v_head_size)`
+
+ is_causal: If set to `1`, the attention masking is a lower triangular matrix
+ when the mask is a square matrix. The attention masking has the form of
+ the upper left causal bias due to the alignment.
+
+ kv_num_heads: Number of heads of key and value. Must be used with 3D inputs
+ of Q, K and V.
+
+ q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K
+ and V.
+
+ qk_matmul_output_mode: If set to `0`, qk_matmul_output is the output of qk
+ matmul. If set to `1`, qk_matmul_output includes the addition of the
+ attention mask to the output of qk matmul. If set to `2`,
+ qk_matmul_output is the output after the softcap operation. If set to
+ `3`, qk_matmul_output is the output after the softmax operation. Default
+ value is 0.
+
+ scale: Scaling factor applied to $Q*K^T$. Default value is
+ `1/sqrt(head_size)`. To prevent [numerical
+ overflow](https://tinyurl.com/sudb9s96), scale `Q`, `K` by `sqrt(scale)`
+ before matmul.
+
+ softcap: Softcap value for attention weights. Default value is 0.
+
+ softmax_precision: The floating-point precision used in softmax computation.
+ If softmax precision is not provided, the same precision as the input of
+ softmax (Q and K) is used.
+ """
+
+ schema = get_schema("Attention", 23, "")
+ op = Op(self, "Attention", schema)
+ return op(
+ *self._prepare_inputs(schema, Q, K, V, attn_mask, past_key, past_value),
+ is_causal=is_causal,
+ kv_num_heads=kv_num_heads,
+ q_num_heads=q_num_heads,
+ qk_matmul_output_mode=qk_matmul_output_mode,
+ scale=scale,
+ softcap=softcap,
+ softmax_precision=softmax_precision,
+ )
+
+ T1_Cast = TypeVar(
+ "T1_Cast",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T2_Cast: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def Cast(self, input: T1_Cast, *, saturate: int = 1, to: int) -> T2_Cast:
+ r"""[🌐 Cast(23)](https://onnx.ai/onnx/operators/onnx__Cast.html#cast-23 "Online Documentation")
+
+
+ The operator casts the elements of a given input tensor to a data type
+ specified by the 'to' argument and returns an output tensor of the same size in
+ the converted type. The 'to' argument must be one of the data types specified
+ in the 'DataType' enum field in the TensorProto message.
+
+ Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations
+ (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may
+ yield result 100. There are some string literals reserved for special floating-point values;
+ "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively.
+ Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,
+ this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors
+ to string tensors, plain floating-point representation (such as "314.15926") would be used.
+ Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases
+ of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior.
+
+ Conversion from a numerical type to any numerical type is always allowed.
+ User must be aware of precision loss and value change caused by range difference between two types.
+ For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting
+ an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type.
+
+ In more detail, the conversion among numerical types should follow these rules
+ if the destination type is not a float 8 type.
+
+ * Casting from floating point to:
+ * floating point: +/- infinity if OOR (out of range).
+ * fixed point: undefined if OOR.
+ * bool: +/- 0.0 to False; all else to True.
+ * Casting from fixed point to:
+ * floating point: +/- infinity if OOR. (+ infinity in the case of uint)
+ * fixed point: when OOR, discard higher bits and reinterpret (with respect to two's complement representation for
+ signed types). For example, 200 (int16) -> -56 (int8).
+ * bool: zero to False; nonzero to True.
+ * Casting from bool to:
+ * floating point: `{1.0, 0.0}`.
+ * fixed point: `{1, 0}`.
+ * bool: no change.
+
+ Float 8 type were introduced to speed up the training of
+ deep models. By default the conversion of a float *x* obeys
+ to the following rules. `[x]` means the value rounded to
+ the target mantissa width.
+
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | -------- | -------- | -------- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | Inf | FLT_MAX | NaN | FLT_MAX | NaN |
+ | -Inf | -FLT_MAX | NaN | -FLT_MAX | NaN |
+ | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX |
+ | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX |
+ | else | RNE | RNE | RNE | RNE |
+
+ The behavior changes if the parameter 'saturate' is set to False.
+ The rules then become:
+
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | ------ | -------- | ---- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | -NaN | -NaN | NaN | -NaN | NaN |
+ | Inf | NaN | NaN | Inf | NaN |
+ | -Inf | -NaN | NaN | -Inf | NaN |
+ | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN |
+ | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN |
+ | else | RNE | RNE | RNE | RNE |
+
+
+ Args:
+ input: (differentiable) Input tensor to be cast.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz).
+ It is true by default. All cases are fully described in two tables
+ inserted in the operator description.
+
+ to: The data type to which the elements of the input tensor are cast.
+ Strictly must be one of the types from DataType enum in TensorProto
+ """
+
+ schema = get_schema("Cast", 23, "")
+ op = Op(self, "Cast", schema)
+ return op(*self._prepare_inputs(schema, input), saturate=saturate, to=to)
+
+ T1_CastLike = TypeVar(
+ "T1_CastLike",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T2_CastLike = TypeVar(
+ "T2_CastLike",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def CastLike(
+ self, input: T1_CastLike, target_type: T2_CastLike, *, saturate: int = 1
+ ) -> T2_CastLike:
+ r"""[🌐 CastLike(23)](https://onnx.ai/onnx/operators/onnx__CastLike.html#castlike-23 "Online Documentation")
+
+
+ The operator casts the elements of a given input tensor (the first input) to
+ the same data type as the elements of the second input tensor.
+ See documentation of the Cast operator for further details.
+
+
+ Args:
+ input: (differentiable) Input tensor to be cast.
+
+ target_type: (non-differentiable) The (first) input tensor will be cast to
+ produce a tensor of the same type as this (second input) tensor.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz).
+ It is true by default. Please refer to operator Cast description for
+ further details.
+ """
+
+ schema = get_schema("CastLike", 23, "")
+ op = Op(self, "CastLike", schema)
+ return op(*self._prepare_inputs(schema, input, target_type), saturate=saturate)
+
+ T_Constant: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def Constant(
+ self,
+ *,
+ sparse_value: Optional[SparseTensorProto] = None,
+ value: Optional[TensorProto] = None,
+ value_float: Optional[float] = None,
+ value_floats: Optional[Sequence[float]] = None,
+ value_int: Optional[int] = None,
+ value_ints: Optional[Sequence[int]] = None,
+ value_string: Optional[str] = None,
+ value_strings: Optional[Sequence[str]] = None,
+ ) -> T_Constant:
+ r"""[🌐 Constant(23)](https://onnx.ai/onnx/operators/onnx__Constant.html#constant-23 "Online Documentation")
+
+
+ This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value,
+ or value_* must be specified.
+
+
+ Args:
+ sparse_value: The value for the elements of the output tensor in sparse
+ format.
+
+ value: The value for the elements of the output tensor.
+
+ value_float: The value for the sole element for the scalar, float32, output
+ tensor.
+
+ value_floats: The values for the elements for the 1D, float32, output
+ tensor.
+
+ value_int: The value for the sole element for the scalar, int64, output
+ tensor.
+
+ value_ints: The values for the elements for the 1D, int64, output tensor.
+
+ value_string: The value for the sole element for the scalar, UTF-8 string,
+ output tensor.
+
+ value_strings: The values for the elements for the 1D, UTF-8 string, output
+ tensor.
+ """
+
+ schema = get_schema("Constant", 23, "")
+ op = Op(self, "Constant", schema)
+ return op(
+ sparse_value=sparse_value,
+ value=value,
+ value_float=value_float,
+ value_floats=value_floats,
+ value_int=value_int,
+ value_ints=value_ints,
+ value_string=value_string,
+ value_strings=value_strings,
+ )
+
+ T1_ConstantOfShape: TypeAlias = INT64
+
+ T2_ConstantOfShape: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def ConstantOfShape(
+ self, input: T1_ConstantOfShape, *, value: Optional[TensorProto] = None
+ ) -> T2_ConstantOfShape:
+ r"""[🌐 ConstantOfShape(23)](https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html#constantofshape-23 "Online Documentation")
+
+
+ Generate a tensor with given value and shape.
+
+
+ Args:
+ input: 1D tensor. The shape of the expected output tensor. If empty tensor
+ is given, the output would be a scalar. All values must be >= 0.
+
+ value: (Optional) The value of the output elements.Should be a one-element
+ tensor. If not specified, it defaults to a tensor of value 0 and
+ datatype float32
+ """
+
+ schema = get_schema("ConstantOfShape", 23, "")
+ op = Op(self, "ConstantOfShape", schema)
+ return op(*self._prepare_inputs(schema, input), value=value)
+
+ T1_DequantizeLinear = TypeVar(
+ "T1_DequantizeLinear",
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT8,
+ UINT16,
+ UINT4,
+ UINT8,
+ )
+
+ T2_DequantizeLinear = TypeVar("T2_DequantizeLinear", BFLOAT16, FLOAT, FLOAT16)
+
+ T3_DequantizeLinear: TypeAlias = Union[BFLOAT16, FLOAT, FLOAT16]
+
+ def DequantizeLinear(
+ self,
+ x: T1_DequantizeLinear,
+ x_scale: T2_DequantizeLinear,
+ x_zero_point: Optional[T1_DequantizeLinear] = None,
+ *,
+ axis: int = 1,
+ block_size: int = 0,
+ output_dtype: int = 0,
+ ) -> T3_DequantizeLinear:
+ r"""[🌐 DequantizeLinear(23)](https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html#dequantizelinear-23 "Online Documentation")
+
+
+ The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the
+ full-precision tensor. The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point`
+ must have the same shape, determining the quantization's granularity: a scalar for per-tensor/per-layer quantization,
+ a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization.
+ See QuantizeLinear for details on quantization granularity.
+
+ `x_zero_point` and `x` must have the same type. `x` and `y` must have the same shape. In the case of dequantizing
+ `int32`, there's no zero point (zero point is supposed to be 0).
+ `zero-point` is usually not used in the case of float8 and 4-bit types quantization, but the dequantization formula remains the same
+ for consistency. The output type is determined by the attribute `output_dtype`. If `output_dtype` is not supplied then the output type
+ is the same as `x_scale`. The output type also determines the precision of the multiplication operation.
+
+
+
+ Args:
+ x: N-D quantized input tensor to be de-quantized.
+
+ x_scale: Scale for input `x`. For per-tensor/layer dequantization the scale
+ is a scalar, for per per-axis dequantization it is a 1-D Tensor and for
+ blocked dequantization it has the same shape as the input, except for
+ one dimension in which blocking is performed.
+
+ x_zero_point: (optional) Zero point for input `x`. Shape must match x_scale.
+ It's optional. Zero point is 0 when it's not specified.
+
+ axis: (Optional) The axis of the dequantizing dimension of the input tensor.
+ Used for per-axis and blocked quantization. Negative value means
+ counting dimensions from the back. Accepted range is `[-r, r-1]` where
+ `r = rank(input)`.
+
+ block_size: (Optional) The size of the quantization block (number of times
+ every scale is replicated). Used only for blocked quantization. The
+ block size is a positive integer. Given `x` shape `(D0, ..., Di, ...,
+ Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted
+ range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]`
+
+ output_dtype: (Optional) The output data type. If not supplied, the output
+ data type is inferred from `x_scale` data type (`T2`)
+ """
+
+ schema = get_schema("DequantizeLinear", 23, "")
+ op = Op(self, "DequantizeLinear", schema)
+ return op(
+ *self._prepare_inputs(schema, x, x_scale, x_zero_point),
+ axis=axis,
+ block_size=block_size,
+ output_dtype=output_dtype,
+ )
+
+ T_Flatten = TypeVar(
+ "T_Flatten",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Flatten(self, input: T_Flatten, *, axis: int = 1) -> T_Flatten:
+ r"""[🌐 Flatten(23)](https://onnx.ai/onnx/operators/onnx__Flatten.html#flatten-23 "Online Documentation")
+
+
+ Flattens the input tensor into a 2D matrix. If input tensor has shape
+ (d_0, d_1, ... d_n) then the output will have shape
+ (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn).
+
+
+ Args:
+ input: (differentiable) A tensor of rank >= axis.
+
+ axis: Indicate up to which input dimensions (exclusive) should be flattened
+ to the outer dimension of the output. The value for axis must be in the
+ range [-r, r], where r is the rank of the input tensor. Negative value
+ means counting dimensions from the back. When axis = 0, the shape of the
+ output tensor is (1, (d_0 X d_1 ... d_n), where the shape of the input
+ tensor is (d_0, d_1, ... d_n).
+ """
+
+ schema = get_schema("Flatten", 23, "")
+ op = Op(self, "Flatten", schema)
+ return op(*self._prepare_inputs(schema, input), axis=axis)
+
+ V_Identity = TypeVar(
+ "V_Identity",
+ Optional[Sequence[BOOL]],
+ Optional[Sequence[COMPLEX128]],
+ Optional[Sequence[COMPLEX64]],
+ Optional[Sequence[DOUBLE]],
+ Optional[Sequence[FLOAT]],
+ Optional[Sequence[FLOAT16]],
+ Optional[Sequence[INT16]],
+ Optional[Sequence[INT32]],
+ Optional[Sequence[INT64]],
+ Optional[Sequence[INT8]],
+ Optional[Sequence[STRING]],
+ Optional[Sequence[UINT16]],
+ Optional[Sequence[UINT32]],
+ Optional[Sequence[UINT64]],
+ Optional[Sequence[UINT8]],
+ Optional[BOOL],
+ Optional[COMPLEX128],
+ Optional[COMPLEX64],
+ Optional[DOUBLE],
+ Optional[FLOAT],
+ Optional[FLOAT16],
+ Optional[INT16],
+ Optional[INT32],
+ Optional[INT64],
+ Optional[INT8],
+ Optional[STRING],
+ Optional[UINT16],
+ Optional[UINT32],
+ Optional[UINT64],
+ Optional[UINT8],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Identity(self, input: V_Identity) -> V_Identity:
+ r"""[🌐 Identity(23)](https://onnx.ai/onnx/operators/onnx__Identity.html#identity-23 "Online Documentation")
+
+ Identity operator
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Identity", 23, "")
+ op = Op(self, "Identity", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ B_If: TypeAlias = BOOL
+
+ V_If: TypeAlias = Union[
+ None,
+ Sequence[BFLOAT16],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ Sequence[FLOAT4E2M1],
+ Sequence[FLOAT8E4M3FN],
+ Sequence[FLOAT8E4M3FNUZ],
+ Sequence[FLOAT8E5M2],
+ Sequence[FLOAT8E5M2FNUZ],
+ Sequence[INT4],
+ Sequence[UINT4],
+ ]
+
+ def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If:
+ r"""[🌐 If(23)](https://onnx.ai/onnx/operators/onnx__If.html#if-23 "Online Documentation")
+
+ If conditional
+
+ Args:
+ cond: Condition for the if. The tensor must contain a single element.
+
+ else_branch: Graph to run if condition is false. Has N outputs: values you
+ wish to be live-out to the enclosing scope. The number of outputs must
+ match the number of outputs in the then_branch.
+
+ then_branch: Graph to run if condition is true. Has N outputs: values you
+ wish to be live-out to the enclosing scope. The number of outputs must
+ match the number of outputs in the else_branch.
+ """
+
+ schema = get_schema("If", 23, "")
+ op = Op(self, "If", schema)
+ return op(
+ *self._prepare_inputs(schema, cond),
+ else_branch=else_branch,
+ then_branch=then_branch,
+ )
+
+ I_Loop: TypeAlias = INT64
+
+ B_Loop: TypeAlias = BOOL
+
+ V_Loop = TypeVar(
+ "V_Loop",
+ Optional[Sequence[BFLOAT16]],
+ Optional[Sequence[BOOL]],
+ Optional[Sequence[COMPLEX128]],
+ Optional[Sequence[COMPLEX64]],
+ Optional[Sequence[DOUBLE]],
+ Optional[Sequence[FLOAT]],
+ Optional[Sequence[FLOAT16]],
+ Optional[Sequence[INT16]],
+ Optional[Sequence[INT32]],
+ Optional[Sequence[INT64]],
+ Optional[Sequence[INT8]],
+ Optional[Sequence[STRING]],
+ Optional[Sequence[UINT16]],
+ Optional[Sequence[UINT32]],
+ Optional[Sequence[UINT64]],
+ Optional[Sequence[UINT8]],
+ Optional[BFLOAT16],
+ Optional[BOOL],
+ Optional[COMPLEX128],
+ Optional[COMPLEX64],
+ Optional[DOUBLE],
+ Optional[FLOAT],
+ Optional[FLOAT16],
+ Optional[FLOAT4E2M1],
+ Optional[FLOAT8E4M3FN],
+ Optional[FLOAT8E4M3FNUZ],
+ Optional[FLOAT8E5M2],
+ Optional[FLOAT8E5M2FNUZ],
+ Optional[INT16],
+ Optional[INT32],
+ Optional[INT4],
+ Optional[INT64],
+ Optional[INT8],
+ Optional[STRING],
+ Optional[UINT16],
+ Optional[UINT32],
+ Optional[UINT4],
+ Optional[UINT64],
+ Optional[UINT8],
+ Sequence[BFLOAT16],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[FLOAT4E2M1],
+ Sequence[FLOAT8E4M3FN],
+ Sequence[FLOAT8E4M3FNUZ],
+ Sequence[FLOAT8E5M2],
+ Sequence[FLOAT8E5M2FNUZ],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT4],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT4],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Loop(
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
+ ) -> V_Loop:
+ r"""[🌐 Loop(23)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-23 "Online Documentation")
+
+
+ Generic Looping construct. This loop has multiple termination conditions:
+
+ 1) Trip count. Iteration count specified at runtime. Set by
+ specifying the input M. Optional. Set to empty string to omit.
+ Note that a static trip count (specified at graph construction time) can be
+ specified by passing in a constant node for input M.
+ 2) Loop termination condition. This is an input to the op that determines
+ whether to run the first iteration and also a loop-carried dependency for
+ the body graph. The body graph must yield a value for the condition variable,
+ whether this input is provided or not.
+
+ This table summarizes the operating modes of this operator with equivalent
+ C-style code:
+
+ Operator inputs defined as (max_trip_count, condition_var).
+
+ * input ("", ""):
+ for (int i=0; ; ++i) {
+ cond = ... // Note this value is ignored, but is required in the body
+ }
+
+ * input ("", cond) // Note this is analogous to a while loop
+ bool cond = ...;
+ for (int i=0; cond; ++i) {
+ cond = ...;
+ }
+
+ * input ("", 1) // Note this is analogous to a do-while loop
+ bool cond = true
+ for (int i=0; cond; ++i) {
+ cond = ...;
+ }
+
+ * input (trip_count, "") // Note this is analogous to a for loop
+ int trip_count = ...
+ for (int i=0; i < trip_count; ++i) {
+ cond = ...; // ignored
+ }
+
+ * input (trip_count, cond)
+ int trip_count = ...;
+ bool cond = ...;
+ for (int i=0; i < trip_count && cond; ++i) {
+ cond = ...;
+ }
+
+
+ *Sample usage - cond as well as trip count*
+
+ graph predict-net {
+ %a = Constant[value = ]()
+ %b = Constant[value = ]()
+ %keepgoing = Constant[value = ]()
+ %max_trip_count = Constant[value = ]()
+ %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b)
+ return
+ }
+
+ graph body-net (
+ %i[INT32, scalar] // iteration number
+ %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used
+ %b_in[INT32, scalar] // incoming value of loop-carried-dependency b
+ ) {
+ %my_local = Add(%a, %b_in)
+ %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b
+ %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition
+ %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated
+ return %keepgoing_out, %b_out, %user_defined_val
+ }
+
+ *Sample equivalent C code*
+
+ {
+ /* User-defined code (enclosing scope) */
+ int a = 3, b = 6;
+ bool keepgoing = true; // Analogous to input cond
+ /* End user-defined code */
+
+ /* Implicitly-defined code */
+ const int max_trip_count = 10; // Analogous to input M
+ int user_defined_vals[]; // Imagine this is resizable
+ /* End implicitly-defined code */
+ /* initialize loop-carried variables and scan-output variables */
+ bool keepgoing_out = keepgoing
+ int b_out = b
+
+ for (int i=0; i < max_trip_count && keepgoing_out; ++i) {
+ /* Implicitly-defined code: bind actual parameter values
+ to formal parameter variables of loop-body */
+ bool keepgoing_in = keepgoing_out;
+ bool b_in = b_out;
+
+ /* User-defined code (loop body) */
+ int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine
+ b_out = a - b_in;
+ keepgoing_out = my_local > b_out;
+ user_defined_val = b_in + b_in; // b_in and b_out are different variables
+ /* End user-defined code */
+
+ /* Implicitly defined-code */
+ user_defined_vals[i] = user_defined_val // accumulate scan-output values
+ }
+ // int t = my_local; // Can't do this. my_local is not accessible here.
+
+ // The values below are bound to the output variables of the loop and therefore accessible
+ // b_out; user_defined_vals; keepgoing_out;
+ }
+
+ There are several things of note in this code snippet:
+
+ 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can
+ be referenced in the inputs of the loop.
+ 2) Any values computed in the loop body that needs to be used in a subsequent
+ iteration or after the loop are modelled using a pair of variables in the loop-body,
+ consisting of an input variable (eg., b_in) and an output variable (eg., b_out).
+ These are referred to as loop-carried dependences. The loop operation node
+ supplies the input value of the input variable for the first iteration, and
+ returns the output value of the output variable produced by the final
+ iteration.
+ 3) Scan_output variables are used to implicitly concatenate values computed across
+ all the iterations. In the above example, the value of user_defined_val computed
+ over all iterations are concatenated and returned as the value of user_defined_vals
+ after the loop.
+ 4) Values created in the body cannot be accessed in the enclosing scope,
+ except using the mechanism described above.
+
+ Note that the semantics of this op support "diagonal" or "wavefront" execution.
+ (See Step 3 here for an example:
+ https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/).
+ Frontends should emit multi-layer RNNs as a series of While operators (with
+ time being the inner looping dimension), with each successive layer consuming
+ the scan_outputs from the previous layer, possibly going through several
+ point-wise operators (e.g. dropout, residual connections, linear layer).
+
+ The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order.
+
+
+ Args:
+ M: (optional) A maximum trip-count for the loop specified at runtime.
+ Optional. Pass empty string to skip.
+
+ cond: (optional) A boolean termination condition. Optional. Pass empty
+ string to skip.
+
+ v_initial: (variadic, heterogeneous) The initial values of any loop-carried
+ dependencies (values that change across loop iterations)
+
+ body: The graph run each iteration. It has 2+N inputs: (iteration_num,
+ condition, loop carried dependencies...). It has 1+N+K outputs:
+ (condition, loop carried dependencies..., scan_outputs...). Each
+ scan_output is created by concatenating the value of the specified
+ output value at the end of each iteration of the loop. It is an error if
+ the dimensions or data type of these scan_outputs change across loop
+ iterations.
+ """
+
+ schema = get_schema("Loop", 23, "")
+ op = Op(self, "Loop", schema)
+ return op(*self._prepare_inputs(schema, M, cond, *v_initial), body=body)
+
+ T_Pad = TypeVar(
+ "T_Pad",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ Tind_Pad = TypeVar("Tind_Pad", INT32, INT64)
+
+ def Pad(
+ self,
+ data: T_Pad,
+ pads: INT64,
+ constant_value: Optional[T_Pad] = None,
+ axes: Optional[Tind_Pad] = None,
+ *,
+ mode: str = "constant",
+ ) -> T_Pad:
+ r"""[🌐 Pad(23)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-23 "Online Documentation")
+
+
+ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`,
+ a padded tensor (`output`) is generated.
+
+ The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`):
+
+ 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False)
+
+ 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis
+
+ 3) `edge` - pads with the edge values of array
+
+ 4) `wrap` - wrap-around padding as if the data tensor forms a torus
+
+
+ Example 1 (`constant` mode):
+
+ Insert 0 pads to the beginning of the second dimension.
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'constant'
+
+ constant_value = 0.0
+
+ output = [
+ [0.0, 0.0, 1.0, 1.2],
+ [0.0, 0.0, 2.3, 3.4],
+ [0.0, 0.0, 4.5, 5.7],
+ ]
+
+
+
+ Example 2 (`reflect` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'reflect'
+
+ output = [
+ [1.0, 1.2, 1.0, 1.2],
+ [2.3, 3.4, 2.3, 3.4],
+ [4.5, 5.7, 4.5, 5.7],
+ ]
+
+
+
+ Example 3 (`edge` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'edge'
+
+ output = [
+ [1.0, 1.0, 1.0, 1.2],
+ [2.3, 2.3, 2.3, 3.4],
+ [4.5, 4.5, 4.5, 5.7],
+ ]
+
+
+
+ Example 4 (`wrap` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [2, 1, 1, 1]
+
+ mode = 'wrap'
+
+ output = [
+ [3.4, 2.3, 3.4, 2.3],
+ [5.7, 4.5, 5.7, 4.5],
+ [1.2, 1.0, 1.2, 1.0],
+ [3.4, 2.3, 3.4, 2.3],
+ [5.7, 4.5, 5.7, 4.5],
+ [1.2, 1.0, 1.2, 1.0],
+ ]
+
+
+
+
+ Args:
+ data: (differentiable) Input tensor.
+
+ pads: (non-differentiable) Tensor of integers indicating the number of
+ padding elements to add or remove (if negative) at the beginning and end
+ of each axis. For 2D input tensor, it is the number of pixels. `pads`
+ should be a 1D tensor of shape [2 * num_axes] where `num_axes` refers to
+ the number of elements in the `axes` input or the input rank if `axes`
+ are not provided explicitly. `pads` format should be: [x1_begin,
+ x2_begin, ..., x1_end, x2_end,...], where xi_begin is the number of pad
+ values added at the beginning of axis `axes[i]` and xi_end, the number
+ of pad values added at the end of axis `axes[i]`.
+
+ constant_value: (optional, non-differentiable) (Optional) A scalar value to
+ be used if the mode chosen is `constant` (by default it is 0, empty
+ string or False).
+
+ axes: (optional, non-differentiable) 1-D tensor of axes that `pads` apply
+ to. Negative value means counting dimensions from the back. Accepted
+ range is [-r, r-1] where r = rank(data). Behavior is undefined if an
+ axis is repeated. If not provided, all axes are assumed (`[0, 1, ...,
+ input_rank-1]`).
+
+ mode: Supported modes: `constant`(default), `reflect`, `edge`, `wrap`
+ """
+
+ schema = get_schema("Pad", 23, "")
+ op = Op(self, "Pad", schema)
+ return op(*self._prepare_inputs(schema, data, pads, constant_value, axes), mode=mode)
+
+ T1_QuantizeLinear = TypeVar("T1_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32)
+
+ T2_QuantizeLinear = TypeVar("T2_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32)
+
+ T3_QuantizeLinear = TypeVar(
+ "T3_QuantizeLinear",
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT4,
+ INT8,
+ UINT16,
+ UINT4,
+ UINT8,
+ )
+
+ def QuantizeLinear(
+ self,
+ x: T1_QuantizeLinear,
+ y_scale: T2_QuantizeLinear,
+ y_zero_point: Optional[T3_QuantizeLinear] = None,
+ *,
+ axis: int = 1,
+ block_size: int = 0,
+ output_dtype: int = 0,
+ precision: int = 0,
+ saturate: int = 1,
+ ) -> T3_QuantizeLinear:
+ r"""[🌐 QuantizeLinear(23)](https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html#quantizelinear-23 "Online Documentation")
+
+
+ The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the
+ low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization
+ granularity. The quantization formula is `y = saturate((x / y_scale) + y_zero_point)`.
+
+ Saturation is done according to:
+ - uint16: [0, 65535]
+ - int16: [-32768, 32767]
+ - uint8: [0, 255]
+ - int8: [-128, 127]
+ - uint4: [0, 15]
+ - int4: [-8, 7]
+
+ For `(x / y_scale)`, it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
+
+ `y_zero_point` and `y` must have the same type. `y_zero_point` is usually not used for quantization to float8 and 4bit types, but the quantization
+ formula remains the same for consistency, and the type of the attribute `y_zero_point` still determines the quantization type.
+ `x` and `y_scale` are allowed to have different types. The type of `y_scale` determines the precision of the division operation between `x` and
+ `y_scale`, unless the `precision` attribute is specified.
+
+ There are three supported quantization granularities, determined by the shape of `y_scale`.
+ In all cases, `y_zero_point` must have the same shape as `y_scale`.
+ - Per-tensor (per-layer) quantization: `y_scale` is a scalar.
+ - Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape
+ `(D0, ..., Di, ..., Dn)` and `axis=i`, `y_scale` is a 1-D tensor of length `Di`.
+ - Blocked quantization: The scale's shape is identical to the input's shape, except for one dimension, in which
+ blocking is performed. Given `x` shape `(D0, ..., Di, ..., Dn)`, `axis=i`, and block size `B`: `y_scale` shape is
+ `(D0, ..., ceil(Di/B), ..., Dn)`.
+
+
+ Args:
+ x: N-D full precision Input tensor to be quantized.
+
+ y_scale: Scale for doing quantization to get `y`. For per-tensor/layer
+ quantization the scale is a scalar, for per-axis quantization it is a
+ 1-D Tensor and for blocked quantization it has the same shape as the
+ input, except for one dimension in which blocking is performed.
+
+ y_zero_point: (optional) Zero point for doing quantization to get `y`. Shape
+ must match `y_scale`.Default is uint8 with zero point of 0 if it's not
+ specified.
+
+ axis: (Optional) The axis of the dequantizing dimension of the input tensor.
+ Used only for per-axis and blocked quantization. Negative value means
+ counting dimensions from the back. Accepted range is `[-r, r-1]` where
+ `r = rank(input)`. When the rank of the input is 1, per-tensor
+ quantization is applied, rendering the axis unnecessary in this
+ scenario.
+
+ block_size: (Optional) The size of the quantization block (number of times
+ every scale is replicated). Used only for blocked quantization. The
+ block size is a positive integer. Given `x` shape `(D0, ..., Di, ...,
+ Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted
+ range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]`
+
+ output_dtype: (Optional) The output data type. If not supplied, the output
+ data type is inferred from `y_zero_point` data type (`T3`). If neither
+ `output_dtype` nor `y_zero_point` are supplied, output data type is
+ uint8. If both `output_dtype` and `y_zero_point` are specified,
+ `output_dtype` must be `T3`.
+
+ precision: (Optional) The precision of the division operation between `x`
+ and `y_scale`. If not provided, it will be the same as the type of
+ `y_scale`.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ quantization (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz).
+ It is true by default. All cases are fully described in two tables
+ inserted in the operator description.
+ """
+
+ schema = get_schema("QuantizeLinear", 23, "")
+ op = Op(self, "QuantizeLinear", schema)
+ return op(
+ *self._prepare_inputs(schema, x, y_scale, y_zero_point),
+ axis=axis,
+ block_size=block_size,
+ output_dtype=output_dtype,
+ precision=precision,
+ saturate=saturate,
+ )
+
+ T_RMSNormalization = TypeVar("T_RMSNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ V_RMSNormalization = TypeVar("V_RMSNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def RMSNormalization(
+ self,
+ X: T_RMSNormalization,
+ scale: V_RMSNormalization,
+ *,
+ axis: int = -1,
+ epsilon: float = 9.999999747378752e-06,
+ stash_type: int = 1,
+ ) -> V_RMSNormalization:
+ r"""[🌐 RMSNormalization(23)](https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#rmsnormalization-23 "Online Documentation")
+
+
+ This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467.
+ The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions,
+ where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape),
+ the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be
+ described by the following equations.
+ ```
+ XSquared = Mul(X, X)
+ XSquaredMean = ReduceMean(XSquared)
+ MeanSquareEpsilon = Add(XSquaredMean, epsilon)
+ RMS = Sqrt(MeanSquareEpsilon)
+ Normalized = Div(X, RMS)
+ ```
+ where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square,
+ Depending on `stash_type` attribute, the actual computation
+ must happen in different floating-point precision.
+ For example, if `stash_type` is 1, this operator casts
+ all input variables to 32-bit float, perform the computation, and
+ finally cast `Normalized` back to the original type of `X`.
+ The second stage then scales the outcome of the first stage using:
+ ```
+ Y= Mul(Normalized, Scale)
+ ```
+ Let `d[i]` indicate the i-th dimension of `X`.
+ If `X`'s shape is `[d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]`,
+ the shape of `RMS` is `[d[0], ..., d[axis-1], 1, ..., 1]`.
+ `Y` and `X` have the same shape. This operator supports unidirectional broadcasting
+ (`Scale` should be unidirectional broadcastable to tensor `X`);
+ for more details please check `Broadcasting in ONNX `_.
+
+
+ Args:
+ X: The input tensor to be normalized. In general, the shape is (D1, D2, ...
+ , Dn) for n-dimensional data, where the root mean squared norm is taken
+ over the last D dimensions, D is determined by the axis attribute.
+
+ scale: Scale tensor. Scale tensor shape should be broadcastable to the
+ normalized shape.
+
+ axis: The first normalization dimension. If rank(X) is r, axis' allowed
+ range is [-r, r). Negative value means counting dimensions from the
+ back.
+
+ epsilon: The epsilon value to use to avoid division by zero.
+
+ stash_type: The floating-point precision used in stage one of the
+ computation.
+ """
+
+ schema = get_schema("RMSNormalization", 23, "")
+ op = Op(self, "RMSNormalization", schema)
+ return op(
+ *self._prepare_inputs(schema, X, scale),
+ axis=axis,
+ epsilon=epsilon,
+ stash_type=stash_type,
+ )
+
+ T_Reshape = TypeVar(
+ "T_Reshape",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Reshape(self, data: T_Reshape, shape: INT64, *, allowzero: int = 0) -> T_Reshape:
+ r"""[🌐 Reshape(23)](https://onnx.ai/onnx/operators/onnx__Reshape.html#reshape-23 "Online Documentation")
+
+
+ Reshape the input tensor similar to numpy.reshape.
+ First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor.
+ At most one dimension of the new shape can be -1. In this case, the value is
+ inferred from the size of the tensor and the remaining dimensions. A dimension
+ could also be 0, in which case the actual dimension value is unchanged (i.e. taken
+ from the input tensor). If 'allowzero' is set, and the new shape includes 0, the
+ dimension will be set explicitly to zero (i.e. not taken from input tensor).
+ Shape (second input) could be an empty shape, which means converting to a scalar.
+ The input tensor's shape and the output tensor's shape are required to have the same number of elements.
+
+ If the attribute 'allowzero' is set, it is invalid for the specified shape to
+ contain both a zero value and -1, as the value of the dimension corresponding
+ to -1 cannot be determined uniquely.
+
+
+ Args:
+ data: (differentiable) An input tensor.
+
+ shape: (non-differentiable) Specified shape for output.
+
+ allowzero: (Optional) By default, when any value in the 'shape' input is
+ equal to zero the corresponding dimension value is copied from the input
+ tensor dynamically. allowzero=1 indicates that if any value in the
+ 'shape' input is set to zero, the zero value is honored, similar to
+ NumPy.
+ """
+
+ schema = get_schema("Reshape", 23, "")
+ op = Op(self, "Reshape", schema)
+ return op(*self._prepare_inputs(schema, data, shape), allowzero=allowzero)
+
+ T_RotaryEmbedding = TypeVar("T_RotaryEmbedding", BFLOAT16, FLOAT, FLOAT16)
+
+ M_RotaryEmbedding: TypeAlias = INT64
+
+ def RotaryEmbedding(
+ self,
+ X: T_RotaryEmbedding,
+ cos_cache: T_RotaryEmbedding,
+ sin_cache: T_RotaryEmbedding,
+ position_ids: Optional[M_RotaryEmbedding] = None,
+ *,
+ interleaved: int = 0,
+ num_heads: Optional[int] = None,
+ rotary_embedding_dim: int = 0,
+ ) -> T_RotaryEmbedding:
+ r"""[🌐 RotaryEmbedding(23)](https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23 "Online Documentation")
+
+
+ RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864.
+ The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances
+ between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids).
+
+ The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles.
+ For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the
+ embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector.
+ The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated
+ to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism.
+ The rotation ensures that the model captures both absolute and relative positional information.
+
+ Rotary embeddings are defined using the following algorithm:
+
+ ::
+
+ def compute_rotary_embedding(
+ input,
+ position_ids,
+ sin_cache,
+ cos_cache,
+ interleaved=0,
+ rotary_embedding_dim=0,
+ num_heads=0,
+ ):
+ # First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
+ if len(input.shape) == 4:
+ input = np.transpose(input, (0, 2, 1, 3))
+ batch_size = input.shape[0]
+ sequence_length = input.shape[1]
+ if len(input.shape) == 3:
+ hidden_size = input.shape[2]
+ assert num_heads != 0
+ head_size = int(hidden_size / num_heads)
+ new_shape = [batch_size, sequence_length, num_heads, head_size]
+ input = np.reshape(input, new_shape)
+ assert len(input.shape) == 4
+ head_size = input.shape[3]
+
+ # Fully or partially perform rotation on input based on rotary_embedding_dim attribute
+ if rotary_embedding_dim == 0:
+ # If rotary_embedding_dim not provided, perform full rotation by using head_size
+ rotary_embedding_dim = head_size
+ x_rotate = input[:, :, :, :rotary_embedding_dim]
+ x_not_rotate = input[:, :, :, rotary_embedding_dim:]
+ rotary_embedding_dim_half = int(rotary_embedding_dim / 2)
+
+ # Retrieve sin and cos caches using position ids
+ if position_ids is not None:
+ cos = cos_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2]
+ sin = sin_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2]
+ else:
+ cos = cos_cache
+ sin = sin_cache
+ cos = cos[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
+ sin = sin[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
+ cos = np.expand_dims(cos, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
+ sin = np.expand_dims(sin, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
+
+ # Either divide the input in halves or interleave (based on interleaved attribute)
+ if interleaved:
+ x1 = x_rotate[:, :, :, 0::2]
+ x2 = x_rotate[:, :, :, 1::2]
+ else:
+ x1, x2 = np.split(x_rotate, 2, axis=-1)
+
+ # Calculate real and imaginary values
+ real = cos * x1 - sin * x2
+ imag = sin * x1 + cos * x2
+
+ # Inserted rotated embeddings back to the original input
+ if interleaved:
+ # x_rotate[:, :, :, 0::2] = real
+ # x_rotate[:, :, :, 1::2] = imag
+ real = np.expand_dims(real, axis=-1)
+ imag = np.expand_dims(imag, axis=-1)
+ x_rotate_concat = np.concatenate((real, imag), axis=-1)
+ x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
+ else:
+ x_rotate = np.concatenate((real, imag), axis=-1)
+ output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
+ if len(original_input_shape) == 3:
+ output = np.reshape(output, input.shape)
+ else:
+ output = np.transpose(output, (0, 2, 1, 3))
+ return output
+
+
+
+
+ Args:
+ X: The input tensor representing the token embeddings. 4D tensor with shape
+ `(batch_size, num_heads, sequence_length, head_size)` or 3D tensor with
+ shape `(batch_size, sequence_length, hidden_size)`. For cases with a 4D
+ input tensor, `head_size` has to be even. For cases with a 3D input
+ tensor, `num_heads` attribute must be provided and `hidden_size` must be
+ an even multiple of `num_heads` where `hidden_size = num_heads *
+ head_size`
+
+ cos_cache: The cosine values for the rotation. 2D tensor with shape
+ `(max_position_id_plus_1, head_size / 2)` for full rotation or
+ `(max_position_id_plus_1, rotary_embedding_dim / 2)` for partial
+ rotation when `position_ids` are provided. 3D tensor with shape
+ `(batch_size, sequence_length, head_size / 2)` for full rotation or
+ `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial
+ rotation when `position_ids` are not provided. `max_position_id_plus_1`
+ is a parameter to the model.
+
+ sin_cache: The sine values for the rotation. 2D tensor with shape
+ `(max_position_id_plus_1, head_size / 2)` for full rotation or
+ `(max_position_id_plus_1, rotary_embedding_dim / 2)` for partial
+ rotation when `position_ids` are provided. 3D tensor with shape
+ `(batch_size, sequence_length, head_size / 2)` for full rotation or
+ `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial
+ rotation when `position_ids` are not provided. `max_position_id_plus_1`
+ is a parameter to the model.
+
+ position_ids: (optional) The position indices for the tokens. 2D tensor with
+ shape `(batch_size, sequence_length)`
+
+ interleaved: Rotate using interleaved pattern. Default value is 0 (False).
+
+ num_heads: Number of attention heads. Must be provided when input is a 3D
+ tensor.
+
+ rotary_embedding_dim: Rotary embedding dimension used to apply partial
+ rotary embeddings.
+ """
+
+ schema = get_schema("RotaryEmbedding", 23, "")
+ op = Op(self, "RotaryEmbedding", schema)
+ return op(
+ *self._prepare_inputs(schema, X, cos_cache, sin_cache, position_ids),
+ interleaved=interleaved,
+ num_heads=num_heads,
+ rotary_embedding_dim=rotary_embedding_dim,
+ )
+
+ V_Scan = TypeVar(
+ "V_Scan",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Scan(
+ self,
+ *initial_state_and_scan_inputs: V_Scan,
+ body: GraphProto,
+ num_scan_inputs: int,
+ scan_input_axes: Optional[Sequence[int]] = None,
+ scan_input_directions: Optional[Sequence[int]] = None,
+ scan_output_axes: Optional[Sequence[int]] = None,
+ scan_output_directions: Optional[Sequence[int]] = None,
+ ) -> V_Scan:
+ r"""[🌐 Scan(23)](https://onnx.ai/onnx/operators/onnx__Scan.html#scan-23 "Online Documentation")
+
+
+ Scan can be used to iterate over one or more scan_input tensors,
+ constructing zero or more scan_output tensors. It combines ideas from general recurrences,
+ functional programming constructs such as scan, fold, map, and zip, and is intended to enable
+ generalizations of RNN-like constructs for sequence-to-sequence processing.
+ Other tensors (referred to as state_variables here) can be used to carry a state
+ when iterating from one element to another (similar to hidden-state in RNNs, also referred
+ to as loop-carried dependences in the context of loops).
+ Many common usages involve a single scan_input tensor (where functionality
+ similar to scan, fold and map can be obtained). When more than one scan_input is used,
+ a behavior similar to zip is obtained.
+
+ The attribute body must be a graph, specifying the computation to be performed in
+ every iteration. It takes as input the current values of the state_variables and
+ the current iterated element of the scan_inputs. It must return the (updated) values
+ of the state_variables and zero or more scan_output_element tensors. The values of the
+ scan_output_element tensors are concatenated over all the iterations to produce the
+ scan_output values of the scan construct (similar to the concatenated intermediate
+ hidden-state values of RNN-like constructs). All the output tensors (state_variables as
+ well as scan_output_element tensors) are required to have the same shape in each iteration
+ of the loop (a restriction imposed to enable efficient memory allocation).
+
+ Note that the iterated element passed to the body subgraph does not have a sequence
+ axis. It will have a rank one less than the rank of the corresponding scan_input.
+
+ The scan operation returns the final values of the state_variables as well as the
+ scan_outputs.
+
+ The optional attribute scan_input_directions specifies the direction (forward or backward)
+ for each scan input. If this attribute is omitted, all sequences are scanned in the forward
+ direction. A bidirectional scan may be performed by specifying the same tensor input twice
+ in the scan_inputs, once with a forward direction, and once with a backward direction.
+
+ The scan_output of the operation is produced by concatenating the scan_output_element
+ values produced by the body in each iteration. The optional attribute scan_output_directions
+ specifies the direction in which scan_output is constructed (by appending or prepending the
+ scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
+ is omitted, the scan_output_element is appended to the scan_output in each iteration.
+
+ The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
+ If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
+ batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
+ Note that scanning a non-zero axis may be less efficient than scanning axis zero.
+
+ The optional attribute scan_output_axes specifies the axis along which the scan_outputs
+ are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
+ scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
+ value of 1.
+
+ Note that because of the ONNX restriction that only the last parameter of an operator can
+ be variadic, the initial-states and scan-inputs are listed together as one input parameter.
+ Similarly, the final-states and scan-outputs are listed together as one output parameter.
+ The attribute num_scan_inputs indicates the number M of scan-inputs.
+
+ The behavior of
+
+ Scan <
+ num_scan_inputs = m,
+ body = loop-body,
+ scan_input_axes = [axis_1, ..., axis_m]
+ > (init_1, ..., init_n, scan_1, ..., scan_m)
+
+ is equivalent to the following pseudo-code:
+
+ // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
+ // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
+ sequence_length = scan_1.shape[axis_1];
+
+ // initialize state-variables
+ st_1 = init_1; ... st_n = init_n;
+ // initialize scan-output variables: [] denotes an empty tensor
+ scan_out_1 = []; ...; scan_out_k = [];
+ // identify number of iterations:
+
+ // execute loop
+ for (int t = 0; t < sequence_length; ++t) {
+ // generate the scan-input elements: the notation T[t] indicates the sub-tensor
+ // of rank one less than T obtained by indexing T at position t along axis k.
+ si_1 = scan_1[t];
+ ... ;
+ si_m = scan_m[t];
+ // execute loop-body
+ st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
+ // accumulate the scan-output elements
+ scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k);
+ }
+
+ return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
+
+ *Sample usage: Encoding RNN using a Scan*
+
+ The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
+ recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
+ be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
+ %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
+ values are computed in the outer graph, they need to be passed in as extra state_variables.
+
+ graph rnn-encoding {
+ %H_0 = ...
+ %X = ...
+ %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X)
+ return %Y, %Y_h
+ }
+
+ graph rnn-cell-1 (
+ %H_tminus1[FLOAT, tensor]
+ %X_t[FLOAT, tensor]
+ ) {
+ %Wi = ...
+ %Ri = ...
+ %Wbi = ...
+ %Rbi = ...
+ %t1 = X_t * (Wi^T)
+ %t2 = H_tminus1*(Ri^T)
+ %t3 = Add(%t1, %t2)
+ %t4 = Add(%t3, %Wbi)
+ %t5 = Add(%t4, %Rbi)
+ %Ht = Tanh(%t5)
+ %Accumulate = Identity(%Ht)
+ return %Ht, %Accumulate
+ }
+
+
+
+ Args:
+ initial_state_and_scan_inputs: (variadic, heterogeneous) Initial values of
+ the loop's N state variables followed by M scan_inputs
+
+ body: The graph run each iteration. It has N+M inputs: (loop state
+ variables..., scan_input_elts...). It has N+K outputs: (loop state
+ variables..., scan_output_elts...). Each scan_output is created by
+ concatenating the value of the specified scan_output_elt value at the
+ end of each iteration of the loop. It is an error if the dimensions of
+ these values change across loop iterations.
+
+ num_scan_inputs: An attribute specifying the number of scan_inputs M.
+
+ scan_input_axes: An optional list of M flags. The i-th element of the list
+ specifies the axis to be scanned (the sequence axis) for the i-th
+ scan_input. If omitted, 0 will be used as the scan axis for every
+ scan_input. Negative value for an axis means counting dimensions from
+ the back. Accepted range is [-r, r-1] where r = rank(input).
+
+ scan_input_directions: An optional list of M flags. The i-th element of the
+ list specifies the direction to be scanned for the i-th scan_input
+ tensor: 0 indicates forward direction and 1 indicates reverse direction.
+ If omitted, all scan_input tensors will be scanned in the forward
+ direction.
+
+ scan_output_axes: An optional list of K flags. The i-th element of the list
+ specifies the axis for the i-th scan_output. The scan outputs are
+ accumulated along the specified axis. If omitted, 0 will be used as the
+ scan axis for every scan_output. Negative value for an axis means
+ counting dimensions from the back. Accepted range is [-r, r-1].
+
+ scan_output_directions: An optional list of K flags, one for each
+ scan_output. The i-th element of the list specifies whether the i-th
+ scan_output should be constructed by appending or prepending a new value
+ in each iteration: 0 indicates appending and 1 indicates prepending. If
+ omitted, all scan_output tensors will be produced by appending a value
+ in each iteration.
+ """
+
+ schema = get_schema("Scan", 23, "")
+ op = Op(self, "Scan", schema)
+ return op(
+ *self._prepare_inputs(schema, *initial_state_and_scan_inputs),
+ body=body,
+ num_scan_inputs=num_scan_inputs,
+ scan_input_axes=scan_input_axes,
+ scan_input_directions=scan_input_directions,
+ scan_output_axes=scan_output_axes,
+ scan_output_directions=scan_output_directions,
+ )
+
+ T_Shape = TypeVar(
+ "T_Shape",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T1_Shape: TypeAlias = INT64
+
+ def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) -> T1_Shape:
+ r"""[🌐 Shape(23)](https://onnx.ai/onnx/operators/onnx__Shape.html#shape-23 "Online Documentation")
+
+
+ Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor.
+ Optional attributes start and end can be used to compute a slice of the input tensor's shape.
+ If start axis is omitted, the slice starts from axis 0.
+ The end axis, if specified, is exclusive (and the returned value will not include the size of that axis).
+ If the end axis is omitted, the axes upto the last one will be included.
+ Negative axes indicate counting back from the last axis.
+ Note that axes will be clamped to the range [0, r], where r is the
+ rank of the input tensor if they are out-of-range (after adding r in the case of
+ negative axis). Thus, specifying any end value > r is equivalent to specifying an end
+ value of r, and specifying any start value < -r is equivalent to specifying a start
+ value of 0. If start > end, the result will be an empty shape.
+
+ Examples:
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ No attributes specified.
+ Output: [2, 3, 4]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ start: -1
+ Output: [4]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ end: -1
+ Output: [2, 3]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ start: 1
+ end: 2
+ Output: [3]
+
+
+
+
+ Args:
+ data: (non-differentiable) An input tensor.
+
+ end: (Optional) Ending axis for slicing the shape. Negative value means
+ counting dimensions from the back. If omitted, sizes of all axes upto
+ (including) the last one will be included.
+
+ start: (Optional) Starting axis for slicing the shape. Default value is
+ 0.Negative value means counting dimensions from the back.
+ """
+
+ schema = get_schema("Shape", 23, "")
+ op = Op(self, "Shape", schema)
+ return op(*self._prepare_inputs(schema, data), end=end, start=start)
+
+ T_Size = TypeVar(
+ "T_Size",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T1_Size: TypeAlias = INT64
+
+ def Size(self, data: T_Size) -> T1_Size:
+ r"""[🌐 Size(23)](https://onnx.ai/onnx/operators/onnx__Size.html#size-23 "Online Documentation")
+
+
+ Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor.
+
+
+ Args:
+ data: (non-differentiable) An input tensor.
+ """
+
+ schema = get_schema("Size", 23, "")
+ op = Op(self, "Size", schema)
+ return op(*self._prepare_inputs(schema, data))
+
+ T_Squeeze = TypeVar(
+ "T_Squeeze",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Squeeze(self, data: T_Squeeze, axes: Optional[INT64] = None) -> T_Squeeze:
+ r"""[🌐 Squeeze(23)](https://onnx.ai/onnx/operators/onnx__Squeeze.html#squeeze-23 "Online Documentation")
+
+
+ Remove single-dimensional entries from the shape of a tensor.
+ Takes an input `axes` with a list of axes to squeeze.
+ If `axes` is not provided, all the single dimensions will be removed from
+ the shape. If an axis is selected with shape entry not equal to one, an error is raised.
+
+
+ Args:
+ data: (differentiable) Tensors with at least max(dims) dimensions.
+
+ axes: (optional, non-differentiable) 1D tensor of integers indicating the
+ dimensions to squeeze. Negative value means counting dimensions from the
+ back. Accepted range is [-r, r-1] where r = rank(data).
+ """
+
+ schema = get_schema("Squeeze", 23, "")
+ op = Op(self, "Squeeze", schema)
+ return op(*self._prepare_inputs(schema, data, axes))
+
+ T_Transpose = TypeVar(
+ "T_Transpose",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Transpose(
+ self, data: T_Transpose, *, perm: Optional[Sequence[int]] = None
+ ) -> T_Transpose:
+ r"""[🌐 Transpose(23)](https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-23 "Online Documentation")
+
+
+ Transpose the input tensor similar to numpy.transpose. For example, when
+ perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape
+ will be (2, 1, 3).
+
+
+ Args:
+ data: (differentiable) An input tensor.
+
+ perm: A list of integers. By default, reverse the dimensions, otherwise
+ permute the axes according to the values given. Its length must be equal
+ to the rank of the input.
+ """
+
+ schema = get_schema("Transpose", 23, "")
+ op = Op(self, "Transpose", schema)
+ return op(*self._prepare_inputs(schema, data), perm=perm)
+
+ T_Unsqueeze = TypeVar(
+ "T_Unsqueeze",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Unsqueeze(self, data: T_Unsqueeze, axes: INT64) -> T_Unsqueeze:
+ r"""[🌐 Unsqueeze(23)](https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#unsqueeze-23 "Online Documentation")
+
+
+ Insert single-dimensional entries to the shape of an input tensor (`data`).
+ Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`).
+
+ For example, given an input tensor (`data`) of shape [3, 4, 5], then
+ Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1].
+
+ The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates.
+ The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`.
+ Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1].
+ The order of values in `axes` does not matter and can come in any order.
+
+
+ Args:
+ data: (differentiable) Original tensor
+
+ axes: (non-differentiable) 1D tensor of integers indicating the dimensions
+ to be inserted. Negative value means counting dimensions from the back.
+ Accepted range is [-r, r-1] where r = rank(expanded).
+ """
+
+ schema = get_schema("Unsqueeze", 23, "")
+ op = Op(self, "Unsqueeze", schema)
+ return op(*self._prepare_inputs(schema, data, axes))
diff --git a/onnxscript/onnx_opset/_impl/opset24.py b/onnxscript/onnx_opset/_impl/opset24.py
new file mode 100644
index 0000000000..d85fcaefe5
--- /dev/null
+++ b/onnxscript/onnx_opset/_impl/opset24.py
@@ -0,0 +1,2342 @@
+# --------------------------------------------------------------------------
+# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
+# ⚙️ Generated by 'python -m opgen'
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+# pylint: disable=W0221,W0222,R0901,W0237
+# mypy: disable-error-code=override
+# ruff: noqa: D214, D402, D405, D411, D412, D416
+# --------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from typing import Optional, Sequence, Tuple, TypeVar, Union
+
+from onnx import GraphProto, SparseTensorProto, TensorProto
+from onnx.defs import get_schema
+from typing_extensions import TypeAlias
+
+from onnxscript.onnx_opset._impl.opset23 import Opset23
+from onnxscript.onnx_types import (
+ BFLOAT16,
+ BOOL,
+ COMPLEX64,
+ COMPLEX128,
+ DOUBLE,
+ FLOAT,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ FLOAT16,
+ INT4,
+ INT8,
+ INT16,
+ INT32,
+ INT64,
+ STRING,
+ UINT4,
+ UINT8,
+ UINT16,
+ UINT32,
+ UINT64,
+)
+from onnxscript.values import Op, Opset
+
+
+class Opset24(Opset23):
+ def __new__(cls):
+ return Opset.__new__(cls, "", 24)
+
+ T1_Attention = TypeVar("T1_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ T2_Attention = TypeVar("T2_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ U_Attention = TypeVar(
+ "U_Attention",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ def Attention(
+ self,
+ Q: T1_Attention,
+ K: T1_Attention,
+ V: T2_Attention,
+ attn_mask: Optional[U_Attention] = None,
+ past_key: Optional[T1_Attention] = None,
+ past_value: Optional[T2_Attention] = None,
+ nonpad_kv_seqlen: Optional[INT64] = None,
+ *,
+ is_causal: int = 0,
+ kv_num_heads: Optional[int] = None,
+ q_num_heads: Optional[int] = None,
+ qk_matmul_output_mode: int = 0,
+ scale: Optional[float] = None,
+ softcap: float = 0.0,
+ softmax_precision: Optional[int] = None,
+ ) -> Tuple[T1_Attention, T1_Attention, T2_Attention, T1_Attention]:
+ r"""[🌐 Attention(24)](https://onnx.ai/onnx/operators/onnx__Attention.html#attention-24 "Online Documentation")
+
+
+
+ Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
+
+ This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
+
+ For self attention, `kv_sequence_length` equals to `q_sequence_length`.
+
+ For cross attention, query and key might have different lengths.
+
+ This operator also covers the 3 following variants based on the number of heads:
+ 1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`.
+ 2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`.
+ 3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`.
+
+ Attention bias to be added is calculated based on `attn_mask` input and `is_causal` attribute:
+ 1) `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
+ 2) If `is_causal` is set to `1`, attention scores above the diagonal are masked out, regardless of the `attn_mask` input.
+
+ With respect to KV cache update, this operator allows the following two use cases:
+
+ 1) Cache update happens inside the Attention operator. In this case, the `K` and `V` inputs contain only the incoming
+ tokens for the current autoregressive step, and the four optional inputs/outputs past and present key and value are
+ all needed. The Attention op performs a Concat operation on the past and incoming key and value to form the present
+ key and value, respectively. Note that this only works correctly for the special case where the past key and value
+ do not contain padded tokens.
+ 2) Cache update happens outside the Attention operator (for example, through the `TensorScatter` operator). In this
+ case, the `K` and `V` inputs correspond to the entire cache tensor, so the four optional inputs/outputs past and
+ present key and value should not be used. An additional input `nonpad_kv_seqlen` of shape (batch_size,) may be
+ provided to indicate the number of non-padding tokens in each sample of the batch to save unnecessary computation.
+ Here, the kv_sequence dimension of `attn_mask` can be shorter than `K` and `V`, but still needs to be at least as long
+ as the maximum value of `nonpad_kv_seqlen`.
+
+ Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them.
+ The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:
+
+ ::
+
+ The following pattern is applied by this operator:
+ Q K V
+ | | |
+ Q*sqrt(scale) K*sqrt(scale) |
+ | | |
+ | Transpose |
+ | | |
+ ---MatMul--- |
+ | |
+ at_mask---Add |
+ | |
+ softcap (if provided) |
+ | |
+ Softmax |
+ | |
+ -----MatMul------
+ |
+ Y
+
+
+
+
+
+ Args:
+ Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads,
+ q_sequence_length, head_size)` or 3D tensor with shape `(batch_size,
+ q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor,
+ `q_hidden_size = q_num_heads * head_size`
+
+ K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads,
+ kv_sequence_length, head_size)` or 3D tensor with shape `(batch_size,
+ kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor,
+ `k_hidden_size = kv_num_heads * head_size`
+
+ V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads,
+ kv_sequence_length, v_head_size)` or 3D tensor with shape `(batch_size,
+ kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor,
+ `v_hidden_size = kv_num_heads * v_head_size`
+
+ attn_mask: (optional) Attention mask. Shape must be broadcastable to
+ `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)`
+ where `total_sequence_length = past_sequence_length +
+ kv_sequence_length.` The last dimension can also be shorter than
+ `total_sequence_length` and will be padded to `total_sequence_length`
+ with negative infinity. Two types of masks are supported: a boolean mask
+ where a value of `True` indicates that the element should take part in
+ attention, or a float mask of the same type as query, key, value that is
+ added to the attention score.
+
+ past_key: (optional) past state cache for key with shape `(batch_size,
+ kv_num_heads, past_sequence_length, head_size)`
+
+ past_value: (optional) past state cache for value with shape `(batch_size,
+ kv_num_heads, past_sequence_length, v_head_size)`
+
+ nonpad_kv_seqlen: (optional) A vector of integers of shape `(batch_size,)`
+ that indicates the number of valid (ie, non-padding) tokens in each
+ sample. A padding mask can be derived from this. This should not be used
+ together with `past_key` and `past_value` inputs or `present_key` and
+ `present_value` outputs (See the KV cache use cases in the operator
+ description).
+
+ is_causal: If set to `1`, the attention masking is a lower triangular matrix
+ when the mask is a square matrix. The attention masking has the form of
+ the upper left causal bias due to the alignment.
+
+ kv_num_heads: Number of heads of key and value. Must be used with 3D inputs
+ of Q, K and V.
+
+ q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K
+ and V.
+
+ qk_matmul_output_mode: If set to `0`, qk_matmul_output is the output of qk
+ matmul. If set to `1`, qk_matmul_output includes the addition of the
+ attention mask to the output of qk matmul. If set to `2`,
+ qk_matmul_output is the output after the softcap operation. If set to
+ `3`, qk_matmul_output is the output after the softmax operation. Default
+ value is 0.
+
+ scale: Scaling factor applied to $Q*K^T$. Default value is
+ `1/sqrt(head_size)`. To prevent [numerical
+ overflow](https://tinyurl.com/sudb9s96), scale `Q`, `K` by `sqrt(scale)`
+ before matmul.
+
+ softcap: Softcap value for attention weights. Default value is 0.
+
+ softmax_precision: The floating-point precision used in softmax computation.
+ If softmax precision is not provided, the same precision as the input of
+ softmax (Q and K) is used.
+ """
+
+ schema = get_schema("Attention", 24, "")
+ op = Op(self, "Attention", schema)
+ return op(
+ *self._prepare_inputs(
+ schema, Q, K, V, attn_mask, past_key, past_value, nonpad_kv_seqlen
+ ),
+ is_causal=is_causal,
+ kv_num_heads=kv_num_heads,
+ q_num_heads=q_num_heads,
+ qk_matmul_output_mode=qk_matmul_output_mode,
+ scale=scale,
+ softcap=softcap,
+ softmax_precision=softmax_precision,
+ )
+
+ T1_Cast = TypeVar(
+ "T1_Cast",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T2_Cast: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def Cast(
+ self, input: T1_Cast, *, round_mode: str = "up", saturate: int = 1, to: int
+ ) -> T2_Cast:
+ r"""[🌐 Cast(24)](https://onnx.ai/onnx/operators/onnx__Cast.html#cast-24 "Online Documentation")
+
+
+ The operator casts the elements of a given input tensor to a data type
+ specified by the 'to' argument and returns an output tensor of the same size in
+ the converted type. The 'to' argument must be one of the data types specified
+ in the 'DataType' enum field in the TensorProto message.
+
+ Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations
+ (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may
+ yield result 100. There are some string literals reserved for special floating-point values;
+ "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively.
+ Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,
+ this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors
+ to string tensors, plain floating-point representation (such as "314.15926") would be used.
+ Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases
+ of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior.
+
+ Conversion from a numerical type to any numerical type is always allowed.
+ User must be aware of precision loss and value change caused by range difference between two types.
+ For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting
+ an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type.
+
+ In more detail, the conversion among numerical types should follow these rules
+ if the destination type is not a float 8 type.
+
+ * Casting from floating point to:
+ * floating point: +/- infinity if OOR (out of range).
+ * fixed point: undefined if OOR.
+ * bool: +/- 0.0 to False; all else to True.
+ * Casting from fixed point to:
+ * floating point: +/- infinity if OOR. (+ infinity in the case of uint)
+ * fixed point: when OOR, discard higher bits and reinterpret (with respect to two's complement representation for
+ signed types). For example, 200 (int16) -> -56 (int8).
+ * bool: zero to False; nonzero to True.
+ * Casting from bool to:
+ * floating point: `{1.0, 0.0}`.
+ * fixed point: `{1, 0}`.
+ * bool: no change.
+
+ Float 8 types (E4M3FN, E4M3FNUZ, E5M2, E5M2FNUZ) were introduced to speed up the training of
+ deep models. By default the conversion of a float *x* obeys
+ to the following rules. `[x]` means the value rounded to
+ the target mantissa width.
+
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | -------- | -------- | -------- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | Inf | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX |
+ | -Inf | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX |
+ | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX |
+ | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX |
+ | else | RNE | RNE | RNE | RNE |
+
+ The behavior changes if the parameter 'saturate' is set to False.
+ The rules then become:
+
+ | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ |
+ | ----------------- | ------ | -------- | ---- | -------- |
+ | 0 | 0 | 0 | 0 | 0 |
+ | -0 | -0 | 0 | -0 | 0 |
+ | NaN | NaN | NaN | NaN | NaN |
+ | -NaN | -NaN | NaN | -NaN | NaN |
+ | Inf | NaN | NaN | Inf | NaN |
+ | -Inf | -NaN | NaN | -Inf | NaN |
+ | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN |
+ | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN |
+ | else | RNE | RNE | RNE | RNE |
+
+ FLOAT8E8M0 type was introduced to enable [Microscaling (MX) formats](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
+ When casting to FLOAT8E8M0, the rounding behavior can be specified using the `round_mode` and `saturate` attributes.
+ The current CUDA behavior is to round up and saturate. Casting negative values to FLOAT8E8M0 gives undefined behavior.
+ The following table describes the casting behavior of special values to FLOAT8E8M0 in the two most common cases.
+
+ | x | saturate + up | non-saturate + nearest |
+ | ----------------- | ------------- | --------------------- |
+ | 0 | 0 | NaN |
+ | -0 | Unspecified | Unspecified |
+ | NaN | NaN | NaN |
+ | Inf | E8M0_MAX | NaN |
+ | x > E8M0_MAX | E8M0_MAX | NaN |
+ | x \< E8M0_MIN | E8M0_MIN | NaN |
+ | x \< 0 | Unspecified | Unspecified |
+
+
+ Args:
+ input: (differentiable) Input tensor to be cast.
+
+ round_mode: Rounding mode for conversion to float8e8m0. It only applies to
+ casting to float8e8m0 and is `up` by default. `up`: round to nearest
+ value away from zero, `down`: round to nearest value towards zero,
+ `nearest`: round to nearest value and ties round up.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz,
+ float8e8m0). It is true by default. All cases are fully described in the
+ tables inserted in the operator description.
+
+ to: The data type to which the elements of the input tensor are cast.
+ Strictly must be one of the types from DataType enum in TensorProto
+ """
+
+ schema = get_schema("Cast", 24, "")
+ op = Op(self, "Cast", schema)
+ return op(
+ *self._prepare_inputs(schema, input),
+ round_mode=round_mode,
+ saturate=saturate,
+ to=to,
+ )
+
+ T1_CastLike = TypeVar(
+ "T1_CastLike",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T2_CastLike = TypeVar(
+ "T2_CastLike",
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def CastLike(
+ self,
+ input: T1_CastLike,
+ target_type: T2_CastLike,
+ *,
+ round_mode: str = "up",
+ saturate: int = 1,
+ ) -> T2_CastLike:
+ r"""[🌐 CastLike(24)](https://onnx.ai/onnx/operators/onnx__CastLike.html#castlike-24 "Online Documentation")
+
+
+ The operator casts the elements of a given input tensor (the first input) to
+ the same data type as the elements of the second input tensor.
+ See documentation of the Cast operator for further details.
+
+
+ Args:
+ input: (differentiable) Input tensor to be cast.
+
+ target_type: (non-differentiable) The (first) input tensor will be cast to
+ produce a tensor of the same type as this (second input) tensor.
+
+ round_mode: Rounding mode for conversion to float8e8m0. It only applies to
+ casting to float8e8m0 and is `up` by default. `up`: round to nearest
+ value away from zero, `down`: round to nearest value towards zero,
+ `nearest`: round to nearest value and ties round up. Please refer to
+ operator Cast description for further details.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz,
+ float8e8m0). It is true by default. Please refer to operator Cast
+ description for further details.
+ """
+
+ schema = get_schema("CastLike", 24, "")
+ op = Op(self, "CastLike", schema)
+ return op(
+ *self._prepare_inputs(schema, input, target_type),
+ round_mode=round_mode,
+ saturate=saturate,
+ )
+
+ T_Constant: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def Constant(
+ self,
+ *,
+ sparse_value: Optional[SparseTensorProto] = None,
+ value: Optional[TensorProto] = None,
+ value_float: Optional[float] = None,
+ value_floats: Optional[Sequence[float]] = None,
+ value_int: Optional[int] = None,
+ value_ints: Optional[Sequence[int]] = None,
+ value_string: Optional[str] = None,
+ value_strings: Optional[Sequence[str]] = None,
+ ) -> T_Constant:
+ r"""[🌐 Constant(24)](https://onnx.ai/onnx/operators/onnx__Constant.html#constant-24 "Online Documentation")
+
+
+ This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value,
+ or value_* must be specified.
+
+
+ Args:
+ sparse_value: The value for the elements of the output tensor in sparse
+ format.
+
+ value: The value for the elements of the output tensor.
+
+ value_float: The value for the sole element for the scalar, float32, output
+ tensor.
+
+ value_floats: The values for the elements for the 1D, float32, output
+ tensor.
+
+ value_int: The value for the sole element for the scalar, int64, output
+ tensor.
+
+ value_ints: The values for the elements for the 1D, int64, output tensor.
+
+ value_string: The value for the sole element for the scalar, UTF-8 string,
+ output tensor.
+
+ value_strings: The values for the elements for the 1D, UTF-8 string, output
+ tensor.
+ """
+
+ schema = get_schema("Constant", 24, "")
+ op = Op(self, "Constant", schema)
+ return op(
+ sparse_value=sparse_value,
+ value=value,
+ value_float=value_float,
+ value_floats=value_floats,
+ value_int=value_int,
+ value_ints=value_ints,
+ value_string=value_string,
+ value_strings=value_strings,
+ )
+
+ T1_ConstantOfShape: TypeAlias = INT64
+
+ T2_ConstantOfShape: TypeAlias = Union[
+ BFLOAT16,
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ ]
+
+ def ConstantOfShape(
+ self, input: T1_ConstantOfShape, *, value: Optional[TensorProto] = None
+ ) -> T2_ConstantOfShape:
+ r"""[🌐 ConstantOfShape(24)](https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html#constantofshape-24 "Online Documentation")
+
+
+ Generate a tensor with given value and shape.
+
+
+ Args:
+ input: 1D tensor. The shape of the expected output tensor. If empty tensor
+ is given, the output would be a scalar. All values must be >= 0.
+
+ value: (Optional) The value of the output elements.Should be a one-element
+ tensor. If not specified, it defaults to a tensor of value 0 and
+ datatype float32
+ """
+
+ schema = get_schema("ConstantOfShape", 24, "")
+ op = Op(self, "ConstantOfShape", schema)
+ return op(*self._prepare_inputs(schema, input), value=value)
+
+ T1_DequantizeLinear = TypeVar(
+ "T1_DequantizeLinear",
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT32,
+ INT4,
+ INT8,
+ UINT16,
+ UINT4,
+ UINT8,
+ )
+
+ T2_DequantizeLinear = TypeVar("T2_DequantizeLinear", BFLOAT16, FLOAT, FLOAT16, FLOAT8E8M0)
+
+ T3_DequantizeLinear: TypeAlias = Union[BFLOAT16, FLOAT, FLOAT16]
+
+ def DequantizeLinear(
+ self,
+ x: T1_DequantizeLinear,
+ x_scale: T2_DequantizeLinear,
+ x_zero_point: Optional[T1_DequantizeLinear] = None,
+ *,
+ axis: int = 1,
+ block_size: int = 0,
+ output_dtype: int = 0,
+ ) -> T3_DequantizeLinear:
+ r"""[🌐 DequantizeLinear(24)](https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html#dequantizelinear-24 "Online Documentation")
+
+
+ The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the
+ full-precision tensor. The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point`
+ must have the same shape, determining the quantization's granularity: a scalar for per-tensor/per-layer quantization,
+ a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization.
+ See QuantizeLinear for details on quantization granularity.
+
+ `x_zero_point` and `x` must have the same type. `x` and `y` must have the same shape. In the case of dequantizing
+ `int32`, there's no zero point (zero point is supposed to be 0).
+ `zero-point` is usually not used in the case of float8 and 4-bit types quantization, but the dequantization formula remains the same
+ for consistency. The output type is determined by the attribute `output_dtype`. If `output_dtype` is not supplied then the output type
+ is the same as `x_scale`. The output type also determines the precision of the multiplication operation.
+
+
+
+ Args:
+ x: N-D quantized input tensor to be de-quantized.
+
+ x_scale: Scale for input `x`. For per-tensor/layer dequantization the scale
+ is a scalar, for per per-axis dequantization it is a 1-D Tensor and for
+ blocked dequantization it has the same shape as the input, except for
+ one dimension in which blocking is performed.
+
+ x_zero_point: (optional) Zero point for input `x`. Shape must match x_scale.
+ It's optional. Zero point is 0 when it's not specified.
+
+ axis: (Optional) The axis of the dequantizing dimension of the input tensor.
+ Used for per-axis and blocked quantization. Negative value means
+ counting dimensions from the back. Accepted range is `[-r, r-1]` where
+ `r = rank(input)`.
+
+ block_size: (Optional) The size of the quantization block (number of times
+ every scale is replicated). Used only for blocked quantization. The
+ block size is a positive integer. Given `x` shape `(D0, ..., Di, ...,
+ Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted
+ range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]`
+
+ output_dtype: (Optional) The output data type. If not supplied, the output
+ data type is inferred from `x_scale` data type (`T2`)
+ """
+
+ schema = get_schema("DequantizeLinear", 24, "")
+ op = Op(self, "DequantizeLinear", schema)
+ return op(
+ *self._prepare_inputs(schema, x, x_scale, x_zero_point),
+ axis=axis,
+ block_size=block_size,
+ output_dtype=output_dtype,
+ )
+
+ T_Flatten = TypeVar(
+ "T_Flatten",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Flatten(self, input: T_Flatten, *, axis: int = 1) -> T_Flatten:
+ r"""[🌐 Flatten(24)](https://onnx.ai/onnx/operators/onnx__Flatten.html#flatten-24 "Online Documentation")
+
+
+ Flattens the input tensor into a 2D matrix. If input tensor has shape
+ (d_0, d_1, ... d_n) then the output will have shape
+ (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn).
+
+
+ Args:
+ input: (differentiable) A tensor of rank >= axis.
+
+ axis: Indicate up to which input dimensions (exclusive) should be flattened
+ to the outer dimension of the output. The value for axis must be in the
+ range [-r, r], where r is the rank of the input tensor. Negative value
+ means counting dimensions from the back. When axis = 0, the shape of the
+ output tensor is (1, (d_0 X d_1 ... d_n), where the shape of the input
+ tensor is (d_0, d_1, ... d_n).
+ """
+
+ schema = get_schema("Flatten", 24, "")
+ op = Op(self, "Flatten", schema)
+ return op(*self._prepare_inputs(schema, input), axis=axis)
+
+ V_Identity = TypeVar(
+ "V_Identity",
+ Optional[Sequence[BOOL]],
+ Optional[Sequence[COMPLEX128]],
+ Optional[Sequence[COMPLEX64]],
+ Optional[Sequence[DOUBLE]],
+ Optional[Sequence[FLOAT]],
+ Optional[Sequence[FLOAT16]],
+ Optional[Sequence[INT16]],
+ Optional[Sequence[INT32]],
+ Optional[Sequence[INT64]],
+ Optional[Sequence[INT8]],
+ Optional[Sequence[STRING]],
+ Optional[Sequence[UINT16]],
+ Optional[Sequence[UINT32]],
+ Optional[Sequence[UINT64]],
+ Optional[Sequence[UINT8]],
+ Optional[BOOL],
+ Optional[COMPLEX128],
+ Optional[COMPLEX64],
+ Optional[DOUBLE],
+ Optional[FLOAT],
+ Optional[FLOAT16],
+ Optional[INT16],
+ Optional[INT32],
+ Optional[INT64],
+ Optional[INT8],
+ Optional[STRING],
+ Optional[UINT16],
+ Optional[UINT32],
+ Optional[UINT64],
+ Optional[UINT8],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Identity(self, input: V_Identity) -> V_Identity:
+ r"""[🌐 Identity(24)](https://onnx.ai/onnx/operators/onnx__Identity.html#identity-24 "Online Documentation")
+
+ Identity operator
+
+ Args:
+ input: (differentiable) Input tensor
+ """
+
+ schema = get_schema("Identity", 24, "")
+ op = Op(self, "Identity", schema)
+ return op(*self._prepare_inputs(schema, input))
+
+ B_If: TypeAlias = BOOL
+
+ V_If: TypeAlias = Union[
+ None,
+ Sequence[BFLOAT16],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ Sequence[FLOAT4E2M1],
+ Sequence[FLOAT8E4M3FN],
+ Sequence[FLOAT8E4M3FNUZ],
+ Sequence[FLOAT8E5M2],
+ Sequence[FLOAT8E5M2FNUZ],
+ Sequence[FLOAT8E8M0],
+ Sequence[INT4],
+ Sequence[UINT4],
+ ]
+
+ def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If:
+ r"""[🌐 If(24)](https://onnx.ai/onnx/operators/onnx__If.html#if-24 "Online Documentation")
+
+ If conditional
+
+ Args:
+ cond: Condition for the if. The tensor must contain a single element.
+
+ else_branch: Graph to run if condition is false. Has N outputs: values you
+ wish to be live-out to the enclosing scope. The number of outputs must
+ match the number of outputs in the then_branch.
+
+ then_branch: Graph to run if condition is true. Has N outputs: values you
+ wish to be live-out to the enclosing scope. The number of outputs must
+ match the number of outputs in the else_branch.
+ """
+
+ schema = get_schema("If", 24, "")
+ op = Op(self, "If", schema)
+ return op(
+ *self._prepare_inputs(schema, cond),
+ else_branch=else_branch,
+ then_branch=then_branch,
+ )
+
+ I_Loop: TypeAlias = INT64
+
+ B_Loop: TypeAlias = BOOL
+
+ V_Loop = TypeVar(
+ "V_Loop",
+ Optional[Sequence[BFLOAT16]],
+ Optional[Sequence[BOOL]],
+ Optional[Sequence[COMPLEX128]],
+ Optional[Sequence[COMPLEX64]],
+ Optional[Sequence[DOUBLE]],
+ Optional[Sequence[FLOAT]],
+ Optional[Sequence[FLOAT16]],
+ Optional[Sequence[INT16]],
+ Optional[Sequence[INT32]],
+ Optional[Sequence[INT64]],
+ Optional[Sequence[INT8]],
+ Optional[Sequence[STRING]],
+ Optional[Sequence[UINT16]],
+ Optional[Sequence[UINT32]],
+ Optional[Sequence[UINT64]],
+ Optional[Sequence[UINT8]],
+ Optional[BFLOAT16],
+ Optional[BOOL],
+ Optional[COMPLEX128],
+ Optional[COMPLEX64],
+ Optional[DOUBLE],
+ Optional[FLOAT],
+ Optional[FLOAT16],
+ Optional[FLOAT4E2M1],
+ Optional[FLOAT8E4M3FN],
+ Optional[FLOAT8E4M3FNUZ],
+ Optional[FLOAT8E5M2],
+ Optional[FLOAT8E5M2FNUZ],
+ Optional[FLOAT8E8M0],
+ Optional[INT16],
+ Optional[INT32],
+ Optional[INT4],
+ Optional[INT64],
+ Optional[INT8],
+ Optional[STRING],
+ Optional[UINT16],
+ Optional[UINT32],
+ Optional[UINT4],
+ Optional[UINT64],
+ Optional[UINT8],
+ Sequence[BFLOAT16],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[FLOAT4E2M1],
+ Sequence[FLOAT8E4M3FN],
+ Sequence[FLOAT8E4M3FNUZ],
+ Sequence[FLOAT8E5M2],
+ Sequence[FLOAT8E5M2FNUZ],
+ Sequence[FLOAT8E8M0],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT4],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT4],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Loop(
+ self,
+ M: Optional[I_Loop],
+ cond: Optional[B_Loop],
+ *v_initial: V_Loop,
+ body: GraphProto,
+ ) -> V_Loop:
+ r"""[🌐 Loop(24)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-24 "Online Documentation")
+
+
+ Generic Looping construct. This loop has multiple termination conditions:
+
+ 1) Trip count. Iteration count specified at runtime. Set by
+ specifying the input M. Optional. Set to empty string to omit.
+ Note that a static trip count (specified at graph construction time) can be
+ specified by passing in a constant node for input M.
+ 2) Loop termination condition. This is an input to the op that determines
+ whether to run the first iteration and also a loop-carried dependency for
+ the body graph. The body graph must yield a value for the condition variable,
+ whether this input is provided or not.
+
+ This table summarizes the operating modes of this operator with equivalent
+ C-style code:
+
+ Operator inputs defined as (max_trip_count, condition_var).
+
+ * input ("", ""):
+ for (int i=0; ; ++i) {
+ cond = ... // Note this value is ignored, but is required in the body
+ }
+
+ * input ("", cond) // Note this is analogous to a while loop
+ bool cond = ...;
+ for (int i=0; cond; ++i) {
+ cond = ...;
+ }
+
+ * input ("", 1) // Note this is analogous to a do-while loop
+ bool cond = true
+ for (int i=0; cond; ++i) {
+ cond = ...;
+ }
+
+ * input (trip_count, "") // Note this is analogous to a for loop
+ int trip_count = ...
+ for (int i=0; i < trip_count; ++i) {
+ cond = ...; // ignored
+ }
+
+ * input (trip_count, cond)
+ int trip_count = ...;
+ bool cond = ...;
+ for (int i=0; i < trip_count && cond; ++i) {
+ cond = ...;
+ }
+
+
+ *Sample usage - cond as well as trip count*
+
+ graph predict-net {
+ %a = Constant[value = ]()
+ %b = Constant[value = ]()
+ %keepgoing = Constant[value = ]()
+ %max_trip_count = Constant[value = ]()
+ %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b)
+ return
+ }
+
+ graph body-net (
+ %i[INT32, scalar] // iteration number
+ %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used
+ %b_in[INT32, scalar] // incoming value of loop-carried-dependency b
+ ) {
+ %my_local = Add(%a, %b_in)
+ %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b
+ %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition
+ %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated
+ return %keepgoing_out, %b_out, %user_defined_val
+ }
+
+ *Sample equivalent C code*
+
+ {
+ /* User-defined code (enclosing scope) */
+ int a = 3, b = 6;
+ bool keepgoing = true; // Analogous to input cond
+ /* End user-defined code */
+
+ /* Implicitly-defined code */
+ const int max_trip_count = 10; // Analogous to input M
+ int user_defined_vals[]; // Imagine this is resizable
+ /* End implicitly-defined code */
+ /* initialize loop-carried variables and scan-output variables */
+ bool keepgoing_out = keepgoing
+ int b_out = b
+
+ for (int i=0; i < max_trip_count && keepgoing_out; ++i) {
+ /* Implicitly-defined code: bind actual parameter values
+ to formal parameter variables of loop-body */
+ bool keepgoing_in = keepgoing_out;
+ bool b_in = b_out;
+
+ /* User-defined code (loop body) */
+ int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine
+ b_out = a - b_in;
+ keepgoing_out = my_local > b_out;
+ user_defined_val = b_in + b_in; // b_in and b_out are different variables
+ /* End user-defined code */
+
+ /* Implicitly defined-code */
+ user_defined_vals[i] = user_defined_val // accumulate scan-output values
+ }
+ // int t = my_local; // Can't do this. my_local is not accessible here.
+
+ // The values below are bound to the output variables of the loop and therefore accessible
+ // b_out; user_defined_vals; keepgoing_out;
+ }
+
+ There are several things of note in this code snippet:
+
+ 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can
+ be referenced in the inputs of the loop.
+ 2) Any values computed in the loop body that needs to be used in a subsequent
+ iteration or after the loop are modelled using a pair of variables in the loop-body,
+ consisting of an input variable (eg., b_in) and an output variable (eg., b_out).
+ These are referred to as loop-carried dependences. The loop operation node
+ supplies the input value of the input variable for the first iteration, and
+ returns the output value of the output variable produced by the final
+ iteration.
+ 3) Scan_output variables are used to implicitly concatenate values computed across
+ all the iterations. In the above example, the value of user_defined_val computed
+ over all iterations are concatenated and returned as the value of user_defined_vals
+ after the loop.
+ 4) Values created in the body cannot be accessed in the enclosing scope,
+ except using the mechanism described above.
+
+ Note that the semantics of this op support "diagonal" or "wavefront" execution.
+ (See Step 3 here for an example:
+ https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/).
+ Frontends should emit multi-layer RNNs as a series of While operators (with
+ time being the inner looping dimension), with each successive layer consuming
+ the scan_outputs from the previous layer, possibly going through several
+ point-wise operators (e.g. dropout, residual connections, linear layer).
+
+ The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order.
+
+
+ Args:
+ M: (optional) A maximum trip-count for the loop specified at runtime.
+ Optional. Pass empty string to skip.
+
+ cond: (optional) A boolean termination condition. Optional. Pass empty
+ string to skip.
+
+ v_initial: (variadic, heterogeneous) The initial values of any loop-carried
+ dependencies (values that change across loop iterations)
+
+ body: The graph run each iteration. It has 2+N inputs: (iteration_num,
+ condition, loop carried dependencies...). It has 1+N+K outputs:
+ (condition, loop carried dependencies..., scan_outputs...). Each
+ scan_output is created by concatenating the value of the specified
+ output value at the end of each iteration of the loop. It is an error if
+ the dimensions or data type of these scan_outputs change across loop
+ iterations.
+ """
+
+ schema = get_schema("Loop", 24, "")
+ op = Op(self, "Loop", schema)
+ return op(*self._prepare_inputs(schema, M, cond, *v_initial), body=body)
+
+ T_Pad = TypeVar(
+ "T_Pad",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ Tind_Pad = TypeVar("Tind_Pad", INT32, INT64)
+
+ def Pad(
+ self,
+ data: T_Pad,
+ pads: INT64,
+ constant_value: Optional[T_Pad] = None,
+ axes: Optional[Tind_Pad] = None,
+ *,
+ mode: str = "constant",
+ ) -> T_Pad:
+ r"""[🌐 Pad(24)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-24 "Online Documentation")
+
+
+ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`,
+ a padded tensor (`output`) is generated.
+
+ The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`):
+
+ 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False)
+
+ 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis
+
+ 3) `edge` - pads with the edge values of array
+
+ 4) `wrap` - wrap-around padding as if the data tensor forms a torus
+
+
+ Example 1 (`constant` mode):
+
+ Insert 0 pads to the beginning of the second dimension.
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'constant'
+
+ constant_value = 0.0
+
+ output = [
+ [0.0, 0.0, 1.0, 1.2],
+ [0.0, 0.0, 2.3, 3.4],
+ [0.0, 0.0, 4.5, 5.7],
+ ]
+
+
+
+ Example 2 (`reflect` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'reflect'
+
+ output = [
+ [1.0, 1.2, 1.0, 1.2],
+ [2.3, 3.4, 2.3, 3.4],
+ [4.5, 5.7, 4.5, 5.7],
+ ]
+
+
+
+ Example 3 (`edge` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [0, 2, 0, 0]
+
+ mode = 'edge'
+
+ output = [
+ [1.0, 1.0, 1.0, 1.2],
+ [2.3, 2.3, 2.3, 3.4],
+ [4.5, 4.5, 4.5, 5.7],
+ ]
+
+
+
+ Example 4 (`wrap` mode):
+
+ ::
+
+ data = [
+ [1.0, 1.2],
+ [2.3, 3.4],
+ [4.5, 5.7],
+ ]
+
+ pads = [2, 1, 1, 1]
+
+ mode = 'wrap'
+
+ output = [
+ [3.4, 2.3, 3.4, 2.3],
+ [5.7, 4.5, 5.7, 4.5],
+ [1.2, 1.0, 1.2, 1.0],
+ [3.4, 2.3, 3.4, 2.3],
+ [5.7, 4.5, 5.7, 4.5],
+ [1.2, 1.0, 1.2, 1.0],
+ ]
+
+
+
+
+ Args:
+ data: (differentiable) Input tensor.
+
+ pads: (non-differentiable) Tensor of integers indicating the number of
+ padding elements to add or remove (if negative) at the beginning and end
+ of each axis. For 2D input tensor, it is the number of pixels. `pads`
+ should be a 1D tensor of shape [2 * num_axes] where `num_axes` refers to
+ the number of elements in the `axes` input or the input rank if `axes`
+ are not provided explicitly. `pads` format should be: [x1_begin,
+ x2_begin, ..., x1_end, x2_end,...], where xi_begin is the number of pad
+ values added at the beginning of axis `axes[i]` and xi_end, the number
+ of pad values added at the end of axis `axes[i]`.
+
+ constant_value: (optional, non-differentiable) (Optional) A scalar value to
+ be used if the mode chosen is `constant` (by default it is 0, empty
+ string or False).
+
+ axes: (optional, non-differentiable) 1-D tensor of axes that `pads` apply
+ to. Negative value means counting dimensions from the back. Accepted
+ range is [-r, r-1] where r = rank(data). Behavior is undefined if an
+ axis is repeated. If not provided, all axes are assumed (`[0, 1, ...,
+ input_rank-1]`).
+
+ mode: Supported modes: `constant`(default), `reflect`, `edge`, `wrap`
+ """
+
+ schema = get_schema("Pad", 24, "")
+ op = Op(self, "Pad", schema)
+ return op(*self._prepare_inputs(schema, data, pads, constant_value, axes), mode=mode)
+
+ T1_QuantizeLinear = TypeVar("T1_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32)
+
+ T2_QuantizeLinear = TypeVar(
+ "T2_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, FLOAT8E8M0, INT32
+ )
+
+ T3_QuantizeLinear = TypeVar(
+ "T3_QuantizeLinear",
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ INT16,
+ INT4,
+ INT8,
+ UINT16,
+ UINT4,
+ UINT8,
+ )
+
+ def QuantizeLinear(
+ self,
+ x: T1_QuantizeLinear,
+ y_scale: T2_QuantizeLinear,
+ y_zero_point: Optional[T3_QuantizeLinear] = None,
+ *,
+ axis: int = 1,
+ block_size: int = 0,
+ output_dtype: int = 0,
+ precision: int = 0,
+ saturate: int = 1,
+ ) -> T3_QuantizeLinear:
+ r"""[🌐 QuantizeLinear(24)](https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html#quantizelinear-24 "Online Documentation")
+
+
+ The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the
+ low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization
+ granularity. The quantization formula is `y = saturate((x / y_scale) + y_zero_point)`.
+
+ Saturation is done according to:
+ - uint16: [0, 65535]
+ - int16: [-32768, 32767]
+ - uint8: [0, 255]
+ - int8: [-128, 127]
+ - uint4: [0, 15]
+ - int4: [-8, 7]
+
+ For `(x / y_scale)`, it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
+
+ `y_zero_point` and `y` must have the same type. `y_zero_point` is usually not used for quantization to float8 and 4bit types, but the quantization
+ formula remains the same for consistency, and the type of the attribute `y_zero_point` still determines the quantization type.
+ `x` and `y_scale` are allowed to have different types. The type of `y_scale` determines the precision of the division operation between `x` and
+ `y_scale`, unless the `precision` attribute is specified.
+
+ There are three supported quantization granularities, determined by the shape of `y_scale`.
+ In all cases, `y_zero_point` must have the same shape as `y_scale`.
+ - Per-tensor (per-layer) quantization: `y_scale` is a scalar.
+ - Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape
+ `(D0, ..., Di, ..., Dn)` and `axis=i`, `y_scale` is a 1-D tensor of length `Di`.
+ - Blocked quantization: The scale's shape is identical to the input's shape, except for one dimension, in which
+ blocking is performed. Given `x` shape `(D0, ..., Di, ..., Dn)`, `axis=i`, and block size `B`: `y_scale` shape is
+ `(D0, ..., ceil(Di/B), ..., Dn)`.
+
+
+ Args:
+ x: N-D full precision Input tensor to be quantized.
+
+ y_scale: Scale for doing quantization to get `y`. For per-tensor/layer
+ quantization the scale is a scalar, for per-axis quantization it is a
+ 1-D Tensor and for blocked quantization it has the same shape as the
+ input, except for one dimension in which blocking is performed.
+
+ y_zero_point: (optional) Zero point for doing quantization to get `y`. Shape
+ must match `y_scale`. Default is uint8 with zero point of 0 if it's not
+ specified.
+
+ axis: (Optional) The axis of the dequantizing dimension of the input tensor.
+ Used only for per-axis and blocked quantization. Negative value means
+ counting dimensions from the back. Accepted range is `[-r, r-1]` where
+ `r = rank(input)`. When the rank of the input is 1, per-tensor
+ quantization is applied, rendering the axis unnecessary in this
+ scenario.
+
+ block_size: (Optional) The size of the quantization block (number of times
+ every scale is replicated). Used only for blocked quantization. The
+ block size is a positive integer. Given `x` shape `(D0, ..., Di, ...,
+ Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted
+ range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]`
+
+ output_dtype: (Optional) The output data type. If not supplied, the output
+ data type is inferred from `y_zero_point` data type (`T3`). If neither
+ `output_dtype` nor `y_zero_point` are supplied, output data type is
+ uint8. If both `output_dtype` and `y_zero_point` are specified,
+ `output_dtype` must be `T3`.
+
+ precision: (Optional) The precision of the division operation between `x`
+ and `y_scale`. If not provided, it will be the same as the type of
+ `y_scale`.
+
+ saturate: The parameter defines how the conversion behaves if an input value
+ is out of range of the destination type. It only applies for float 8
+ quantization (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz).
+ It is true by default. All cases are fully described in two tables
+ inserted in the operator description.
+ """
+
+ schema = get_schema("QuantizeLinear", 24, "")
+ op = Op(self, "QuantizeLinear", schema)
+ return op(
+ *self._prepare_inputs(schema, x, y_scale, y_zero_point),
+ axis=axis,
+ block_size=block_size,
+ output_dtype=output_dtype,
+ precision=precision,
+ saturate=saturate,
+ )
+
+ T_Reshape = TypeVar(
+ "T_Reshape",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Reshape(self, data: T_Reshape, shape: INT64, *, allowzero: int = 0) -> T_Reshape:
+ r"""[🌐 Reshape(24)](https://onnx.ai/onnx/operators/onnx__Reshape.html#reshape-24 "Online Documentation")
+
+
+ Reshape the input tensor similar to numpy.reshape.
+ First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor.
+ At most one dimension of the new shape can be -1. In this case, the value is
+ inferred from the size of the tensor and the remaining dimensions. A dimension
+ could also be 0, in which case the actual dimension value is unchanged (i.e. taken
+ from the input tensor). If 'allowzero' is set, and the new shape includes 0, the
+ dimension will be set explicitly to zero (i.e. not taken from input tensor).
+ Shape (second input) could be an empty shape, which means converting to a scalar.
+ The input tensor's shape and the output tensor's shape are required to have the same number of elements.
+
+ If the attribute 'allowzero' is set, it is invalid for the specified shape to
+ contain both a zero value and -1, as the value of the dimension corresponding
+ to -1 cannot be determined uniquely.
+
+
+ Args:
+ data: (differentiable) An input tensor.
+
+ shape: (non-differentiable) Specified shape for output.
+
+ allowzero: (Optional) By default, when any value in the 'shape' input is
+ equal to zero the corresponding dimension value is copied from the input
+ tensor dynamically. allowzero=1 indicates that if any value in the
+ 'shape' input is set to zero, the zero value is honored, similar to
+ NumPy.
+ """
+
+ schema = get_schema("Reshape", 24, "")
+ op = Op(self, "Reshape", schema)
+ return op(*self._prepare_inputs(schema, data, shape), allowzero=allowzero)
+
+ V_Scan = TypeVar(
+ "V_Scan",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Scan(
+ self,
+ *initial_state_and_scan_inputs: V_Scan,
+ body: GraphProto,
+ num_scan_inputs: int,
+ scan_input_axes: Optional[Sequence[int]] = None,
+ scan_input_directions: Optional[Sequence[int]] = None,
+ scan_output_axes: Optional[Sequence[int]] = None,
+ scan_output_directions: Optional[Sequence[int]] = None,
+ ) -> V_Scan:
+ r"""[🌐 Scan(24)](https://onnx.ai/onnx/operators/onnx__Scan.html#scan-24 "Online Documentation")
+
+
+ Scan can be used to iterate over one or more scan_input tensors,
+ constructing zero or more scan_output tensors. It combines ideas from general recurrences,
+ functional programming constructs such as scan, fold, map, and zip, and is intended to enable
+ generalizations of RNN-like constructs for sequence-to-sequence processing.
+ Other tensors (referred to as state_variables here) can be used to carry a state
+ when iterating from one element to another (similar to hidden-state in RNNs, also referred
+ to as loop-carried dependences in the context of loops).
+ Many common usages involve a single scan_input tensor (where functionality
+ similar to scan, fold and map can be obtained). When more than one scan_input is used,
+ a behavior similar to zip is obtained.
+
+ The attribute body must be a graph, specifying the computation to be performed in
+ every iteration. It takes as input the current values of the state_variables and
+ the current iterated element of the scan_inputs. It must return the (updated) values
+ of the state_variables and zero or more scan_output_element tensors. The values of the
+ scan_output_element tensors are concatenated over all the iterations to produce the
+ scan_output values of the scan construct (similar to the concatenated intermediate
+ hidden-state values of RNN-like constructs). All the output tensors (state_variables as
+ well as scan_output_element tensors) are required to have the same shape in each iteration
+ of the loop (a restriction imposed to enable efficient memory allocation).
+
+ Note that the iterated element passed to the body subgraph does not have a sequence
+ axis. It will have a rank one less than the rank of the corresponding scan_input.
+
+ The scan operation returns the final values of the state_variables as well as the
+ scan_outputs.
+
+ The optional attribute scan_input_directions specifies the direction (forward or backward)
+ for each scan input. If this attribute is omitted, all sequences are scanned in the forward
+ direction. A bidirectional scan may be performed by specifying the same tensor input twice
+ in the scan_inputs, once with a forward direction, and once with a backward direction.
+
+ The scan_output of the operation is produced by concatenating the scan_output_element
+ values produced by the body in each iteration. The optional attribute scan_output_directions
+ specifies the direction in which scan_output is constructed (by appending or prepending the
+ scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
+ is omitted, the scan_output_element is appended to the scan_output in each iteration.
+
+ The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
+ If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
+ batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
+ Note that scanning a non-zero axis may be less efficient than scanning axis zero.
+
+ The optional attribute scan_output_axes specifies the axis along which the scan_outputs
+ are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
+ scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
+ value of 1.
+
+ Note that because of the ONNX restriction that only the last parameter of an operator can
+ be variadic, the initial-states and scan-inputs are listed together as one input parameter.
+ Similarly, the final-states and scan-outputs are listed together as one output parameter.
+ The attribute num_scan_inputs indicates the number M of scan-inputs.
+
+ The behavior of
+
+ Scan <
+ num_scan_inputs = m,
+ body = loop-body,
+ scan_input_axes = [axis_1, ..., axis_m]
+ > (init_1, ..., init_n, scan_1, ..., scan_m)
+
+ is equivalent to the following pseudo-code:
+
+ // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
+ // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
+ sequence_length = scan_1.shape[axis_1];
+
+ // initialize state-variables
+ st_1 = init_1; ... st_n = init_n;
+ // initialize scan-output variables: [] denotes an empty tensor
+ scan_out_1 = []; ...; scan_out_k = [];
+ // identify number of iterations:
+
+ // execute loop
+ for (int t = 0; t < sequence_length; ++t) {
+ // generate the scan-input elements: the notation T[t] indicates the sub-tensor
+ // of rank one less than T obtained by indexing T at position t along axis k.
+ si_1 = scan_1[t];
+ ... ;
+ si_m = scan_m[t];
+ // execute loop-body
+ st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
+ // accumulate the scan-output elements
+ scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k);
+ }
+
+ return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
+
+ *Sample usage: Encoding RNN using a Scan*
+
+ The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
+ recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
+ be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
+ %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
+ values are computed in the outer graph, they need to be passed in as extra state_variables.
+
+ graph rnn-encoding {
+ %H_0 = ...
+ %X = ...
+ %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X)
+ return %Y, %Y_h
+ }
+
+ graph rnn-cell-1 (
+ %H_tminus1[FLOAT, tensor]
+ %X_t[FLOAT, tensor]
+ ) {
+ %Wi = ...
+ %Ri = ...
+ %Wbi = ...
+ %Rbi = ...
+ %t1 = X_t * (Wi^T)
+ %t2 = H_tminus1*(Ri^T)
+ %t3 = Add(%t1, %t2)
+ %t4 = Add(%t3, %Wbi)
+ %t5 = Add(%t4, %Rbi)
+ %Ht = Tanh(%t5)
+ %Accumulate = Identity(%Ht)
+ return %Ht, %Accumulate
+ }
+
+
+
+ Args:
+ initial_state_and_scan_inputs: (variadic, heterogeneous) Initial values of
+ the loop's N state variables followed by M scan_inputs
+
+ body: The graph run each iteration. It has N+M inputs: (loop state
+ variables..., scan_input_elts...). It has N+K outputs: (loop state
+ variables..., scan_output_elts...). Each scan_output is created by
+ concatenating the value of the specified scan_output_elt value at the
+ end of each iteration of the loop. It is an error if the dimensions of
+ these values change across loop iterations.
+
+ num_scan_inputs: An attribute specifying the number of scan_inputs M.
+
+ scan_input_axes: An optional list of M flags. The i-th element of the list
+ specifies the axis to be scanned (the sequence axis) for the i-th
+ scan_input. If omitted, 0 will be used as the scan axis for every
+ scan_input. Negative value for an axis means counting dimensions from
+ the back. Accepted range is [-r, r-1] where r = rank(input).
+
+ scan_input_directions: An optional list of M flags. The i-th element of the
+ list specifies the direction to be scanned for the i-th scan_input
+ tensor: 0 indicates forward direction and 1 indicates reverse direction.
+ If omitted, all scan_input tensors will be scanned in the forward
+ direction.
+
+ scan_output_axes: An optional list of K flags. The i-th element of the list
+ specifies the axis for the i-th scan_output. The scan outputs are
+ accumulated along the specified axis. If omitted, 0 will be used as the
+ scan axis for every scan_output. Negative value for an axis means
+ counting dimensions from the back. Accepted range is [-r, r-1].
+
+ scan_output_directions: An optional list of K flags, one for each
+ scan_output. The i-th element of the list specifies whether the i-th
+ scan_output should be constructed by appending or prepending a new value
+ in each iteration: 0 indicates appending and 1 indicates prepending. If
+ omitted, all scan_output tensors will be produced by appending a value
+ in each iteration.
+ """
+
+ schema = get_schema("Scan", 24, "")
+ op = Op(self, "Scan", schema)
+ return op(
+ *self._prepare_inputs(schema, *initial_state_and_scan_inputs),
+ body=body,
+ num_scan_inputs=num_scan_inputs,
+ scan_input_axes=scan_input_axes,
+ scan_input_directions=scan_input_directions,
+ scan_output_axes=scan_output_axes,
+ scan_output_directions=scan_output_directions,
+ )
+
+ T_Shape = TypeVar(
+ "T_Shape",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T1_Shape: TypeAlias = INT64
+
+ def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) -> T1_Shape:
+ r"""[🌐 Shape(24)](https://onnx.ai/onnx/operators/onnx__Shape.html#shape-24 "Online Documentation")
+
+
+ Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor.
+ Optional attributes start and end can be used to compute a slice of the input tensor's shape.
+ If start axis is omitted, the slice starts from axis 0.
+ The end axis, if specified, is exclusive (and the returned value will not include the size of that axis).
+ If the end axis is omitted, the axes upto the last one will be included.
+ Negative axes indicate counting back from the last axis.
+ Note that axes will be clamped to the range [0, r], where r is the
+ rank of the input tensor if they are out-of-range (after adding r in the case of
+ negative axis). Thus, specifying any end value > r is equivalent to specifying an end
+ value of r, and specifying any start value < -r is equivalent to specifying a start
+ value of 0. If start > end, the result will be an empty shape.
+
+ Examples:
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ No attributes specified.
+ Output: [2, 3, 4]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ start: -1
+ Output: [4]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ end: -1
+ Output: [2, 3]
+
+
+
+ ::
+
+ Input tensor with shape: [2, 3, 4]
+ start: 1
+ end: 2
+ Output: [3]
+
+
+
+
+ Args:
+ data: (non-differentiable) An input tensor.
+
+ end: (Optional) Ending axis for slicing the shape. Negative value means
+ counting dimensions from the back. If omitted, sizes of all axes upto
+ (including) the last one will be included.
+
+ start: (Optional) Starting axis for slicing the shape. Default value is
+ 0.Negative value means counting dimensions from the back.
+ """
+
+ schema = get_schema("Shape", 24, "")
+ op = Op(self, "Shape", schema)
+ return op(*self._prepare_inputs(schema, data), end=end, start=start)
+
+ T_Size = TypeVar(
+ "T_Size",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ T1_Size: TypeAlias = INT64
+
+ def Size(self, data: T_Size) -> T1_Size:
+ r"""[🌐 Size(24)](https://onnx.ai/onnx/operators/onnx__Size.html#size-24 "Online Documentation")
+
+
+ Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor.
+
+
+ Args:
+ data: (non-differentiable) An input tensor.
+ """
+
+ schema = get_schema("Size", 24, "")
+ op = Op(self, "Size", schema)
+ return op(*self._prepare_inputs(schema, data))
+
+ T_SplitToSequence = TypeVar(
+ "T_SplitToSequence",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ I_SplitToSequence = TypeVar("I_SplitToSequence", INT32, INT64)
+
+ S_SplitToSequence: TypeAlias = Union[
+ Sequence[BFLOAT16],
+ Sequence[BOOL],
+ Sequence[COMPLEX128],
+ Sequence[COMPLEX64],
+ Sequence[DOUBLE],
+ Sequence[FLOAT],
+ Sequence[FLOAT16],
+ Sequence[INT16],
+ Sequence[INT32],
+ Sequence[INT64],
+ Sequence[INT8],
+ Sequence[STRING],
+ Sequence[UINT16],
+ Sequence[UINT32],
+ Sequence[UINT64],
+ Sequence[UINT8],
+ ]
+
+ def SplitToSequence(
+ self,
+ input: T_SplitToSequence,
+ split: Optional[I_SplitToSequence] = None,
+ *,
+ axis: int = 0,
+ keepdims: int = 1,
+ ) -> S_SplitToSequence:
+ r"""[🌐 SplitToSequence(24)](https://onnx.ai/onnx/operators/onnx__SplitToSequence.html#splittosequence-24 "Online Documentation")
+
+
+ Split a tensor into a sequence of tensors, along the specified 'axis'.
+ Lengths of the parts can be specified using the optional argument 'split'.
+ If the argument `split' is not specified, a default scalar value of 1
+ is used as the value of `split'.
+ 'split' must contain only positive numbers.
+ 'split' is either a scalar (tensor of empty shape), or a 1-D tensor.
+ If 'split' is a scalar, then 'input' will be split into chunks all of size 'split'
+ if possible. The last chunk alone may be smaller than 'split' if the 'input' size
+ along the given axis 'axis' is not divisible by 'split'.
+ If 'split' is a 1-dimensional tensor, the input tensor is split into 'size(split)' chunks,
+ with lengths of the parts on 'axis' specified in 'split'. In this scenario, the sum of entries
+ in 'split' must be equal to the dimension size of input tensor on 'axis'.
+
+
+ Args:
+ input: The tensor to split
+
+ split: (optional) Length of each output. It can be either a scalar(tensor of
+ empty shape), or a 1-D tensor. All values must be >= 0.
+
+ axis: Which axis to split on. A negative value means counting dimensions
+ from the back. Accepted range is [-rank, rank-1].
+
+ keepdims: Keep the split dimension or not. Default 1, which means we keep
+ split dimension. If input 'split' is specified, this attribute is
+ ignored.
+ """
+
+ schema = get_schema("SplitToSequence", 24, "")
+ op = Op(self, "SplitToSequence", schema)
+ return op(*self._prepare_inputs(schema, input, split), axis=axis, keepdims=keepdims)
+
+ T_Squeeze = TypeVar(
+ "T_Squeeze",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Squeeze(self, data: T_Squeeze, axes: Optional[INT64] = None) -> T_Squeeze:
+ r"""[🌐 Squeeze(24)](https://onnx.ai/onnx/operators/onnx__Squeeze.html#squeeze-24 "Online Documentation")
+
+
+ Remove single-dimensional entries from the shape of a tensor.
+ Takes an input `axes` with a list of axes to squeeze.
+ If `axes` is not provided, all the single dimensions will be removed from
+ the shape. If an axis is selected with shape entry not equal to one, an error is raised.
+
+
+ Args:
+ data: (differentiable) Tensors with at least max(dims) dimensions.
+
+ axes: (optional, non-differentiable) 1D tensor of integers indicating the
+ dimensions to squeeze. Negative value means counting dimensions from the
+ back. Accepted range is [-r, r-1] where r = rank(data).
+ """
+
+ schema = get_schema("Squeeze", 24, "")
+ op = Op(self, "Squeeze", schema)
+ return op(*self._prepare_inputs(schema, data, axes))
+
+ T_Swish = TypeVar("T_Swish", BFLOAT16, DOUBLE, FLOAT, FLOAT16)
+
+ def Swish(self, X: T_Swish, *, alpha: float = 1.0) -> T_Swish:
+ r"""[🌐 Swish(24)](https://onnx.ai/onnx/operators/onnx__Swish.html#swish-24 "Online Documentation")
+
+
+ Swish function takes one input data (Tensor) and produces one output data (Tensor) of the same shape,
+ where $Swish(x) = x * sigmoid(alpha * x)$.
+
+
+ Args:
+ X: (differentiable) Input tensor
+
+ alpha: Coefficient to multiply with input before sigmoid.
+ """
+
+ schema = get_schema("Swish", 24, "")
+ op = Op(self, "Swish", schema)
+ return op(*self._prepare_inputs(schema, X), alpha=alpha)
+
+ T_TensorScatter = TypeVar(
+ "T_TensorScatter",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def TensorScatter(
+ self,
+ past_cache: T_TensorScatter,
+ update: T_TensorScatter,
+ write_indices: Optional[INT64] = None,
+ *,
+ axis: int = -2,
+ mode: str = "linear",
+ ) -> T_TensorScatter:
+ r"""[🌐 TensorScatter(24)](https://onnx.ai/onnx/operators/onnx__TensorScatter.html#tensorscatter-24 "Online Documentation")
+
+
+ TensorScatter is a generic tensor update operation, motivated by the requirements for KV cache updates for Attention
+ ops commonly found in LLMs. It is a functional operation that models an in-place update to a KV cache buffer.
+
+ The past and present cache tensors have the same shape (batch_size, D1, D2, ..., max_sequence_length, ..., Dn), with
+ the sequence dimension (indicated by the `axis` attribute) being max_sequence_length, so the sizes of these tensors do
+ not need to grow between iterations. The `update` tensor's shape only differs from the cache tensors in the sequence
+ dimension: (batch_size, D1, D2, ..., sequence_length, ..., Dn), where sequence_length <= max_sequence_length.
+
+ The optional `write_indices` input indicates the write index for each sample in the batch, assumed to be zero
+ if not provided. When the `mode` attribute is set to "circular", the write index is modulo max_sequence_length.
+ The operation can be described using the following pseudocode:
+
+ ::
+
+ for prefix_idx in np.ndindex(past_cache.shape[:axis]):
+ batch_idx = prefix_idx[0]
+ for sequence_idx in range(sequence_length):
+ cache_idx = (*prefix_idx, write_indices[batch_idx] + sequence_idx)
+ if mode == "circular":
+ cache_idx = tuple(np.mod(np.asarray(cache_idx), max_sequence_length))
+ update_idx = (*prefix_idx, sequence_idx)
+ present_cache[cache_idx] = update[update_idx]
+
+
+
+ During the prefill phase of attention, only the first two inputs are needed. During the decode phase, `write_indices`
+ is also needed so that the incoming key or value update can be appended after the last valid token for each sample
+ in the batch.
+
+
+ Args:
+ past_cache: (differentiable) Past state cache for key or value with shape
+ `(batch_size, D1, D2, ..., max_sequence_length, ..., Dn)`.
+
+ update: (differentiable) New update tensor with shape `(batch_size, D1, D2,
+ ..., sequence_length, ..., Dn)`.
+
+ write_indices: (optional, non-differentiable) Write indices for the incoming
+ update tensor in the cache. Shape is `(batch_size,)`. Assumed to be all
+ zeros if not provided.
+
+ axis: Sequence dimension of the `past_cache` and `update` tensors. It cannot
+ be 0 (the batch dimension). Default is -2.
+
+ mode: Write mode of cache update. Supported modes include `linear` and
+ `circular`. `linear` mode requires
+ write_indices+sequence_length<=max_sequence_length. For `circular` mode,
+ the updates happen in wrap-around fashion, ie, the update index is
+ modulo `max_sequence_length`
+ """
+
+ schema = get_schema("TensorScatter", 24, "")
+ op = Op(self, "TensorScatter", schema)
+ return op(
+ *self._prepare_inputs(schema, past_cache, update, write_indices),
+ axis=axis,
+ mode=mode,
+ )
+
+ T_TopK = TypeVar(
+ "T_TopK",
+ BFLOAT16,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
+ )
+
+ I_TopK: TypeAlias = INT64
+
+ def TopK(
+ self, X: T_TopK, K: INT64, *, axis: int = -1, largest: int = 1, sorted: int = 1
+ ) -> Tuple[T_TopK, I_TopK]:
+ r"""[🌐 TopK(24)](https://onnx.ai/onnx/operators/onnx__TopK.html#topk-24 "Online Documentation")
+
+
+ Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of
+ shape [a_0, a_1, ..., a_{n-1}] and integer argument k, return two outputs:
+
+ * Value tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1}]
+ which contains the values of the top k elements along the specified axis
+ * Index tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1}] which
+ contains the indices of the top k elements (original indices from the input
+ tensor).
+
+ * If "largest" is 1 (the default value) then the k largest elements are returned.
+ * If "sorted" is 1 (the default value) then the resulting k elements will be sorted.
+ * If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined.
+
+ Given two equivalent values, this operator uses the indices along the axis as
+ a tiebreaker. That is, the element with the lower index will appear first.
+
+
+ Args:
+ X: (differentiable) Tensor of shape [a_0, a_1, ..., a_{n-1}]
+
+ K: (non-differentiable) A 1-D tensor containing a single positive value
+ corresponding to the number of top elements to retrieve
+
+ axis: Dimension on which to do the sort. Negative value means counting
+ dimensions from the back. Accepted range is [-r, r-1] where r =
+ rank(input).
+
+ largest: Whether to return the top-K largest or smallest elements.
+
+ sorted: Whether to return the elements in sorted order.
+ """
+
+ schema = get_schema("TopK", 24, "")
+ op = Op(self, "TopK", schema)
+ return op(
+ *self._prepare_inputs(schema, X, K),
+ axis=axis,
+ largest=largest,
+ sorted=sorted,
+ )
+
+ T_Transpose = TypeVar(
+ "T_Transpose",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Transpose(
+ self, data: T_Transpose, *, perm: Optional[Sequence[int]] = None
+ ) -> T_Transpose:
+ r"""[🌐 Transpose(24)](https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-24 "Online Documentation")
+
+
+ Transpose the input tensor similar to numpy.transpose. For example, when
+ perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape
+ will be (2, 1, 3).
+
+
+ Args:
+ data: (differentiable) An input tensor.
+
+ perm: A list of integers. By default, reverse the dimensions, otherwise
+ permute the axes according to the values given. Its length must be equal
+ to the rank of the input.
+ """
+
+ schema = get_schema("Transpose", 24, "")
+ op = Op(self, "Transpose", schema)
+ return op(*self._prepare_inputs(schema, data), perm=perm)
+
+ T_Unsqueeze = TypeVar(
+ "T_Unsqueeze",
+ BFLOAT16,
+ BOOL,
+ COMPLEX128,
+ COMPLEX64,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ FLOAT4E2M1,
+ FLOAT8E4M3FN,
+ FLOAT8E4M3FNUZ,
+ FLOAT8E5M2,
+ FLOAT8E5M2FNUZ,
+ FLOAT8E8M0,
+ INT16,
+ INT32,
+ INT4,
+ INT64,
+ INT8,
+ STRING,
+ UINT16,
+ UINT32,
+ UINT4,
+ UINT64,
+ UINT8,
+ )
+
+ def Unsqueeze(self, data: T_Unsqueeze, axes: INT64) -> T_Unsqueeze:
+ r"""[🌐 Unsqueeze(24)](https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#unsqueeze-24 "Online Documentation")
+
+
+ Insert single-dimensional entries to the shape of an input tensor (`data`).
+ Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`).
+
+ For example, given an input tensor (`data`) of shape [3, 4, 5], then
+ Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1].
+
+ The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates.
+ The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`.
+ Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1].
+ The order of values in `axes` does not matter and can come in any order.
+
+
+ Args:
+ data: (differentiable) Original tensor
+
+ axes: (non-differentiable) 1D tensor of integers indicating the dimensions
+ to be inserted. Negative value means counting dimensions from the back.
+ Accepted range is [-r, r-1] where r = rank(expanded).
+ """
+
+ schema = get_schema("Unsqueeze", 24, "")
+ op = Op(self, "Unsqueeze", schema)
+ return op(*self._prepare_inputs(schema, data, axes))
diff --git a/onnxscript/onnx_opset/_impl/opset3.py b/onnxscript/onnx_opset/_impl/opset3.py
index f9bbf5d770..fd684dd238 100644
--- a/onnxscript/onnx_opset/_impl/opset3.py
+++ b/onnxscript/onnx_opset/_impl/opset3.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset4.py b/onnxscript/onnx_opset/_impl/opset4.py
index 0a4f68981a..a1b7fb890b 100644
--- a/onnxscript/onnx_opset/_impl/opset4.py
+++ b/onnxscript/onnx_opset/_impl/opset4.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset5.py b/onnxscript/onnx_opset/_impl/opset5.py
index f445cfdce4..d7e34f8d5d 100644
--- a/onnxscript/onnx_opset/_impl/opset5.py
+++ b/onnxscript/onnx_opset/_impl/opset5.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset6.py b/onnxscript/onnx_opset/_impl/opset6.py
index 911192df22..b7b7981154 100644
--- a/onnxscript/onnx_opset/_impl/opset6.py
+++ b/onnxscript/onnx_opset/_impl/opset6.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -211,7 +210,18 @@ def BatchNormalization(
)
T2_Cast: TypeAlias = Union[
- BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
]
def Cast(self, input: T1_Cast, *, to: int) -> T2_Cast:
@@ -370,7 +380,7 @@ def Elu(self, X: T_Elu, *, alpha: float = 1.0) -> T_Elu:
Args:
- X: (differentiable) 1D input tensor
+ X: (differentiable) Input tensor
alpha: Coefficient of ELU.
"""
diff --git a/onnxscript/onnx_opset/_impl/opset7.py b/onnxscript/onnx_opset/_impl/opset7.py
index e584d06c5a..eed9bde7d2 100644
--- a/onnxscript/onnx_opset/_impl/opset7.py
+++ b/onnxscript/onnx_opset/_impl/opset7.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset8.py b/onnxscript/onnx_opset/_impl/opset8.py
index 39d01f198b..6bedb39b86 100644
--- a/onnxscript/onnx_opset/_impl/opset8.py
+++ b/onnxscript/onnx_opset/_impl/opset8.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: D402
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset9.py b/onnxscript/onnx_opset/_impl/opset9.py
index 7d99f002ff..be1cec969d 100644
--- a/onnxscript/onnx_opset/_impl/opset9.py
+++ b/onnxscript/onnx_opset/_impl/opset9.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: E741, D402
# --------------------------------------------------------------------------
from __future__ import annotations
@@ -313,7 +312,18 @@ def Constant(self, *, value: TensorProto) -> T_Constant:
T1_ConstantOfShape: TypeAlias = INT64
T2_ConstantOfShape: TypeAlias = Union[
- BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
]
def ConstantOfShape(
@@ -402,7 +412,18 @@ def Erf(self, input: T_Erf) -> T_Erf:
)
T2_EyeLike: TypeAlias = Union[
- BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8
+ BOOL,
+ DOUBLE,
+ FLOAT,
+ FLOAT16,
+ INT16,
+ INT32,
+ INT64,
+ INT8,
+ UINT16,
+ UINT32,
+ UINT64,
+ UINT8,
]
def EyeLike(
@@ -633,7 +654,7 @@ def MatMul(self, A: T_MatMul, B: T_MatMul) -> T_MatMul:
r"""[🌐 MatMul(9)](https://onnx.ai/onnx/operators/onnx__MatMul.html#matmul-9 "Online Documentation")
- Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
Args:
@@ -1142,7 +1163,12 @@ def Scan(
Tind_Scatter = TypeVar("Tind_Scatter", INT32, INT64)
def Scatter(
- self, data: T_Scatter, indices: Tind_Scatter, updates: T_Scatter, *, axis: int = 0
+ self,
+ data: T_Scatter,
+ indices: Tind_Scatter,
+ updates: T_Scatter,
+ *,
+ axis: int = 0,
) -> T_Scatter:
r"""[🌐 Scatter(9)](https://onnx.ai/onnx/operators/onnx__Scatter.html#scatter-9 "Online Documentation")
diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py
index a190eb17f9..d69cc686a0 100644
--- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py
+++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: N801, D417
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py
index a78e3ae551..49b38d3344 100644
--- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py
+++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: N801
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py
index 0092b4fd40..57c0d90a4e 100644
--- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py
+++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: N801
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py
index 552e545d75..02dc271c6e 100644
--- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py
+++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py
@@ -2,13 +2,12 @@
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
+# ruff: noqa: N801
# --------------------------------------------------------------------------
from __future__ import annotations
diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py
new file mode 100644
index 0000000000..d3f3f0b5cc
--- /dev/null
+++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py
@@ -0,0 +1,157 @@
+# --------------------------------------------------------------------------
+# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
+# ⚙️ Generated by 'python -m opgen'
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+# pylint: disable=W0221,W0222,R0901,W0237
+# mypy: disable-error-code=override
+# ruff: noqa: N801
+# --------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from typing import Optional, Sequence, TypeVar
+
+from onnx import TensorProto
+from onnx.defs import get_schema
+
+from onnxscript.onnx_opset._impl.opset_ai_onnx_ml4 import Opset_ai_onnx_ml4
+from onnxscript.onnx_types import DOUBLE, FLOAT, FLOAT16
+from onnxscript.values import Op, Opset
+
+
+class Opset_ai_onnx_ml5(Opset_ai_onnx_ml4):
+ def __new__(cls):
+ return Opset.__new__(cls, "ai.onnx.ml", 5)
+
+ T_TreeEnsemble = TypeVar("T_TreeEnsemble", DOUBLE, FLOAT, FLOAT16)
+
+ def TreeEnsemble(
+ self,
+ X: T_TreeEnsemble,
+ *,
+ aggregate_function: int = 1,
+ leaf_targetids: Sequence[int],
+ leaf_weights: TensorProto,
+ membership_values: Optional[TensorProto] = None,
+ n_targets: Optional[int] = None,
+ nodes_falseleafs: Sequence[int],
+ nodes_falsenodeids: Sequence[int],
+ nodes_featureids: Sequence[int],
+ nodes_hitrates: Optional[TensorProto] = None,
+ nodes_missing_value_tracks_true: Optional[Sequence[int]] = None,
+ nodes_modes: TensorProto,
+ nodes_splits: TensorProto,
+ nodes_trueleafs: Sequence[int],
+ nodes_truenodeids: Sequence[int],
+ post_transform: int = 0,
+ tree_roots: Sequence[int],
+ ) -> T_TreeEnsemble:
+ r"""[🌐 ai.onnx.ml::TreeEnsemble(5)](https://onnx.ai/onnx/operators/onnx_aionnxml_TreeEnsemble.html#treeensemble-5 "Online Documentation")
+
+
+ Tree Ensemble operator. Returns the regressed values for each input in a batch.
+ Inputs have dimensions `[N, F]` where `N` is the input batch size and `F` is the number of input features.
+ Outputs have dimensions `[N, num_targets]` where `N` is the batch size and `num_targets` is the number of targets, which is a configurable attribute.
+
+ The encoding of this attribute is split along interior nodes and the leaves of the trees. Notably, attributes with the prefix `nodes_*` are associated with interior nodes, and attributes with the prefix `leaf_*` are associated with leaves.
+ The attributes `nodes_*` must all have the same length and encode a sequence of tuples, as defined by taking all the `nodes_*` fields at a given position.
+
+ All fields prefixed with `leaf_*` represent tree leaves, and similarly define tuples of leaves and must have identical length.
+
+ This operator can be used to implement both the previous `TreeEnsembleRegressor` and `TreeEnsembleClassifier` nodes.
+ The `TreeEnsembleRegressor` node maps directly to this node and requires changing how the nodes are represented.
+ The `TreeEnsembleClassifier` node can be implemented by adding a `ArgMax` node after this node to determine the top class.
+ To encode class labels, a `LabelEncoder` or `GatherND` operator may be used.
+
+
+ Args:
+ X: Input of shape [Batch Size, Number of Features]
+
+ aggregate_function: Defines how to aggregate leaf values within a target.
+
One of 'AVERAGE' (0) 'SUM' (1) 'MIN' (2) 'MAX (3) defaults to 'SUM'
+ (1)
+
+ leaf_targetids: The index of the target that this leaf contributes to (this
+ must be in range `[0, n_targets)`).
+
+ leaf_weights: The weight for each leaf.
+
+ membership_values: Members to test membership of for each set membership
+ node. List all of the members to test again in the order that the
+ 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will
+ have the same number of sets of values as nodes with mode
+ 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any
+ 'BRANCH_MEMBER' nodes.
+
+ n_targets: The total number of targets.
+
+ nodes_falseleafs: 1 if false branch is leaf for each node and 0 if an
+ interior node. To represent a tree that is a leaf (only has one node),
+ one can do so by having a single `nodes_*` entry with true and false
+ branches referencing the same `leaf_*` entry
+
+ nodes_falsenodeids: If `nodes_falseleafs` is false at an entry, this
+ represents the position of the false branch node. This position can be
+ used to index into a `nodes_*` entry. If `nodes_falseleafs` is false, it
+ is an index into the leaf_* attributes.
+
+ nodes_featureids: Feature id for each node.
+
+ nodes_hitrates: Popularity of each node, used for performance and may be
+ omitted.
+
+ nodes_missing_value_tracks_true: For each node, define whether to follow the
+ true branch (if attribute value is 1) or false branch (if attribute
+ value is 0) in the presence of a NaN input feature. This attribute may
+ be left undefined and the default value is false (0) for all nodes.
+
+ nodes_modes: The comparison operation performed by the node. This is encoded
+ as an enumeration of 0 ('BRANCH_LEQ'), 1 ('BRANCH_LT'), 2
+ ('BRANCH_GTE'), 3 ('BRANCH_GT'), 4 ('BRANCH_EQ'), 5 ('BRANCH_NEQ'), and
+ 6 ('BRANCH_MEMBER'). Note this is a tensor of type uint8.
+
+ nodes_splits: Thresholds to do the splitting on for each node with mode that
+ is not 'BRANCH_MEMBER'.
+
+ nodes_trueleafs: 1 if true branch is leaf for each node and 0 an interior
+ node. To represent a tree that is a leaf (only has one node), one can do
+ so by having a single `nodes_*` entry with true and false branches
+ referencing the same `leaf_*` entry
+
+ nodes_truenodeids: If `nodes_trueleafs` is false at an entry, this
+ represents the position of the true branch node. This position can be
+ used to index into a `nodes_*` entry. If `nodes_trueleafs` is false, it
+ is an index into the leaf_* attributes.
+
+ post_transform: Indicates the transform to apply to the score.
One of
+ 'NONE' (0), 'SOFTMAX' (1), 'LOGISTIC' (2), 'SOFTMAX_ZERO' (3) or
+ 'PROBIT' (4), defaults to 'NONE' (0)
+
+ tree_roots: Index into `nodes_*` for the root of each tree. The tree
+ structure is derived from the branching of each node.
+ """
+
+ schema = get_schema("TreeEnsemble", 5, "ai.onnx.ml")
+ op = Op(self, "TreeEnsemble", schema)
+ return op(
+ *self._prepare_inputs(schema, X),
+ aggregate_function=aggregate_function,
+ leaf_targetids=leaf_targetids,
+ leaf_weights=leaf_weights,
+ membership_values=membership_values,
+ n_targets=n_targets,
+ nodes_falseleafs=nodes_falseleafs,
+ nodes_falsenodeids=nodes_falsenodeids,
+ nodes_featureids=nodes_featureids,
+ nodes_hitrates=nodes_hitrates,
+ nodes_missing_value_tracks_true=nodes_missing_value_tracks_true,
+ nodes_modes=nodes_modes,
+ nodes_splits=nodes_splits,
+ nodes_trueleafs=nodes_trueleafs,
+ nodes_truenodeids=nodes_truenodeids,
+ post_transform=post_transform,
+ tree_roots=tree_roots,
+ )
diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_preview_training1.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_preview_training1.py
deleted file mode 100644
index cb201bdf97..0000000000
--- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_preview_training1.py
+++ /dev/null
@@ -1,577 +0,0 @@
-# --------------------------------------------------------------------------
-# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
-# ⚙️ Generated by 'python -m opgen'
-# --------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-# pylint: disable=W0221,W0222,R0901,W0237
-# mypy: disable-error-code=override
-# ruff: noqa: N801,E741
-# ruff: noqa: D214,D402,D405,D411,D412,D416,D417
-# --------------------------------------------------------------------------
-
-from __future__ import annotations
-
-from typing import Optional, Sequence, TypeVar, Union
-
-from onnx.defs import get_schema
-from typing_extensions import TypeAlias
-
-from onnxscript.onnx_types import (
- BOOL,
- COMPLEX64,
- COMPLEX128,
- DOUBLE,
- FLOAT,
- FLOAT16,
- INT8,
- INT16,
- INT32,
- INT64,
- STRING,
- UINT8,
- UINT16,
- UINT32,
- UINT64,
-)
-from onnxscript.values import Op, Opset
-
-
-class Opset_ai_onnx_preview_training1(Opset):
- def __new__(cls):
- return Opset.__new__(cls, "ai.onnx.preview.training", 1)
-
- T1_Adagrad = TypeVar("T1_Adagrad", DOUBLE, FLOAT)
-
- T2_Adagrad: TypeAlias = INT64
-
- T3_Adagrad = TypeVar("T3_Adagrad", DOUBLE, FLOAT)
-
- def Adagrad(
- self,
- R: T1_Adagrad,
- T: T2_Adagrad,
- *inputs: T3_Adagrad,
- decay_factor: float = 0.0,
- epsilon: float = 9.999999974752427e-07,
- norm_coefficient: float = 0.0,
- ) -> T3_Adagrad:
- r"""[🌐 ai.onnx.preview.training::Adagrad(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Adagrad.html#adagrad-1 "Online Documentation")
-
-
- Compute one iteration of ADAGRAD, a stochastic gradient based optimization
- algorithm. This operator can conduct the optimization of multiple tensor variables.
-
- Let's define the behavior of this operator. As you can imagine, ADAGRAD requires
- some parameters:
-
- - The initial learning-rate "R".
- - The update count "T". That is, the number of training iterations conducted.
- - A L2-norm regularization coefficient "norm_coefficient".
- - A learning-rate decay factor "decay_factor".
- - A small constant "epsilon" to avoid dividing-by-zero.
-
- At each ADAGRAD iteration, the optimized tensors are moved along a direction
- computed based on their estimated gradient and accumulated squared gradient. Assume
- that only a single tensor "X" is updated by this operator. We need the value of "X",
- its gradient "G", and its accumulated squared gradient "H". Therefore, variables in
- this operator's input list are sequentially "R", "T", "X", "G", and "H". Other
- parameters are given as attributes because they are usually constants. Also, the
- corresponding output tensors are the new value of "X" (called "X_new"), and then
- the new accumulated squared gradient (called "H_new"). Those outputs are computed
- from the given inputs following the pseudo code below.
-
- Let "+", "-", "*", and "/" are all element-wise arithmetic operations with
- numpy-style broadcasting support. The pseudo code to compute those outputs is:
-
- // Compute a scalar learning-rate factor. At the first update of X, T is generally
- // 0 (0-based update index) or 1 (1-based update index).
- r = R / (1 + T * decay_factor);
-
- // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm.
- G_regularized = norm_coefficient * X + G;
-
- // Compute new accumulated squared gradient.
- H_new = H + G_regularized * G_regularized;
-
- // Compute the adaptive part of per-coordinate learning rate. Note that Sqrt(...)
- // computes element-wise square-root.
- H_adaptive = Sqrt(H_new) + epsilon
-
- // Compute the new value of "X".
- X_new = X - r * G_regularized / H_adaptive;
-
- If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2", the same
- pseudo code may be extended to handle all tensors jointly. More specifically, we can view "X" as a
- concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should
- be concatenated too) and then just reuse the entire pseudo code.
-
- Note that ADAGRAD was first proposed in http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
- In that reference paper, this operator is a special case of the Figure 1's composite mirror
- descent update.
-
-
- Args:
- R: The initial learning rate.
-
- T: The update count of "X". It should be a scalar.
-
- inputs: (variadic, heterogeneous) The current values of optimized tensors,
- followed by their respective gradients, followed by their respective
- accumulated squared gradients.For example, if two tensor "X_1" and "X_2"
- are optimized, The input list would be ["X_1", "X_2", gradient of "X_1",
- gradient of "X_2", accumulated squared gradient of "X_1", accumulated
- squared gradient of "X_2"].
-
- decay_factor: The decay factor of learning rate after one update.The
- effective learning rate is computed by r = R / (1 + T * decay_factor).
- Default to 0 so that increasing update counts doesn't reduce the
- learning rate.
-
- epsilon: Small scalar to avoid dividing by zero.
-
- norm_coefficient: Regularization coefficient in 0.5 * norm_coefficient *
- ||X||_2^2. Default to 0, which means no regularization.
- """
-
- schema = get_schema("Adagrad", 1, "ai.onnx.preview.training")
- op = Op(self, "Adagrad", schema)
- return op(
- *self._prepare_inputs(schema, R, T, *inputs),
- decay_factor=decay_factor,
- epsilon=epsilon,
- norm_coefficient=norm_coefficient,
- )
-
- T1_Adam = TypeVar("T1_Adam", DOUBLE, FLOAT)
-
- T2_Adam: TypeAlias = INT64
-
- T3_Adam = TypeVar("T3_Adam", DOUBLE, FLOAT)
-
- def Adam(
- self,
- R: T1_Adam,
- T: T2_Adam,
- *inputs: T3_Adam,
- alpha: float = 0.8999999761581421,
- beta: float = 0.9990000128746033,
- epsilon: float = 9.999999974752427e-07,
- norm_coefficient: float = 0.0,
- norm_coefficient_post: float = 0.0,
- ) -> T3_Adam:
- r"""[🌐 ai.onnx.preview.training::Adam(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Adam.html#adam-1 "Online Documentation")
-
-
- Compute one iteration of Adam, a stochastic gradient based optimization
- algorithm. This operator can conduct the optimization of multiple tensor variables.
-
- Let's define the behavior of this operator. First of all, Adam requires
- some parameters:
-
- - The learning-rate "R".
- - The update count "T". That is, the number of training iterations conducted.
- - A L2-norm regularization coefficient "norm_coefficient".
- - A small constant "epsilon" to avoid dividing-by-zero.
- - Two coefficients, "alpha" and "beta".
-
- At each Adam iteration, the optimized tensors are moved along a direction
- computed based on their exponentially-averaged historical gradient and
- exponentially-averaged historical squared gradient. Assume that only a tensor
- "X" is being optimized. The rest of required information is
-
- - the value of "X",
- - "X"'s gradient (denoted by "G"),
- - "X"'s exponentially-averaged historical gradient (denoted by "V"), and
- - "X"'s exponentially-averaged historical squared gradient (denoted by "H").
-
- Some of those parameters are passed into this operator as input tensors and others
- are stored as this operator's attributes. Specifically, this operator's input tensor
- list is ["R", "T", "X", "G", "V", "H"]. That is, "R" is the first input, "T" is
- the second input, and so on. Other parameters are given as attributes because they
- are constants. Moreover, the corresponding output tensors are
-
- - the new value of "X" (called "X_new"),
- - the new exponentially-averaged historical gradient (denoted by "V_new"), and
- - the new exponentially-averaged historical squared gradient (denoted by "H_new").
-
- Those outputs are computed following the pseudo code below.
-
- Let "+", "-", "*", and "/" are all element-wise arithmetic operations with
- numpy-style broadcasting support. The pseudo code to compute those outputs is:
-
- // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm.
- G_regularized = norm_coefficient * X + G
-
- // Update exponentially-averaged historical gradient.
- V_new = alpha * V + (1 - alpha) * G_regularized
-
- // Update exponentially-averaged historical squared gradient.
- H_new = beta * H + (1 - beta) * G_regularized * G_regularized
-
- // Compute the element-wise square-root of H_new. V_new will be element-wisely
- // divided by H_sqrt for a better update direction.
- H_sqrt = Sqrt(H_new) + epsilon
-
- // Compute learning-rate. Note that "alpha**T"/"beta**T" is alpha's/beta's T-th power.
- R_adjusted = T > 0 ? R * Sqrt(1 - beta**T) / (1 - alpha**T) : R
-
- // Compute new value of "X".
- X_new = X - R_adjusted * V_new / H_sqrt
-
- // Post-update regularization.
- X_final = (1 - norm_coefficient_post) * X_new
-
- If there are multiple inputs to be optimized, the pseudo code will be applied
- independently to each of them.
-
-
- Args:
- R: The initial learning rate.
-
- T: The update count of "X". It should be a scalar.
-
- inputs: (variadic, heterogeneous) The tensors to be optimized, followed by
- their respective gradients, followed by their respective accumulated
- gradients (aka momentum), followed by their respective accumulated
- squared gradients. For example, to optimize tensors "X_1" and "X_2,",
- the input list would be ["X_1", "X_2", gradient of "X_1", gradient of
- "X_2", accumulated gradient of "X_1", accumulated gradient of "X_2",
- accumulated squared gradient of "X_1", accumulated squared gradient of
- "X_2"].
-
- alpha: Coefficient of previously accumulated gradient in running average.
- Default to 0.9.
-
- beta: Coefficient of previously accumulated squared-gradient in running
- average. Default to 0.999.
-
- epsilon: Small scalar to avoid dividing by zero.
-
- norm_coefficient: Regularization coefficient of 0.5 * norm_coefficient *
- ||X||_2^2. Default to 0, which means no regularization.
-
- norm_coefficient_post: Regularization coefficient of 0.5 * norm_coefficient
- * ||X||_2^2. Default to 0, which means no regularization.
- """
-
- schema = get_schema("Adam", 1, "ai.onnx.preview.training")
- op = Op(self, "Adam", schema)
- return op(
- *self._prepare_inputs(schema, R, T, *inputs),
- alpha=alpha,
- beta=beta,
- epsilon=epsilon,
- norm_coefficient=norm_coefficient,
- norm_coefficient_post=norm_coefficient_post,
- )
-
- T1_Gradient = TypeVar(
- "T1_Gradient",
- BOOL,
- COMPLEX128,
- COMPLEX64,
- DOUBLE,
- FLOAT,
- FLOAT16,
- INT16,
- INT32,
- INT64,
- INT8,
- STRING,
- UINT16,
- UINT32,
- UINT64,
- UINT8,
- )
-
- T2_Gradient: TypeAlias = Union[DOUBLE, FLOAT, FLOAT16]
-
- def Gradient(
- self,
- *Inputs: T1_Gradient,
- xs: Sequence[str],
- y: str,
- zs: Optional[Sequence[str]] = None,
- ) -> T2_Gradient:
- r"""[🌐 ai.onnx.preview.training::Gradient(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Gradient.html#gradient-1 "Online Documentation")
-
-
- Gradient operator computes the partial derivatives of a specific tensor w.r.t.
- some other tensors. This operator is widely used in gradient-based training
- algorithms. To illustrate its use, let's consider a computation graph,
-
- ::
-
- X -----.
- |
- v
- W --> Conv --> H --> Gemm --> Y
- ^
- |
- Z
-
-
-
- , where W and Z are trainable tensors. Note that operators' attributes are
- omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of
- Y with respect to W (Z). The user can compute gradient by inserting Gradient
- operator to form another graph shown below.
-
- ::
-
- W --> Conv --> H --> Gemm --> Y
- | ^ ^
- | | |
- | X Z
- | | |
- | | .----------'
- | | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in
- | | | "xs" followed by "zs")
- | v v
- '---> Gradient(xs=["W", "Z"], zs=["X"], y="Y")
- | |
- | '-----------------------------------> dY/dW (1st output of Gradient)
- |
- '---------------------------------------> dY/dZ (2nd output of Gradient)
-
-
-
- By definition, the tensor "y" is a function of independent variables in "xs"
- and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable
- variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H"
- cannot appear in "xs" and "zs". The reason is that "H" can be determined by
- tensors "W" and "X" and therefore "H" is not an independent variable.
-
- All outputs are optional. If needed, for example, user can assign an empty
- string to the 1st output name of that Gradient to skip the generation of dY/dW.
- Note that the concept of optional outputs can also be found in ONNX's RNN, GRU,
- and LSTM.
-
- Gradient operator can compute derivative against intermediate tensors. For
- example, the gradient of Y with respect to H can be done via
-
- ::
-
- W --> Conv --> H --> Gemm --> Y
- ^ | ^
- | | |
- X | Z
- .-------' |
- | .----------'
- | | (H/Z is the 1st/2nd input of Gradient as shown in "xs")
- v v
- Gradient(xs=["H", "Z"], y="Y")
- | |
- | '-----------------------------------> dY/dH (1st output of Gradient)
- |
- '---------------------------------------> dY/dZ (2nd output of Gradient)
-
-
-
- It is possible to represent high-order differentiation using Gradient operators.
- For example, given the following linear model:
-
- ::
-
- W --> Gemm --> Y --> Loss --> O
- ^ ^
- | |
- X L
-
-
-
- To compute the 2nd order derivative of O with respect to W (denoted by
- d^2O/dW^2), one can do
-
- ::
-
- W --> Gemm --> Y --> Loss --> O
- | ^ ^
- | | |
- | X .------------L
- | | | |
- | | | v
- +------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient)
- | | | |
- | | | '---> dO/dW (2nd output of Gradient)
- | v v
- '---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of
- | Gradient)
- |
- |
- '---> d^2O/dW^2 (2nd output of Gradient)
-
-
-
- The tensors named in attributes "xs", "zs", and "y" define the differentiated
- computation graph, and the inputs to Gradient node define the values at
- which the gradient is computed. We can feed different tensors to the identified
- graph. For example, one can compute the gradient of Y with respect to H at
- a specific value of H, H_1, by providing that value as an input to the Gradient
- node.
-
- ::
-
- W --> Conv --> H --> Gemm --> Y
- ^ ^
- | |
- X Z
-
- Z_1 (2nd input of Gradient)
- |
- v
- H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1.
- |
- '------------------------------> dY/dZ (2nd output of Gradient)
-
-
-
- When the inputs of Gradient are the tensors named in "xs" and "zs", the
- computation can be optimized. More specifically, intermediate variables in
- forward pass can be reused if the gradient is computed via reverse-mode
- auto-differentiation.
-
-
-
- Args:
- Inputs: (variadic, heterogeneous) The values fed into graph identified by
- the attributes. The i-th input is the value of the i-th tensor specified
- in the concatenated list of the attribute "xs" and the attribute "zs".
- For example, if xs=["A", "B"] and zs=["C"], the first input is used as
- the value of symbol "A" and the 3rd input is substituted for all the
- occurrences of "C".
-
- xs: Input tensor names of the differentiated sub-graph. It contains only the
- necessary differentiated inputs of a (sub-)graph. Variables (usually
- called intermediate variables) that can be generated from inputs cannot
- be included in this attribute.
-
- y: The targeted tensor. It can be viewed as the output of the differentiated
- function. The attribute "xs" and attribute "zs" are the minimal
- independent variable set that determines the value of "y".
-
- zs: Input tensor names of the differentiated sub-graph. It contains only the
- necessary non-differentiated inputs of a (sub-)graph. Variables (usually
- called intermediate variables) that can be generated from inputs cannot
- be included in this attribute.
- """
-
- schema = get_schema("Gradient", 1, "ai.onnx.preview.training")
- op = Op(self, "Gradient", schema)
- return op(*self._prepare_inputs(schema, *Inputs), xs=xs, y=y, zs=zs)
-
- T1_Momentum = TypeVar("T1_Momentum", DOUBLE, FLOAT)
-
- T2_Momentum: TypeAlias = INT64
-
- T3_Momentum = TypeVar("T3_Momentum", DOUBLE, FLOAT)
-
- def Momentum(
- self,
- R: T1_Momentum,
- T: T2_Momentum,
- *inputs: T3_Momentum,
- alpha: float,
- beta: float,
- mode: str,
- norm_coefficient: float,
- ) -> T3_Momentum:
- r"""[🌐 ai.onnx.preview.training::Momentum(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Momentum.html#momentum-1 "Online Documentation")
-
-
- Compute one iteration of stochastic gradient update with momentum.
- This operator can conduct the optimization of multiple tensor variables.
-
- Let's define the behavior of this operator. As you can imagine, SG with momentum requires
- several parameters:
-
- - The learning-rate "R".
- - The update count "T". That is, the number of conducted training iterations. It should
- be zero in the first training iteration.
- - A L2-norm regularization coefficient "norm_coefficient".
- - A decay coefficient of previous accumulated gradient (i.e., momentum) "alpha".
- - The scaling coefficient of current gradient "beta".
- - An attribute to choose either standard momentum or Nesterov's momentum "mode" should
- be used.
-
- For the sake of simplicity, assume that there is only one tensor (called "X") to be optimized.
- Other necessary inputs are "X"'s gradient (called "G") and "X"'s momentum (called "V"). This
- Momentum operator maps all these inputs to the new value of "X" (called "X_new") and its new
- momentum (called "V_new").
-
- This operator supports two different momentum algorithms. Set the attribute "mode" to
- "nesterov" if Nesterov's momentum is desired. Otherwise, set the attribute "model" to
- "standard" to use standard momentum. Computation details are described subsequently.
-
- Let "+", "-", "*", and "/" are all element-wise operations with numpy-style broadcasting.
-
- Pseudo code for SG with standard momentum:
-
- // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared
- // values of all elements in X.
- G_regularized = norm_coefficient * X + G
-
- // In the first training iteration, beta should always be 1.
- beta_adjusted = T > 0 ? beta : 1
-
- // Compute the current momentum based on previous momentum and the current gradient.
- V_new = alpha * V + beta_adjusted * G_regularized
-
- // Update X.
- X_new = X - R * V_new
-
- Pseudo code for SG with Nesterov's momentum:
-
- // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared
- // values of all elements in X.
- G_regularized = norm_coefficient * X + G;
-
- // In the first training iteration, beta should always be 1.
- beta_adjusted = T > 0 ? beta : 1
-
- // Compute the current momentum based on previous momentum and the current gradient.
- V_new = alpha * V + beta_adjusted * G_regularized;
-
- // Compute final update direction and then update X.
- X_new = X - R * (G_regularized + alpha * V_new)
-
- If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2". The same
- pseudo code would be extended to handle all tensors jointly. More specifically, we can view "X" as a
- concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should
- be concatenated too) and then our pseudo code becomes applicable.
-
-
- Args:
- R: The learning rate.
-
- T: Update count of "X". It should be a scalar.
-
- inputs: (variadic, heterogeneous) It sequentially contains the current
- values of optimized tensors, then their gradient tensors, and finally
- their momentum tensors. For example, if two tensors "X_1" and "X_2" are
- optimized, The expected input list would be ["X_1", "X_2", gradient of
- "X_1", gradient of "X_2", momentum of "X_1", momentum of "X_2"].
-
- alpha: The decay factor of momentum. It should be a scalar.
-
- beta: The coefficient of gradient in computing new momentum. It should be a
- scalar.
-
- mode: Its value should be either "nesterov" or "standard". The value
- "nesterov" leads to the use of Nesterov's momentum while "standard"
- invokes stochastic gradient method using standard momentum
-
- norm_coefficient: Coefficient of 0.5 * norm_coefficient * ||X||^2.
- """
-
- schema = get_schema("Momentum", 1, "ai.onnx.preview.training")
- op = Op(self, "Momentum", schema)
- return op(
- *self._prepare_inputs(schema, R, T, *inputs),
- alpha=alpha,
- beta=beta,
- mode=mode,
- norm_coefficient=norm_coefficient,
- )
diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py
index 6af57d4b1d..9642e3f111 100644
--- a/onnxscript/onnx_types.py
+++ b/onnxscript/onnx_types.py
@@ -1,7 +1,5 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
+# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-# --------------------------------------------------------------------------
from __future__ import annotations
@@ -9,31 +7,27 @@
from typing import ClassVar, Optional, Tuple, Union
import onnx
-import onnx.helper
+import onnx_ir as ir
-DType = onnx.TensorProto.DataType
+_DType = ir.DataType
+_DimType = Union[int, str, type(None)]
+_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)]
-DimType = Union[int, str, type(None)]
+_tensor_type_shape_cache: dict[_DType, TensorType] = {}
+tensor_type_registry: dict[_DType, TensorType] = {}
-def check_dim(dim):
+def _check_dim(dim):
if not isinstance(dim, (int, str, type(None))):
raise TypeError(f"Invalid dimension {dim}")
-ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)]
-
-
-def check_shape(shape):
+def _check_shape(shape):
if isinstance(shape, tuple):
for dim in shape:
- check_dim(dim)
+ _check_dim(dim)
elif shape != Ellipsis:
- check_dim(shape)
-
-
-tensor_type_registry: dict[DType, TensorType] = {}
-_tensor_type_shape_cache: dict[DType, TensorType] = {}
+ _check_dim(shape)
class TensorType(abc.ABC):
@@ -60,13 +54,13 @@ class TensorType(abc.ABC):
tensor: FLOAT[128, 1024]
"""
- dtype: ClassVar[DType]
- shape: ClassVar[Optional[ShapeType]]
+ dtype: ClassVar[_DType]
+ shape: ClassVar[Optional[_ShapeType]]
def __new__(cls):
raise NotImplementedError("TensorTypes cannot be instantiated")
- def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
+ def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None):
cls.dtype = dtype
cls.shape = shape
if shape is None:
@@ -78,9 +72,9 @@ def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
)
tensor_type_registry[dtype] = cls
else:
- check_shape(shape)
+ _check_shape(shape)
- def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]:
+ def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]:
if cls.shape is not None:
raise ValueError("Invalid usage: shape already specified.")
if shape is None:
@@ -103,98 +97,116 @@ def to_type_proto(cls) -> onnx.TypeProto:
shape = cls.shape # example: "FLOAT[10,20]"
else:
shape = [cls.shape] # example: "FLOAT[10]"
- return onnx.helper.make_tensor_type_proto(cls.dtype, shape)
+ return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251
@classmethod
def to_string(cls) -> str:
return f"tensor({cls.__name__.lower()})"
-class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
+class FLOAT(TensorType, dtype=ir.DataType.FLOAT):
+ pass
+
+
+class UINT8(TensorType, dtype=ir.DataType.UINT8):
+ pass
+
+
+class INT8(TensorType, dtype=ir.DataType.INT8):
+ pass
+
+
+class UINT16(TensorType, dtype=ir.DataType.UINT16):
+ pass
+
+
+class INT16(TensorType, dtype=ir.DataType.INT16):
pass
-class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
+class INT32(TensorType, dtype=ir.DataType.INT32):
pass
-class INT8(TensorType, dtype=onnx.TensorProto.INT8):
+class INT64(TensorType, dtype=ir.DataType.INT64):
pass
-class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
+class STRING(TensorType, dtype=ir.DataType.STRING):
pass
-class INT16(TensorType, dtype=onnx.TensorProto.INT16):
+class BOOL(TensorType, dtype=ir.DataType.BOOL):
pass
-class INT32(TensorType, dtype=onnx.TensorProto.INT32):
+class FLOAT16(TensorType, dtype=ir.DataType.FLOAT16):
pass
-class INT64(TensorType, dtype=onnx.TensorProto.INT64):
+class DOUBLE(TensorType, dtype=ir.DataType.DOUBLE):
pass
-class STRING(TensorType, dtype=onnx.TensorProto.STRING):
+class UINT32(TensorType, dtype=ir.DataType.UINT32):
pass
-class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
+class UINT64(TensorType, dtype=ir.DataType.UINT64):
pass
-class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
+class COMPLEX64(TensorType, dtype=ir.DataType.COMPLEX64):
pass
-class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
+class COMPLEX128(TensorType, dtype=ir.DataType.COMPLEX128):
pass
-class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
+class BFLOAT16(TensorType, dtype=ir.DataType.BFLOAT16):
pass
-class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
+class FLOAT8E4M3FN(TensorType, dtype=ir.DataType.FLOAT8E4M3FN):
pass
-class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
+class FLOAT8E4M3FNUZ(TensorType, dtype=ir.DataType.FLOAT8E4M3FNUZ):
pass
-class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
+class FLOAT8E5M2(TensorType, dtype=ir.DataType.FLOAT8E5M2):
pass
-class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
+class FLOAT8E5M2FNUZ(TensorType, dtype=ir.DataType.FLOAT8E5M2FNUZ):
pass
-class FLOAT8E4M3FN(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FN):
+class INT4(TensorType, dtype=ir.DataType.INT4):
pass
-class FLOAT8E4M3FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FNUZ):
+class UINT4(TensorType, dtype=ir.DataType.UINT4):
pass
-class FLOAT8E5M2(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2):
+class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1):
pass
-class FLOAT8E5M2FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2FNUZ):
+class FLOAT8E8M0(TensorType, dtype=ir.DataType.FLOAT8E8M0):
pass
-def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
+def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto, *, reversible: bool = True) -> str:
"""Converts an onnx type into the string representation of the type in *onnxscript*.
Args:
onnx_type: an instance of onnx TypeProto
+ reversible: if True, the conversion produces only types that are
+ recognized by the onnxscript converter.
Returns:
The string representation of the type in onnxscript
@@ -218,6 +230,10 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
return name
return f"{name}[{','.join(shape)}]"
return f"{name}[...]"
+ if not reversible:
+ if onnx_type.HasField("sequence_type"):
+ elem_type = onnx_type.sequence_type.elem_type
+ return f"List[{onnx_type_to_onnxscript_repr(elem_type)}]"
raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.")
diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py
index 03c1e748eb..978a1b4d65 100644
--- a/onnxscript/optimizer/__init__.py
+++ b/onnxscript/optimizer/__init__.py
@@ -1,114 +1,131 @@
-import logging
-from typing import Any
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+
+from typing import TypeVar
+
+__all__ = [
+ "basic_constant_propagation",
+ "fold_constants_ir",
+ "fold_constants",
+ "inline",
+ "optimize_ir",
+ "optimize",
+ "remove_unused_nodes",
+]
import onnx
-import onnx.shape_inference
-
-from onnxscript import rewriter
-from onnxscript.optimizer.constant_folding import fold_constants
-from onnxscript.optimizer.remove_unused import remove_unused_nodes
-from onnxscript.optimizer.remove_unused_function import remove_unused_functions
-from onnxscript.optimizer.simple_function_folding import (
- inline_functions_with_unused_outputs,
- inline_simple_functions,
-)
-from onnxscript.rewriter import (
- broadcast_to_matmul,
- cast_constant_of_shape,
- gemm_to_matmul_add,
- no_op,
-)
-
-logger = logging.getLogger(__name__)
+import onnx_ir.passes.common as common_passes
+
+import onnxscript.optimizer._constant_folding as constant_folding
+from onnxscript import ir
+from onnxscript.optimizer._constant_folding import basic_constant_propagation
+from onnxscript.optimizer._constant_folding import fold_constants as fold_constants_ir
+from onnxscript.optimizer._optimizer import optimize_ir
+
+_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
def optimize(
- model: onnx.ModelProto,
+ model: _ModelProtoOrIr,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
- external_data_folder: str = "",
- **kwargs: Any,
-) -> onnx.ModelProto:
- """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc.
+ input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
+ output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
+ inline: bool = True,
+) -> _ModelProtoOrIr:
+ """Optimizes a model.
Args:
- model (onnx.ModelProto): The model to optimize.
- num_iterations (int, optional): Number of iterations to perform.
- onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model.
- Set this to False to turn off onnx shape inference, and rely on model carried shapes and types.
- This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries
- the symbolic shapes recorded from dynamo tracing.
- stop_if_no_change (bool, optional): Whether to stop if no change is detected.
- external_data_folder (str, optional): The folder to store external data.
- **kwargs: Additional keyword arguments. For BC purposes.
+ model: The model to be optimized.
+ num_iterations: Number of times the optimization loop is repeated.
+ onnx_shape_inference: Applies node-level shape-inference as part of optimization
+ input_size_limit: Will not apply constant folding to ops with any input of size
+ greater than this. Does not apply to special ops like Shape() and Size().
+ output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
+ of the output tensor is greater than this.
+ stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
+ inline: If True, inlines all functions in the model.
+
+ Returns:
+ The optimized model. If the input was a ModelProto, the output will also be a
+ ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
"""
- if kwargs.pop("function_aware_folding", None) is not None:
- logger.warning(
- "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. "
- "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. "
- "This would turn off incremental onnx shape inference and rely on model carried shapes and types. "
- "See 'onnx_shape_inference' for more details."
+ if isinstance(model, ir.Model):
+ # In this case, optimize is done inplace.
+ # TODO(justinchuby): Maybe make functional
+ optimize_ir(
+ model,
+ num_iterations=num_iterations,
+ onnx_shape_inference=onnx_shape_inference,
+ stop_if_no_change=stop_if_no_change,
+ input_size_limit=input_size_limit,
+ output_size_limit=output_size_limit,
+ inline=inline,
)
- for _ in range(num_iterations):
- if onnx_shape_inference:
- if model.ByteSize() < 1024 * 1024 * 1024 * 2:
- # NOTE: strict mode is disabled because it crashes on the models
- # that have different shapes inferred from the model carried shapes.
- # The case can be found in:
- # https://github.com/microsoft/onnxscript/issues/1443
- model = onnx.shape_inference.infer_shapes(
- model, check_type=True, strict_mode=False, data_prop=True
- )
- else:
- logger.warning(
- "The model size is too large for full model shape inference. "
- "Skipping this step."
- )
-
- inline_simple_functions(model)
- modified = fold_constants(
- model, external_data_folder, onnx_shape_inference=onnx_shape_inference
+ return model
+ else:
+ assert isinstance(model, onnx.ModelProto)
+ model_ir = ir.serde.deserialize_model(model)
+ optimize_ir(
+ model_ir,
+ num_iterations=num_iterations,
+ onnx_shape_inference=onnx_shape_inference,
+ stop_if_no_change=stop_if_no_change,
+ input_size_limit=input_size_limit,
+ output_size_limit=output_size_limit,
+ inline=inline,
)
+ # Move the model back to the proto
+ new_proto = ir.serde.serialize_model(model_ir)
+ return new_proto
- remove_unused_nodes(model)
- inline_simple_functions(model)
- remove_unused_functions(model)
- inline_functions_with_unused_outputs(model)
- # NOTE: This is general rewrite rules
- model = rewriter.rewrite(
- model,
- pattern_rewrite_rules=[
- *no_op.rules.rules, # TODO: merge this rule into constant folding?
- *broadcast_to_matmul.rules.rules,
- gemm_to_matmul_add.rule,
- *cast_constant_of_shape.rules.rules,
- ],
- )
- if stop_if_no_change and not modified:
- logger.debug("Stopping after %d iterations.", _)
- break
- for node in model.graph.node:
- logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name)
+def inline(model: ir.Model) -> None:
+ """Inline all function calls (recursively) in the model."""
+ if model.functions:
+ common_passes.InlinePass()(model)
- for function in model.functions:
- for node in function.node:
- logger.debug(
- "Function %s::%s node %s::%s name %s.",
- function.domain,
- function.name,
- node.domain,
- node.op_type,
- node.name,
- )
- return model
+def fold_constants(
+ model: ir.Model | onnx.ModelProto, *args, **kwargs
+) -> constant_folding.FoldConstantsResult:
+ """Fold constants in a model in place."""
+ if isinstance(model, ir.Model):
+ return constant_folding.fold_constants(model, *args, **kwargs)
+ else:
+ assert isinstance(model, onnx.ModelProto)
+ model_proto = model
+ model = ir.serde.deserialize_model(model_proto)
+ result = constant_folding.fold_constants(model, *args, **kwargs)
+ # Move the model back to the proto
+ new_proto = ir.serde.serialize_model(model)
+ model_proto.Clear()
+ model_proto.CopyFrom(new_proto)
+ return result
-__all__ = [
- "fold_constants",
- "remove_unused_nodes",
- "optimize",
-]
+def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
+ """Removes unused nodes from a model inplace."""
+ if isinstance(model, ir.Model):
+ common_passes.RemoveUnusedNodesPass()(model)
+ else:
+ model_ir = ir.serde.deserialize_model(model)
+ model_ir = common_passes.RemoveUnusedNodesPass()(model_ir).model
+ new_proto = ir.serde.serialize_model(model_ir)
+ model.Clear()
+ model.CopyFrom(new_proto)
+
+
+def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None:
+ """Removes unused functions from a model inplace."""
+ if isinstance(model, ir.Model):
+ common_passes.RemoveUnusedFunctionsPass()(model)
+ else:
+ model_ir = ir.serde.deserialize_model(model)
+ model_ir = common_passes.RemoveUnusedFunctionsPass()(model_ir).model
+ new_proto = ir.serde.serialize_model(model_ir)
+ model.Clear()
+ model.CopyFrom(new_proto)
diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py
new file mode 100644
index 0000000000..9a740c783c
--- /dev/null
+++ b/onnxscript/optimizer/_constant_folding.py
@@ -0,0 +1,1372 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files.
+
+from __future__ import annotations
+
+__all__ = [
+ "basic_constant_propagation",
+ "fold_constants",
+ "FoldConstantsPass",
+ "FOLDED_FROM_KEY",
+]
+
+import dataclasses
+import logging
+import math
+import typing
+from typing import Any, Callable, Iterable, Sequence, Union
+
+import numpy as np
+import onnx
+import onnx.reference.ops
+import onnx_ir as ir
+
+import onnxscript.utils.utils as utils
+from onnxscript.ir import _tape
+
+DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 8192
+
+DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512
+
+# Key used to store the metadata
+FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from"
+
+
+_NON_DETERMINISTIC_OPS = frozenset(
+ {
+ "RandomUniform",
+ "RandomNormal",
+ "RandomUniformLike",
+ "RandomNormalLike",
+ "Multinomial",
+ }
+)
+
+# A list of ops to always fold regardless of their input size limits, as long as
+# they are the single consumer of the large input tensors
+_DEFAULT_ALWAYS_FOLD_OPS = frozenset(
+ {
+ ("", "Transpose"),
+ }
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _is_control_flow_op(node: ir.Node) -> bool:
+ graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
+ return any(attr.type in graph_types for attr in node.attributes.values())
+
+
+def _is_non_deterministic_op(node: ir.Node) -> bool:
+ return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain)
+
+
+def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
+ return node.op_type == op_type and utils.is_onnx_domain(node.domain)
+
+
+# "Standard" evaluators are used to perform constant-folding.
+# The API below works only for non-control-flow ops (ops without any graph-attributes).
+# This currently used ONNX's reference implementation. But we could also
+# use ORT's implementation if we want to.
+
+
+def _process_constant_node(node: ir.Node) -> None:
+ """Sets const_value of output value of a Constant op node."""
+ if not _is_onnx_op(node, "Constant"):
+ return
+ if len(node.attributes) != 1:
+ return
+ attr_name, attr_value = next(iter(node.attributes.items()))
+ if len(node.outputs) != 1:
+ return
+ ir_value = node.outputs[0]
+
+ if attr_value is None or not isinstance(attr_value, ir.Attr):
+ return
+
+ # Even if this is an attribute, the value property might not be set, which
+ # happens e.g. in case of attribute references, i.e., ref_attr_name is set
+ if attr_value.value is None:
+ # For now reject this to prevent TypeError from accessing Nones below
+ return
+
+ const_value: ir.TensorProtocol
+ if attr_name in {"value_float", "value_floats"}:
+ const_value = ir.Tensor(
+ np.array(attr_value.value, dtype=np.float32), name=ir_value.name
+ )
+ elif attr_name in {"value_int", "value_ints"}:
+ const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name)
+ elif attr_name in {"value_string", "value_strings"}:
+ const_value = ir.StringTensor(
+ np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
+ )
+ elif attr_name == "value":
+ const_value = typing.cast(ir.TensorProtocol, attr_value.value)
+ else:
+ return
+
+ ir_value.const_value = const_value
+ ir_value.shape = const_value.shape # type: ignore
+ ir_value.dtype = const_value.dtype
+
+
+def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None:
+ """Performs basic constant propagation for a sequence of nodes.
+
+ Just marks the output values of Constant op nodes with their const_value.
+ """
+ for node in nodes:
+ _process_constant_node(node)
+
+
+class ReferenceEvaluator:
+ def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None:
+ try:
+ op_impl_class = onnx.reference.ops.load_op(domain, op, version)
+ return op_impl_class.eval # noqa: TRY300
+ except Exception:
+ return None
+
+ def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any:
+ logger.debug("Evaluating %s::%s", domain, op)
+ evaluator = self.get_evaluator(domain, op, version)
+ if evaluator is None:
+ return None
+ try:
+ return evaluator(*args, **kwargs)
+ except Exception as e:
+ logger.warning("Evaluation failed: %s", e)
+ return None
+
+
+_reference_evaluator = ReferenceEvaluator()
+
+
+@dataclasses.dataclass
+class Replacement:
+ """A replacement for a node in the graph."""
+
+ new_outputs: Sequence[ir.Value]
+ new_nodes: Sequence[ir.Node]
+
+
+# The optimizer tracks an optional symbolic value for each value in the model.
+# The symbolic value attached to a value X can be:
+# - another IR value Y (indicating that X is equal to Y)
+# - a list of IR values [Y1, Y2, ...] (indicating that X is a sequence of values Y1, Y2, ...)
+# - a Shape object (indicating that X is a shape value)
+# A Shape object as a symbolic value indicates that the corresponding value is
+# 1-D (or 0-D) tensor of INT64 values. The values in this object may be constants
+# or symbolic dimension values (like "batch_size", "sequence_length", etc.).
+# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
+# TODO: Add support for negative symbolic dimensions.
+
+SymbolicValue = Union[ir.Value, list[ir.Value], ir.Shape]
+
+
+class OptimizerState:
+ def __init__(self):
+ self._sym_value_map: dict[ir.Value, SymbolicValue] = {}
+ self._initializer_inputs: list[set[ir.Value]] = []
+
+ @property
+ def symbolic_value_map(self) -> dict[ir.Value, SymbolicValue]:
+ return self._sym_value_map
+
+ def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None:
+ if value is None:
+ return None
+ return self._sym_value_map.get(value)
+
+ def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
+ self._sym_value_map[value] = sym_value
+
+ def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
+ const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
+ if const_value is not None:
+ if const_value.ndim == 1:
+ return ir.Shape(const_value.tolist())
+ return None
+ sym_value = self.get_sym_value(value)
+ if isinstance(sym_value, ir.Shape):
+ return sym_value
+ # TODO use shape of value if available
+ return None
+
+
+# The "partial evaluators" below are non-standard evaluators. They are used to perform
+# partial evaluation and/or static program analysis (abstract interpretation).
+
+# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns
+# a Replacement for the node or None (if no replacement is needed). It may also return just
+# the ir.Value or ir.Values to replace the output values of the node, when the new nodes
+# can be inferred from the RewriterContext used to build the new nodes.
+
+RewriterContext = _tape.Builder
+ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None]
+PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue]
+
+
+@dataclasses.dataclass
+class PartialEvaluator:
+ """A class that represents a partial-evaluator for a particular op.
+
+ It is applicable for a specific version range (min_version, max_version) of the op.
+ The min_version and max_version can be None, indicating that there is no version
+ constraint in that direction.
+ """
+
+ min_version: int | None
+ max_version: int | None
+ function: PartialEvaluatorFunction
+
+ def valid_for(self, version: int) -> bool:
+ """Returns True if this evaluator is applicable for the given version."""
+ return (self.min_version is None or version >= self.min_version) and (
+ self.max_version is None or version <= self.max_version
+ )
+
+
+class PartialEvaluatorRegistry:
+ """A class that maintains a registry of evaluators for ops."""
+
+ def __init__(self):
+ self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {}
+
+ def lookup_evaluators(self, domain: str, opname: str, version: int):
+ evaluator_list = self.op_evaluators.get((domain, opname), [])
+ return [
+ evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version)
+ ]
+
+ def register(
+ self, opname: str, domain: str = "", version=None
+ ) -> Callable[[PartialEvaluatorFunction], PartialEvaluatorFunction]:
+ if (domain, opname) in self.op_evaluators:
+ evaluator_list = self.op_evaluators[(domain, opname)]
+ else:
+ evaluator_list = []
+ self.op_evaluators[(domain, opname)] = evaluator_list
+ if version is None:
+ min_version = None
+ max_version = None
+ elif isinstance(version, int):
+ min_version = version
+ max_version = version
+ elif isinstance(version, tuple):
+ min_version, max_version = version
+
+ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction:
+ evaluator_list.append(PartialEvaluator(min_version, max_version, function))
+ return function
+
+ return decorator
+
+
+registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry()
+
+register = registry.register
+
+
+def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool:
+ # Comparison of shapes as tuples works except if any dimension is None
+ # (which represents an unknown dimension value). Thus, two shapes such
+ # as (Batch, 1024) and (Batch, 1024) are considered equal, but (None, 1024)
+ # and (None, 1024) are not considered equal.
+ if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1):
+ return False
+ return shape1.dims == shape2.dims
+
+
+def _get_numpy_value(
+ val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None
+) -> np.ndarray | None:
+ """Returns the numpy value of a constant value, if available.
+
+ It returns None if the value is not a constant value, or if the value is not of
+ the specified element dtype, or if the size of the value exceeds the specified
+ size_limit.
+ """
+ if val is None:
+ return None
+ const_value = val.const_value
+ if const_value is not None:
+ if dtype is not None and const_value.dtype != dtype:
+ return None
+ if size_limit is not None and const_value.size > size_limit:
+ return None
+ try:
+ # Turn the constant value into a numpy array representation with the
+ # specifics of this conversion handled by the tensor type
+ array = const_value.numpy()
+ # Can/should not reinterpret strings via .view, resulting in
+ # "TypeError: Cannot change data-type for array of references."
+ # There is also no reason to reinterpret strings, this is only
+ # relevant for some arithmetic types
+ if const_value.dtype != ir.DataType.STRING:
+ # Reinterpret the array with `.view()` because some
+ # implementations of ir.TensorProtocol (e.g. PyTorch<=2.7) do
+ # not use ml_dtypes for bfloat16 etc.
+ array = array.view(const_value.dtype.numpy())
+ except FileNotFoundError:
+ # External data is not available.
+ logger.warning(
+ "External data for value '%s' is not available. "
+ "This may lead to incorrect constant folding.",
+ val.name,
+ )
+ return None
+ assert isinstance(array, np.ndarray)
+ return array
+ return None
+
+
+def _get_bool_value(val: ir.Value | None) -> bool | None:
+ if val is None:
+ return None
+ value = _get_numpy_value(val)
+ if value is None:
+ return None
+ if value.size == 1 and value.dtype == bool:
+ return value.item(0)
+ return None
+
+
+def _get_input(node: ir.Node, index: int) -> ir.Value | None:
+ if index < len(node.inputs):
+ return node.inputs[index]
+ return None
+
+
+def _get_output(node: ir.Node, index: int) -> ir.Value | None:
+ if index < len(node.outputs):
+ return node.outputs[index]
+ return None
+
+
+def _get_input_element_type(node: ir.Node, index: int) -> int:
+ input = _get_input(node, index)
+ if input is not None and input.type is not None:
+ return input.type.dtype.value
+ return ir.DataType.UNDEFINED.value
+
+
+def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None:
+ if name in node.attributes:
+ attr = node.attributes[name]
+ if not isinstance(attr, ir.Attr):
+ return None
+ attr_val = attr.value
+ if isinstance(attr_val, int):
+ return attr_val
+ # This is an invalid model: attribute has invalid/unexpected type.
+ # For now, we just return None. We could raise an error too.
+ return None
+ return default
+
+
+@register("Add")
+def add(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Propagate symbolic dim values."""
+
+ def get_dim_value(input_index):
+ input = _get_input(node, input_index)
+ if input is None:
+ return None
+ shape_value: ir.Shape | None = state.get_shape_value(input)
+ if shape_value is None or len(shape_value) != 1:
+ return None
+ dim: int | ir.SymbolicDim = shape_value[0]
+ return dim if isinstance(dim, int) else dim.value
+
+ dim0 = get_dim_value(0)
+ dim1 = get_dim_value(1)
+ if dim0 is None or dim1 is None:
+ return None
+ if isinstance(dim0, int) and isinstance(dim1, int):
+ result_dim_value: int | ir.SymbolicDim = dim0 + dim1
+ else:
+ result_dim_value = ir.SymbolicDim(f"{dim0}+{dim1}")
+ output = _get_output(node, 0)
+ if output is not None:
+ state.set_sym_value(output, ir.Shape([result_dim_value]))
+
+
+@register("Abs")
+def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Replace an Abs node by Identity when applicable.
+
+ Currently, addresses Abs applied to symbolic shapes.
+ """
+ input = _get_input(node, 0)
+ input_sym_value = state.get_shape_value(input)
+ if input_sym_value is None:
+ return None
+ if any(isinstance(d, int) and d < 0 for d in input_sym_value):
+ return None
+ # Abs applied to a symbolic shape of the form [1, 1, SequenceLength].
+ # We assume that SequenceLength is a non-negative integer.
+ # The Abs op is redundant in this case.
+ return op.Identity(input)
+
+
+@register("Gather")
+def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Replace a Gather node by a constant when applicable.
+
+ Currently, handles the case of Gathering from a shape tensor.
+ """
+ input = _get_input(node, 0)
+ indices = _get_input(node, 1)
+ if input is None or indices is None:
+ return None
+ input_sym_value = state.get_shape_value(input)
+ if input_sym_value is None:
+ return None
+ axis = _get_int_attribute(node, "axis", None)
+ if axis != 0:
+ return None
+ indices_numpy_value = _get_numpy_value(indices)
+ if indices_numpy_value is None:
+ return None
+ if indices_numpy_value.ndim != 1:
+ return None
+ gathered = [input_sym_value[i] for i in indices_numpy_value]
+ output = _get_output(node, 0)
+ if output is not None:
+ state.set_sym_value(output, ir.Shape(gathered))
+ if all(isinstance(d, int) for d in gathered):
+ return op.Constant(value_ints=ir.AttrInt64s("value_ints", gathered))
+ return None
+
+
+def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Propagates symbolic shape value of input 0 to output 0.
+
+ Applies to ops like Reshape/Squeeze/Unsqueeze where the shape of the tensor may change
+ but the values in the tensor remain the same.
+ """
+ input = _get_input(node, 0)
+ input_shape_value = state.get_shape_value(input)
+ output = _get_output(node, 0)
+ if output is not None and input_shape_value is not None:
+ state.set_sym_value(output, input_shape_value)
+ return None
+
+
+@register("Reshape")
+def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Replace a Reshape node by Identity when applicable.
+
+ Also propagate symbolic shape values.
+ """
+ input = _get_input(node, 0)
+ shape = _get_input(node, 1)
+ if input is None or shape is None:
+ return None
+
+ input_shape = input.shape
+ shape_value = state.get_shape_value(shape)
+
+ if shape_value is None or input_shape is None:
+ return _propagate_shape_value(node, op, state)
+
+ # No need to check for special values like -1, 0, etc. here
+ if _same_shape(input_shape, shape_value):
+ return op.Identity(input)
+ return _propagate_shape_value(node, op, state)
+
+
+@register("Squeeze")
+def squeeze(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Propagate symbolic shape values."""
+ return _propagate_shape_value(node, op, state)
+
+
+@register("Cast")
+def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ input = _get_input(node, 0)
+ output = _get_output(node, 0)
+
+ if input is None or output is None:
+ return None
+
+ # TODO(rama): Parts of the following logic (implementing type/shape inference
+ # for Cast op) should be unnecessary. Generic incremental shape-inference
+ # should handle this. Only the optimization to eliminate redundant Cast ops
+ # should be needed here.
+
+ output.shape = _merge_shapes(output.shape, input.shape)
+
+ input_dtype = _get_input_element_type(node, 0)
+ output_dtype = _get_int_attribute(node, "to", None)
+ if output_dtype is not None:
+ if input_dtype == output_dtype:
+ return op.Identity(input)
+ output.type = ir.TensorType(ir.DataType(output_dtype))
+ return None
+
+
+@register("CastLike")
+def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ input0 = node.inputs[0]
+ source_element_type = _get_input_element_type(node, 0)
+ target_element_type = _get_input_element_type(node, 1)
+
+ if target_element_type == ir.DataType.UNDEFINED:
+ return None
+ if source_element_type == target_element_type:
+ return op.Identity(input0)
+ return op.Cast(input0, to=target_element_type)
+
+
+@register("Shape")
+def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ input = node.inputs[0]
+ if input is None:
+ return None
+ shape = input.shape
+ if shape is None:
+ return None
+ start = _get_int_attribute(node, "start", 0)
+ end = _get_int_attribute(node, "end", None)
+ shape_slice = shape[start:end]
+ output = _get_output(node, 0)
+ if output is not None:
+ state.set_sym_value(output, ir.Shape(shape_slice))
+ if all(isinstance(d, int) for d in shape_slice):
+ return op.Constant(value_ints=ir.AttrInt64s("value_ints", list(shape_slice)))
+ return None
+
+
+@register("Size")
+def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ input = _get_input(node, 0)
+ if input is None:
+ return None
+ shape = input.shape
+ if shape is None:
+ return None
+ size = 1
+ for d in shape:
+ if not isinstance(d, int):
+ return None
+ size *= d
+ return op.Constant(value_int=size)
+
+
+@register("If")
+def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ cond_input = _get_input(node, 0)
+ cond = _get_bool_value(cond_input)
+ if cond is not None:
+ # cond is a constant-value: inline the branch
+ branch = "then_branch" if cond else "else_branch"
+ graph_attr = node.attributes.get(branch)
+ if graph_attr is None:
+ return None
+ if graph_attr.type != ir.AttributeType.GRAPH:
+ return None
+ assert isinstance(graph_attr, ir.Attr)
+ graph = graph_attr.as_graph()
+ # Copy the graph outputs and clear the graph outputs so that the values are free to move
+ formal_outs = list(graph.outputs)
+ graph.outputs.clear()
+ actual_outs = node.outputs
+ renamings = {
+ formal.name: actual.name
+ for formal, actual in zip(formal_outs, actual_outs)
+ if actual is not None
+ }
+ # TODO: Extend renaming to intermediate values.
+
+ def rename(name):
+ return renamings.get(name, name)
+
+ graph_nodes = list(graph)
+ graph.remove(graph_nodes)
+ for sub_node in graph_nodes:
+ # TODO: handle renaming inside subgraphs in nodes
+ for v in sub_node.outputs:
+ v.name = rename(v.name)
+ # Avoid name collision.
+ sub_node.name = f"{node.name}_{sub_node.name}"
+
+ # TODO: we should handle initializers as well!
+ return Replacement(formal_outs, graph_nodes)
+ return None
+
+
+@register("Identity")
+def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ del op
+ input = node.inputs[0]
+ output = node.outputs[0]
+ if input is not None and output is not None:
+ input.shape = _merge_shapes(input.shape, output.shape)
+ if input.type is None:
+ input.type = output.type
+ state.set_sym_value(output, input)
+ return None
+
+
+@register("SequenceConstruct")
+def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ del op
+ output = node.outputs[0]
+ if output is not None:
+ state.set_sym_value(output, list(node.inputs))
+ return None
+
+
+@register("Concat")
+def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Replace a Concat node with a single input by Identity"""
+
+ # Replace Concat(x) by Identity(x)
+ inputs = node.inputs
+ if len(inputs) == 1:
+ return op.Identity(inputs[0])
+
+ axis = _get_int_attribute(node, "axis", None)
+ if axis is None:
+ return None
+
+ # Eliminate zero-length operands from Concat
+ def has_zero_size(operand: ir.Value | None) -> bool:
+ if operand is None:
+ return False # Invalid model
+ if (shape := operand.shape) is None:
+ return False
+ try:
+ # We have already checked that axis is an int value (!= None)
+ dim_size = shape[axis] # type: ignore[index]
+ except IndexError:
+ return False
+ return dim_size == 0 # return False if symbolic or None or non-zero int value
+
+ new_inputs = [x for x in inputs if not has_zero_size(x)]
+ if len(new_inputs) != len(inputs):
+ if new_inputs:
+ # Remove zero-length operands from Concat
+ logger.debug(
+ "Concat: removing zero-length operand(s) %s => %s", inputs, new_inputs
+ )
+ return op.Concat(*new_inputs, axis=axis)
+ elif inputs:
+ # All operands are zero-length. Concat is a no-op, but we need to use one of the
+ # inputs to get the other dimensions correct:
+ logger.debug("Concat: removing all zero-length operands %s", inputs)
+ return op.Identity(inputs[0])
+ else:
+ # No inputs: invalid model.
+ return None
+
+ # Track value of tensors that carry a shape value:
+
+ # Check axis attribute is 0
+
+ if axis != 0:
+ return None
+ shapes = [state.get_shape_value(input) for input in inputs]
+ if any(shape is None for shape in shapes):
+ return None
+ concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr]
+ output = node.outputs[0]
+ if output is None:
+ return None
+ state.set_sym_value(output, concatenated)
+ return None
+
+
+@register("Dropout", version=(12, None))
+def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Replace a Dropout by Identity when applicable."""
+
+ def optimized_dropout():
+ input = node.inputs[0]
+ output = op.Identity(input)
+ if len(node.outputs) == 1:
+ return output
+ else:
+ true_tensor = ir.tensor([True])
+ input_shape = op.Shape(input)
+ mask = op.ConstantOfShape(input_shape, value=true_tensor)
+ return output, mask
+
+ inputs = node.inputs
+ if (len(inputs) <= 2) or inputs[2] is None:
+ # No training_mode specified:
+ return optimized_dropout()
+ if _get_bool_value(inputs[2]) is False:
+ # training_mode is False: dropout is not applied.
+ return optimized_dropout()
+ ratio = _get_numpy_value(inputs[1])
+ if ratio is None:
+ return None
+ if ratio.size != 1: # Only scalar dropout ratio is supported.
+ return None
+ if ratio.item() == 0:
+ # dropout ratio is 0: dropout is not applied.
+ return optimized_dropout()
+ return None
+
+
+@register("Expand")
+def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Replace an Expand node by Identity when applicable."""
+ if len(node.inputs) != 2:
+ return None
+ if (input := node.inputs[0]) is None:
+ return None
+ if (input_shape := input.shape) is None:
+ # Input shape is not known.
+ return None
+ if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
+ # Target shape is not known.
+ expanded_sym_shape = state.get_shape_value(node.inputs[1])
+ if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape):
+ return None
+ return op.Identity(input)
+ if expanded_shape.ndim != 1:
+ # Target shape must be a 1D tensor. Erroneous model.
+ return None
+ if input_shape.dims == tuple(expanded_shape.tolist()):
+ return op.Identity(input)
+ return None
+
+
+@register("ConcatFromSequence")
+def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ input = node.inputs[0]
+ inputs = state.get_sym_value(input)
+ if inputs is None or any(x is None for x in inputs):
+ return None
+ new_axis = _get_int_attribute(node, "new_axis", 0)
+ axis = _get_int_attribute(node, "axis", None)
+ if axis is None:
+ return None
+ if input is not None and isinstance(inputs, list):
+ if new_axis == 0:
+ logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs])
+ return op.Concat(*inputs, axis=axis)
+ if new_axis == 1:
+ # Unsqueeze the inputs with concat axis if new_axis is 1
+ axis_value = op.Constant(value_int=axis)
+ unsqueezed_inputs = []
+ for node_input in inputs:
+ unsqueezed_input = op.Unsqueeze(
+ node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"]
+ )
+ unsqueezed_inputs.append(unsqueezed_input)
+ # Send unsqueezed outputs to Concat
+ logger.debug(
+ "ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs]
+ )
+ return op.Concat(*unsqueezed_inputs, axis=axis)
+ return None
+
+
+@register("SplitToSequence")
+def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ """Rewriting pattern.
+
+ From
+
+ splits = onnx::SplitToSequence(input, split, axis=axis)
+
+ to
+
+ split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis)
+ splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n)
+
+ or
+
+ split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1)
+ splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n)
+
+ where number of output tensors in `splits` is statically known.
+ onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator.
+ This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly.
+ """
+ input = node.inputs[0]
+ if len(node.inputs) == 1:
+ # split is not provided
+ return None
+ split = node.inputs[1]
+ output = node.outputs[0]
+
+ if input is None or split is None or output is None:
+ return None
+
+ axis = _get_int_attribute(node, "axis", 0)
+ if axis is None:
+ return None
+ shape = input.shape
+ if shape is None:
+ return None
+ rank = len(shape)
+ if axis < 0:
+ axis = axis + rank
+ if axis < 0 or axis >= rank:
+ return None
+
+ # NOTE: Split needs to either be a scalar or a 1-D tensor. We need to
+ # calculate the number of outputs for Split.
+ # If split is a scalar, we split into chunks of size 'split' if possible.
+ # * the split dimension size and split_value has to be known.
+ # If split is a 1-D tensor, we split into 'size(split)' chunks
+ # * Get the size from split_value if it's numpy array.
+ # * Get the size from symbolic shape if split_value is not available.
+ split_value = _get_numpy_value(split)
+ split_shape = (
+ split.shape.numpy() if split.shape is not None and split.shape.is_static() else None
+ )
+
+ # No information about split value or shape.
+ if split_value is None and split_shape is None:
+ return None
+
+ if isinstance(split_shape, tuple) and len(split_shape) == 1:
+ # If split_shape is known, we can use it to determine the number of outputs.
+ split_dimension_size = split_shape[0]
+ assert isinstance(split_dimension_size, int)
+ num_outputs = split_dimension_size
+ split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
+ split_values = op.Split(input, split, axis=axis, _outputs=split_outputs)
+ elif split_value.ndim == 1:
+ # split into 'size(split)' chunks
+ num_outputs = split_value.size
+ split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
+ split_values = op.Split(input, split, axis=axis, _outputs=split_outputs)
+ elif split_value.ndim == 0:
+ # split into chunks all of size 'split' if possible.
+ split_dimension_size = shape[axis]
+ if not isinstance(split_dimension_size, int):
+ return None
+ num_outputs = math.ceil(split_dimension_size / split_value.item())
+ split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)]
+ split_values = op.Split(
+ input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs
+ )
+ else:
+ return None
+
+ # If Split returns a single value, we need to wrap it into a list.
+ if isinstance(split_values, ir.Value):
+ split_values = [split_values]
+
+ keepdims = _get_int_attribute(node, "keepdims", 1)
+ if keepdims is None:
+ return None
+ if keepdims == 0:
+ # squeeze the split dimension if keepdims is 0
+ axis_val = op.Constant(value_ints=[axis], _outputs=[f"{output.name}_axis"])
+ squeezed_values = []
+ for i in range(num_outputs):
+ squeezed = op.Squeeze(
+ split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"]
+ )
+ squeezed_values.append(squeezed)
+ split_values = squeezed_values
+
+ logger.debug("SplitToSequence => Split + SequenceConstruct")
+
+ if isinstance(split_values, ir.Value):
+ split_values = [split_values]
+ return op.SequenceConstruct(*split_values)
+
+
+@register("SequenceAt")
+def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
+ input = node.inputs[0]
+ position = node.inputs[1]
+ output = node.outputs[0]
+ if input is not None and position is not None:
+ input_vals = state.get_sym_value(input)
+ position_val = _get_numpy_value(position)
+ if isinstance(input_vals, list) and position_val is not None:
+ if position_val.size != 1:
+ return None
+ position_val = position_val.item()
+ try:
+ result = input_vals[position_val] # type: ignore[index]
+ except IndexError:
+ return None
+ state.set_sym_value(output, result)
+ logger.debug("SequenceAt %s => %s", input.name, result.name)
+ return op.Identity(result)
+ return None
+
+
+def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
+ def merge_dims(dim1, dim2):
+ if dim1 == dim2:
+ return dim1
+ if not isinstance(dim1, ir.SymbolicDim):
+ return dim1 # Prefer int value over symbolic dim
+ if not isinstance(dim2, ir.SymbolicDim):
+ return dim2
+ if dim1.value is None:
+ return dim2
+ return dim1
+
+ if shape1 is None:
+ return shape2
+ if shape2 is None:
+ return shape1
+ if len(shape1) != len(shape2):
+ raise ValueError("Shapes must have the same rank.")
+ return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
+
+
+def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None:
+ """Record the set of original input values that contributed to the constant-folded outputs."""
+ folded_from: set[str] = set()
+ for input in original_node.inputs:
+ if input is None:
+ continue
+ folded_from.update(input.meta.get(FOLDED_FROM_KEY, set()))
+ assert input.name is not None
+ folded_from.add(input.name)
+
+ for new_output in replacement.new_outputs:
+ if new_output is None:
+ continue
+ new_output.meta[FOLDED_FROM_KEY] = folded_from
+ # Store the string representation of the set to metadata_props to persist it across serialization
+ new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from))
+
+
+class FoldConstantsPass(ir.passes.InPlacePass):
+ """A pass that folds constant expressions in the model.
+
+ Attributes:
+ shape_inference: Whether to perform shape inference.
+ input_size_limit: Maximum size of input tensors to fold.
+ output_size_limit: Maximum size of output tensors to fold.
+ should_fold: An optional function that takes a node and returns True if
+ the node should be considered for folding.
+ The function should return True/False value to indicate if this particular
+ node should be folded, or None to use the default folding rules.
+ """
+
+ def __init__(
+ self,
+ *,
+ shape_inference: bool,
+ input_size_limit: int,
+ output_size_limit: int,
+ should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
+ ) -> None:
+ self.shape_inference = shape_inference
+ self.input_size_limit = input_size_limit
+ self.output_size_limit = output_size_limit
+ self.should_fold = should_fold
+
+ self._opset_imports: dict[str, int] = {}
+ self._counts: dict[str, int] = {}
+ self._sizes: dict[str, int] = {}
+ self._modified: bool = False
+ self._state = OptimizerState()
+ self._reset()
+
+ def _reset(self) -> None:
+ """Reset internal states for a new run."""
+ self._counts = {}
+ self._sizes = {}
+ self._modified = False
+ self._state = OptimizerState()
+
+ def _do_inference(self, node: ir.Node) -> None:
+ output_types = {}
+
+ # TODO: handle optional inputs
+ def get_constant_value(x: ir.Value) -> onnx.TensorProto | None:
+ value = _get_numpy_value(x, size_limit=20)
+ if value is not None:
+ assert x.const_value is not None
+ return ir.serde.serialize_tensor(x.const_value)
+ return None
+
+ def get_type(value: ir.Value) -> onnx.TypeProto | None:
+ if value.type is not None:
+ type_proto = ir.serde.serialize_type(value.type)
+ if value.shape is not None:
+ ir.serde.serialize_shape_into(type_proto, value.shape)
+ return type_proto
+ return None
+
+ input_types = {x.name: get_type(x) for x in node.inputs if x is not None}
+ input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None}
+ input_data = {k: v for k, v in input_data.items() if v is not None}
+ if any(t is None for t in input_types.values()):
+ logger.debug(
+ "Skipping shape inference for node %r due to missing input type.",
+ node.name,
+ )
+ else:
+ # TODO: pass in constant values, ir_version
+ try:
+ schema = onnx.defs.get_schema(
+ node.op_type, self._opset_imports[node.domain], node.domain
+ )
+ output_types = onnx.shape_inference.infer_node_outputs(
+ schema,
+ ir.serde.serialize_node(node),
+ input_types, # type: ignore[arg-type]
+ input_data, # type: ignore[arg-type]
+ )
+ for output in node.outputs:
+ if output.name in output_types:
+ inferred_type = output_types[output.name]
+ # TODO: merge types, check for conflicts
+ inferred_shape = ir.serde.deserialize_type_proto_for_shape(
+ inferred_type
+ )
+ output.shape = _merge_shapes(output.shape, inferred_shape)
+ output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
+ except Exception as e:
+ logger.debug(
+ "Skipping shape inference for node %r due to exception: %s",
+ node.name,
+ e,
+ )
+
+ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
+ irvalue = node.outputs[0]
+ if not isinstance(value, np.ndarray):
+ # ONNX does not have a way to represent non-tensor constants, eg. a sequence.
+ # So, a constant-value of type sequence is not folded, but it can be used
+ # to optimize subsequent operations when possible.
+ logger.info(
+ "Skip storing constant folded value %s due to unsupported type %s.",
+ irvalue.name,
+ type(value),
+ )
+ return None
+
+ tensor = ir.tensor(value)
+ tensor.name = irvalue.name
+ irvalue.const_value = tensor
+
+ if value.size > self.output_size_limit:
+ # Handle examples like Transpose(weight) to be folded even if the size is large,
+ # as long as weight has no other uses. This won't increase model size.
+ removed_input_size = 0
+ for input in node.inputs:
+ if (input is not None) and (len(input.uses()) == 1):
+ array = _get_numpy_value(input)
+ if array is not None:
+ removed_input_size += array.size
+ increased_size = value.size - removed_input_size
+ if increased_size > 0:
+ logger.info(
+ "Skip storing constant folded nvalue %s due to large size %s.",
+ irvalue.name,
+ value.size,
+ )
+ return None
+
+ logger.debug(
+ "New constant for value %s dtype: %s shape: %s",
+ irvalue.name,
+ value.dtype,
+ value.shape,
+ )
+
+ attributes = ir.convenience.convert_attributes({"value": tensor})
+ node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
+ return node
+
+ def process_node(self, node: ir.Node) -> Replacement | None:
+ """Process a node and return a Replacement if the node can be replaced."""
+ for i, value in enumerate(node.inputs):
+ sym_value = self._state.get_sym_value(value)
+ if isinstance(sym_value, ir.Value):
+ logger.debug(
+ "Node [%s]: Replacing input %s with %s",
+ node.name,
+ value.name, # type: ignore[union-attr]
+ sym_value.name,
+ )
+ node.replace_input_with(i, sym_value)
+ self._modified = True
+ # TODO(rama): consider merging type/other info from both values
+
+ # Propagate const_value, and manually find out shape and type
+ # to avoid potentially expensive shape inference on large tensors.
+ if _is_onnx_op(node, "Constant"):
+ _process_constant_node(node)
+ # Do incremental shape inference
+ elif self.shape_inference and not _is_control_flow_op(node):
+ self._do_inference(node)
+
+ if node.domain not in self._opset_imports:
+ return None
+ version = self._opset_imports[node.domain]
+ op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
+ for optimizer in op_optimizers:
+ assert optimizer
+ context = RewriterContext()
+ output = optimizer(node, context, self._state)
+ if output is not None:
+ if isinstance(output, Replacement):
+ return output
+ if isinstance(output, ir.Value):
+ output = [output]
+ return Replacement(output, context.nodes)
+
+ if _is_onnx_op(node, "Constant"):
+ logger.debug("Skipping constant folding for Constant node %r", node.name)
+ return None
+
+ if _is_control_flow_op(node):
+ logger.info(
+ "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet",
+ node.name,
+ node.domain,
+ node.op_type,
+ )
+
+ return None
+
+ if _is_non_deterministic_op(node):
+ logger.info(
+ "Skipping constant folding for non-deterministic op %r (%s::%s)",
+ node.name,
+ node.domain,
+ node.op_type,
+ )
+ return None
+
+ if any(x.is_graph_input() for x in node.inputs if x is not None):
+ logger.info(
+ "Skipping constant folding for node %r because it is graph input to preserve graph signature",
+ node.name,
+ )
+ return None
+
+ # Ensure all node inputs are constants
+ if any(x.const_value is None for x in node.inputs if x is not None):
+ return None
+
+ should_fold = self.should_fold(node)
+
+ if should_fold is False:
+ logger.info(
+ "Skipping constant folding for node %r because should_fold returned False",
+ node.name,
+ )
+ return None
+
+ elif should_fold is None:
+ # Use default rules to decide whether to fold the node:
+ # - ConstantOfShape is preserved to avoid increasing model size unnecessarily
+ # - If the any tensor input size exceeds the input_size_limit, skip folding the node
+ if _is_onnx_op(node, "ConstantOfShape"):
+ logger.info(
+ "Skipping constant folding for node %r because ConstantOfShape is preserved by default",
+ node.name,
+ )
+ return None
+
+ input_tensors = [x.const_value if x is not None else None for x in node.inputs]
+ large_inputs = [
+ tensor is not None and tensor.size > self.input_size_limit
+ for tensor in input_tensors
+ ]
+ if any(large_inputs):
+ # Decide whether to fold large constants
+ assert len(node.inputs) == len(large_inputs)
+ if (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS and all(
+ len(input.consumers()) == 1 or (not is_large)
+ for input, is_large in zip(node.inputs, large_inputs)
+ if input is not None
+ ):
+ # If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node,
+ # we can still fold it even if the input size exceeds the limit
+ pass
+ else:
+ # Skip folding large tensors
+ if logger.isEnabledFor(logging.INFO):
+ input_sizes = [
+ tensor.size for tensor in input_tensors if tensor is not None
+ ]
+ logger.info(
+ "Skipping constant folding for node %r due to large input sizes: %s",
+ node,
+ input_sizes,
+ )
+ return None
+ else:
+ logger.info(
+ "Constant folding node %r because should_fold returned True",
+ node.name,
+ )
+
+ input_values = [_get_numpy_value(x) for x in node.inputs]
+
+ def convert(av):
+ if av.type == ir.AttributeType.TENSOR:
+ return ir.serde.serialize_tensor(av.value)
+ return av.value
+
+ # TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node
+ attr_values = {name: convert(attr) for name, attr in node.attributes.items()}
+ outputs = _reference_evaluator.evaluate(
+ node.domain, node.op_type, version, *input_values, **attr_values
+ )
+
+ if outputs is None:
+ return None
+ if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
+ replacement = self.new_constant(node, outputs)
+ if replacement is None:
+ return None
+ return Replacement(replacement.outputs, [replacement])
+ else:
+ logger.warning(
+ "Skipping constant folding for op %s with multiple outputs.", node.op_type
+ )
+ return None
+
+ def replace_node(
+ self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function
+ ) -> None:
+ logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
+
+ # Record the names of the values that has contributed to the replacement
+ _record_contributing_values(node, replacement)
+
+ ir.convenience.replace_nodes_and_values(
+ root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
+ )
+
+ self._modified = True
+
+ # TODO: what about new opset_imports?
+ # TODO: track statistics about replaced nodes and sizes of new constants
+
+ def visit_attribute(self, attr: ir.Attr) -> None:
+ if attr.is_ref():
+ return
+ if attr.type == ir.AttributeType.GRAPH:
+ self.visit_graph(attr.as_graph())
+ elif attr.type == ir.AttributeType.GRAPHS:
+ for graph in attr.as_graphs():
+ self.visit_graph(graph)
+
+ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
+ replacement = self.process_node(node)
+ if replacement is None:
+ # No change. Process attributes.
+ for attr in node.attributes.values():
+ self.visit_attribute(attr)
+ return
+ else:
+ self.replace_node(node, replacement, root)
+
+ def visit_graph(self, graph: ir.Graph) -> None:
+ for node in graph:
+ self.visit_node(node, graph)
+
+ # Replace outputs if output nodes can be folded. This are typically outputs from
+ # Identity nodes
+ for i, output in enumerate(graph.outputs):
+ if output is None:
+ continue
+ sym_value = self._state.get_sym_value(output)
+ if not isinstance(sym_value, ir.Value):
+ # An output must be a Value
+ continue
+ if not _sym_value_can_replace_graph_output(graph, sym_value, output):
+ continue
+ # Rename sym_value to match the output name
+ sym_value.name = output.name
+ graph.outputs[i] = sym_value
+ self._modified = True
+
+ def visit_function(self, function: ir.Function) -> None:
+ for node in function:
+ self.visit_node(node, function)
+
+ def call(self, model: ir.Model) -> FoldConstantsResult:
+ self._reset()
+ self._opset_imports = model.opset_imports
+ self.visit_graph(model.graph)
+ for function in model.functions.values():
+ # TODO(rama): Should we specialize functions?
+ self.visit_function(function)
+ return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map)
+
+
+def _sym_value_can_replace_graph_output(
+ graph: ir.Graph, sym_value: ir.Value, output: ir.Value
+) -> bool:
+ if (producer := sym_value.producer()) is None:
+ # If the sym_value has no producer, it is some graph's input
+ # ONNX does not allow a graph input to be a graph output
+ return False
+ if producer.graph is not graph:
+ # The sym_value must be produced by a node in the graph to be an output of this graph
+ return False
+ if sym_value.is_graph_output():
+ # If the sym_value is already an output of a graph, we cannot rename it
+ # to this output name. Otherwise the graph output represented by sym_value
+ # will lose its name.
+ return False
+ return True
+
+
+@dataclasses.dataclass
+class FoldConstantsResult(ir.passes.PassResult):
+ symbolic_value_map: dict[ir.Value, SymbolicValue]
+
+ # Add conversion to bool for backward compatibility. The previously returned value
+ # for the fold_constants method was a boolean indicating whether the model was modified.
+ def __bool__(self) -> bool:
+ return self.modified
+
+
+def fold_constants(
+ model: ir.Model,
+ *,
+ onnx_shape_inference: bool = False,
+ input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
+ output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
+ should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
+) -> FoldConstantsResult:
+ """
+ Applies constant folding optimization to the model.
+
+ Args:
+ model: The ONNX model to optimize.
+ onnx_shape_inference: Whether to enable ONNX shape inference during
+ constant folding. Defaults to False.
+ input_size_limit: The maximum size of input tensors
+ that can be considered for constant folding. Defaults to
+ `DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`.
+ output_size_limit: The maximum size of output tensors
+ that can be stored after constant folding. Defaults to
+ `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
+ should_fold: An optional function that takes a node and returns True if
+ the node should be considered for folding, False if it should not be folded,
+ or None to use the default rules. Defaults to a function that always returns None.
+
+ Returns:
+ An instance of `FoldConstantsResult`.
+
+ """
+ folder_pass = FoldConstantsPass(
+ shape_inference=onnx_shape_inference,
+ input_size_limit=input_size_limit,
+ output_size_limit=output_size_limit,
+ should_fold=should_fold,
+ )
+ return folder_pass(model) # type: ignore[return-value]
diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py
new file mode 100644
index 0000000000..d3d76c4a23
--- /dev/null
+++ b/onnxscript/optimizer/_constant_folding_test.py
@@ -0,0 +1,699 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import onnx
+import parameterized
+
+import onnxscript.optimizer as optimizer
+from onnxscript import ir
+from onnxscript.optimizer import _constant_folding
+
+
+class FoldConstantsTest(unittest.TestCase):
+ def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs):
+ if isinstance(model, str):
+ model = ir.from_onnx_text(model)
+ _constant_folding.fold_constants(
+ model, onnx_shape_inference=onnx_shape_inference, **kwargs
+ )
+ optimizer.remove_unused_nodes(model)
+ # Ensure the model is valid after optimization
+ onnx.checker.check_model(ir.serde.serialize_model(model))
+ return model
+
+ def test_fold_add(self):
+ model = """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, four)
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(optimized.graph[0].outputs[0].name, "four")
+
+ def test_fold_cast_like(self):
+ model = """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ two_float = CastLike(two, x)
+ four = Add(two_float, two_float)
+ z = Mul(x, four)
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(optimized.graph[0].outputs[0].name, "four")
+
+ def test_fold_shape(self):
+ model = """
+
+ agraph (float[16, 16] x) => (float[16, 16] z) {
+ shape = Shape(x)
+ rank = Size(shape)
+ two_float = CastLike(rank, x)
+ four = Add(two_float, two_float)
+ z = Mul(x, four)
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(optimized.graph[0].outputs[0].name, "four")
+
+ def test_fold_shape_slice(self):
+ model = """
+
+ agraph (float[M, N, 16, 16] x) => (float[M, N, 16, 16] z) {
+ shape = Shape (x)
+ two = Size(shape)
+ two_float = CastLike(two, x)
+ four = Add(two_float, two_float)
+ z = Mul(x, four)
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(optimized.graph[0].outputs[0].name, "four")
+
+ def test_fold_if_cond(self):
+ model = """
+
+ agraph (float[16, 16] x) => (float[16, 16] z) {
+ shape = Shape(x)
+ rank = Size(shape)
+ zero = Constant ()
+ zero_cast = CastLike (zero, rank)
+ is_scalar = Equal(zero_cast, rank)
+ z = If (is_scalar) <
+ then_branch = then_graph () => (then_z) { then_z = Add (x, x) },
+ else_branch = else_graph () => (else_z) { else_z = Mul (x, x) }
+ >
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertEqual(optimized.graph[0].outputs[0].name, "z")
+ self.assertEqual(optimized.graph[0].op_type, "Mul")
+
+ def test_fold_inside_if_branch(self):
+ model = """
+
+ agraph (float[16, 16] x, bool cond) => (float[16, 16] z) {
+ two = Constant ()
+ z = If (cond) <
+ then_branch = then_graph () => (then_z) {
+ three = Constant ()
+ temp = Add (two, three)
+ then_z = Mul (temp, x)
+ },
+ else_branch = else_graph () => (else_z) {
+ four = Constant ()
+ temp = Add (two, four)
+ else_z = Mul (temp, x)
+ }
+ >
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 1)
+ then_graph = optimized.graph[0].attributes["then_branch"].as_graph()
+ self.assertEqual(len(then_graph), 2)
+ else_graph = optimized.graph[0].attributes["else_branch"].as_graph()
+ self.assertEqual(len(else_graph), 2)
+
+ def test_fold_if_propagate(self):
+ model = """
+
+ agraph (float[16, 16] x) => (float[16, 16] z) {
+ shape = Shape(x)
+ rank = Size(shape)
+ zero = Constant ()
+ two = Constant ()
+ zero_cast = CastLike (zero, rank)
+ is_scalar = Equal(zero_cast, rank)
+ m = If (is_scalar) <
+ then_branch = then_graph () => (then_z) { then_z = Add (x, x) },
+ else_branch = else_graph () => (else_z) { else_z = Mul (two, two) }
+ >
+ m_square = Mul (m, m)
+ z = Mul (x, m_square)
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(optimized.graph[0].outputs[0].name, "m_square")
+ self.assertEqual(optimized.graph[0].op_type, "Constant")
+
+ def test_fold_redundant_cast(self):
+ model = """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ x_cast = CastLike(x, two)
+ z = Mul(x_cast, two)
+ }
+ """
+
+ optimized = self._fold(model, onnx_shape_inference=True)
+ self.assertEqual(len(optimized.graph), 2)
+
+ def test_fold_redundant_cast2(self):
+ model = """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ z = CastLike(x, two)
+ }
+ """
+
+ optimized = self._fold(model, onnx_shape_inference=True)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertEqual(optimized.graph[0].op_type, "Identity")
+ self.assertEqual(optimized.graph[0].outputs[0].name, "z")
+ self.assertEqual(optimized.graph[0].inputs[0].name, "x")
+
+ def test_shape_inference(self):
+ model = """
+
+ agraph (int64[64] x) => (int64[N] z) {
+ one = Constant ()
+ cond = Equal(one, one)
+ temp = If (cond) <
+ then_branch = then_graph () => (then_z) {
+ shape1 = Constant ()
+ then_z = Reshape(x, shape1)
+ },
+ else_branch = else_graph () => (else_z) {
+ shape2 = Constant ()
+ else_z = Reshape(x, shape2)
+ }>
+ shape = Shape(temp) # shape = [8, 8] or [64], but [8, 8] after constant propagation
+ rank = Size(shape) # rank = 2 or 1, but 2 after constant propagation
+ C = Add (rank, rank)
+ z = Mul(x, C)
+ }
+ """
+
+ optimized = self._fold(model, onnx_shape_inference=True)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(optimized.graph[0].outputs[0].name, "C")
+
+ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(
+ self,
+ ):
+ model = """
+<
+ ir_version: 8,
+ opset_import: ["" : 18]
+>
+func (float[1,512] x) => (float[1,512] return_val) {
+ int64_128 = Constant ()
+ splits = SplitToSequence (x, int64_128)
+ int64_0 = Constant ()
+ split_0 = SequenceAt (splits, int64_0)
+ int64_1 = Constant ()
+ split_1 = SequenceAt (splits, int64_1)
+ int64_2 = Constant ()
+ split_2 = SequenceAt (splits, int64_2)
+ int64_3 = Constant ()
+ split_3 = SequenceAt (splits, int64_3)
+ return_val = Concat (split_0, split_1, split_2, split_3)
+}"""
+
+ # TODO: There is an unrelated limitation that `symbolic_value` is not
+ # utilized when the value is only referenced by graph output.
+ # E.g., the following test model will not have this optimization
+ # applied.
+ #
+ # <
+ # ir_version: 8,
+ # opset_import: ["" : 18]
+ # >
+ # func (float[1,512] x) => ( split_0, split_1, split_2, split_3) {
+ # int64_128 = Constant ()
+ # splits = SplitToSequence (x, int64_128)
+ # int64_0 = Constant ()
+ # split_0 = SequenceAt (splits, int64_0)
+ # int64_1 = Constant ()
+ # split_1 = SequenceAt (splits, int64_1)
+ # int64_2 = Constant ()
+ # split_2 = SequenceAt (splits, int64_2)
+ # int64_3 = Constant ()
+ # split_3 = SequenceAt (splits, int64_3)
+ # }
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(len(optimized.graph[-2].outputs), 4)
+ self.assertEqual(optimized.graph[-2].op_type, "Split")
+
+ def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split(
+ self,
+ ):
+ model = """
+<
+ ir_version: 8,
+ opset_import: ["" : 18]
+>
+func (float[1,512] x) => (float[1,N] return_val) {
+ const = Constant ()
+ splits = SplitToSequence (x, const)
+ int64_0 = Constant ()
+ split_0 = SequenceAt (splits, int64_0)
+ int64_1 = Constant ()
+ split_1 = SequenceAt (splits, int64_1)
+ int64_2 = Constant ()
+ split_2 = SequenceAt (splits, int64_2)
+ return_val = Concat (split_0, split_1, split_2)
+}"""
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 3)
+ self.assertEqual(len(optimized.graph[-2].outputs), 3)
+ self.assertEqual(optimized.graph[-2].op_type, "Split")
+
+ def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze(
+ self,
+ ):
+ model = """
+<
+ ir_version: 8,
+ opset_import: ["" : 18]
+>
+func (float[1,3] x) => (float[1,3] return_val) {
+ const = Constant ()
+ splits = SplitToSequence (x, const)
+ int64_0 = Constant ()
+ split_0 = SequenceAt (splits, int64_0)
+ int64_1 = Constant ()
+ split_1 = SequenceAt (splits, int64_1)
+ int64_2 = Constant ()
+ split_2 = SequenceAt (splits, int64_2)
+ return_val = Concat (split_0, split_1, split_2)
+}"""
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 7)
+ self.assertEqual(len(optimized.graph[1].outputs), 3)
+ self.assertEqual(optimized.graph[1].op_type, "Split")
+ self.assertEqual(len([n for n in optimized.graph if n.op_type == "Squeeze"]), 3)
+
+ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0(
+ self,
+ ):
+ model = """
+<
+ ir_version: 8,
+ opset_import: ["" : 18]
+>
+func (float[1,3] x) => (float[1,3] return_val) {
+ const = Constant ()
+ splits = SplitToSequence (x, const)
+ return_val = ConcatFromSequence (splits)
+}"""
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 3)
+ self.assertEqual(optimized.graph[2].op_type, "Concat")
+
+ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
+ self,
+ ):
+ model = """
+<
+ ir_version: 8,
+ opset_import: ["" : 18]
+>
+func (float[1,3] x) => (float[1,3] return_val) {
+ const = Constant ()
+ splits = SplitToSequence (x, const)
+ return_val = ConcatFromSequence (splits)
+}"""
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 7)
+ self.assertEqual(optimized.graph[6].op_type, "Concat")
+
+ def test_dynamic_split_to_sequence_list_shape_rewrite(self):
+ # split is a graph input with known 1-D static shape [4]; values unknown (not constant)
+ # Ensures the branch: if isinstance(split_shape, tuple) and len(split_shape) == 1
+ model = """
+<
+ ir_version: 8,
+ opset_import: ["" : 18]
+>
+func (float[2,N] x, int64[4] split) => (float[2,N] return_val) {
+ splits = SplitToSequence (x, split)
+ i0 = Constant ()
+ s0 = SequenceAt (splits, i0)
+ i1 = Constant ()
+ s1 = SequenceAt (splits, i1)
+ i2 = Constant ()
+ s2 = SequenceAt (splits, i2)
+ i3 = Constant ()
+ s3 = SequenceAt (splits, i3)
+ return_val = Concat (s0, s1, s2, s3)
+}"""
+ optimized = self._fold(model)
+ # Expect: Split + Concat (index constants & SequenceAt removed)
+ split_nodes = [n for n in optimized.graph if n.op_type == "Split"]
+ self.assertEqual(len(split_nodes), 1)
+ self.assertEqual(len(split_nodes[0].outputs), 4)
+ self.assertEqual(split_nodes[0].op_type, "Split")
+ self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph))
+
+ def test_dynamic_split_to_sequence_list_shape_no_keepdims(self):
+ # keepdims=0 path with dynamic (non-constant) splits input; triggers squeeze logic.
+ model = """
+<
+ ir_version: 8,
+ opset_import: ["" : 18]
+>
+func (float[1,M] x, int64[3] split) => (float[1,M] return_val) {
+ splits = SplitToSequence (x, split)
+ i0 = Constant ()
+ s0 = SequenceAt (splits, i0)
+ i1 = Constant ()
+ s1 = SequenceAt (splits, i1)
+ i2 = Constant ()
+ s2 = SequenceAt (splits, i2)
+ return_val = Concat (s0, s1, s2)
+}"""
+ optimized = self._fold(model)
+ split_nodes = [n for n in optimized.graph if n.op_type == "Split"]
+ self.assertEqual(len(split_nodes), 1)
+ self.assertEqual(len(split_nodes[0].outputs), 3)
+ self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph))
+ # Each split output should have a corresponding Squeeze (keepdims=0 branch)
+ squeeze_nodes = [n for n in optimized.graph if n.op_type == "Squeeze"]
+ self.assertEqual(len(squeeze_nodes), 3)
+
+ def test_initializer_input_not_folded(self):
+ model_text = """
+
+ agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z)
+ {
+ # c is not a constant, and following should not be folded.
+ two_c = Add (c, c)
+ z = Mul (x, two_c)
+ }"""
+ optimized = self._fold(model_text)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual(optimized.graph.node(0).op_type, "Add")
+
+ @parameterized.parameterized.expand(
+ [
+ ("output = Dropout(input)",),
+ ("output = Dropout(input, zero, true)",),
+ ("output = Dropout(input, half)",),
+ ("output = Dropout(input, half, false)",),
+ ]
+ )
+ def test_dropout_identity(self, dropout_node: str):
+ model = f"""
+
+ agraph (float[N] input) => (float[N] output)
+
+ {{
+ {dropout_node}
+ }}
+ """
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertEqual(optimized.graph.node(0).op_type, "Identity")
+
+ @parameterized.parameterized.expand(
+ [
+ ("output, mask = Dropout(input)",),
+ ("output, mask = Dropout(input, zero, true)",),
+ ("output, mask = Dropout(input, half)",),
+ ("output, mask = Dropout(input, half, false)",),
+ ]
+ )
+ def test_dropout_identity_mask(self, dropout_node: str):
+ model = f"""
+
+ agraph (float[N] input) => (float[N] output, bool[N] mask)
+
+ {{
+ {dropout_node}
+ }}
+ """
+ optimized = self._fold(model)
+ nodes = list(optimized.graph)
+ self.assertEqual(len(nodes), 3)
+ ops = [node.op_type for node in nodes]
+ self.assertEqual(ops, ["Identity", "Shape", "ConstantOfShape"])
+
+ def test_concat_identity(self):
+ model = """
+
+ agraph (float[N] x) => (float[N] z)
+ {
+ z = Concat (x)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertEqual(optimized.graph.node(0).op_type, "Identity")
+
+ def test_concat_zero_length(self):
+ model = """
+
+ agraph (float[N, 128] x1, float[N, 0] x2, float[N, 128] x3) => (float[N, M] z)
+ {
+ z = Concat (x1, x2, x3)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x1", "x3"])
+
+ def test_concat_zero_length_identity(self):
+ model = """
+
+ agraph (float[N, 0] x1, float[N, 128] x2, float[N, 0] x3) => (float[N, M] z)
+ {
+ z = Concat (x1, x2, x3)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertEqual(optimized.graph.node(0).op_type, "Identity")
+ self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x2"])
+
+ def test_concat_zero_length_output(self):
+ model = """
+
+ agraph (float[N, 0] x1, float[N, 0] x2, float[N, 0] x3) => (float[N, M] z)
+ {
+ z = Concat (x1, x2, x3)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertEqual(optimized.graph.node(0).op_type, "Identity")
+ self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x1"])
+
+ def test_expand_identity(self):
+ model = """
+
+ agraph (float[128, 256] x) => (float[128, 256] z)
+ {
+ shape = Constant ()
+ z = Expand (x, shape)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
+
+ def test_expand_identity_symdim(self):
+ model = """
+
+ agraph (float[B, 256] x) => (float[B, 256] z)
+ {
+ b = Shape (x)
+ const_256 = Constant ()
+ shape = Concat (b, const_256)
+ z = Expand (x, shape)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
+
+ def test_abs_symdim(self):
+ model = """
+
+ agraph (float[B, 256] x) => (float[B, 256] z)
+ {
+ b = Shape (x)
+ const_256 = Constant ()
+ b_256 = Concat (b, const_256)
+ shape = Abs (b_256)
+ z = Expand (x, shape)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
+
+ def test_reshape_identity(self):
+ model = """
+
+ agraph (float[128, 256] x) => (float[128, 256] z)
+ {
+ shape = Constant ()
+ z = Reshape (x, shape)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
+
+ def test_reshape_identity_symdim(self):
+ model = """
+
+ agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
+ {
+ b = Shape (y)
+ const_256 = Constant ()
+ shape = Concat (b, const_256)
+ z = Reshape (x, shape)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
+
+ def test_gather_symdim(self):
+ model = """
+
+ agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
+ {
+ b_128 = Shape (y)
+ index_0 = Constant ()
+ b = Gather (b_128, index_0)
+ const_256 = Constant ()
+ shape = Concat (b, const_256)
+ z = Reshape (x, shape)
+ }
+ """
+ optimized = self._fold(model)
+ self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
+
+ def test_input_size_limit(self):
+ model_text = """
+
+ agraph (float[M, 256] x) => (float[M, 256] z)
+ # placeholder for large initializer of shape [256, 256]
+ {
+ w_squared = Mul (w, w)
+ z = Add (x, w_squared)
+ }
+ """
+ model = ir.from_onnx_text(model_text)
+ w = model.graph.initializers["w"]
+ w.shape = ir.Shape([256, 256])
+ w.const_value = ir.tensor(np.random.random((256, 256)).astype(np.float32))
+
+ # Input size limit will prevent folding of Mul op
+ optimized = self._fold(model, onnx_shape_inference=False, input_size_limit=128 * 128)
+ ops = [node.op_type for node in optimized.graph]
+ self.assertEqual(ops, ["Mul", "Add"])
+
+ # Input size limit will allow folding of Mul op
+ # Since there is no increase in model-size, output-size is not a concern.
+ optimized = self._fold(model, input_size_limit=256 * 256, output_size_limit=256 * 256)
+ ops = [node.op_type for node in optimized.graph]
+ self.assertEqual(ops, ["Constant", "Add"])
+
+ def test_transpose_is_always_folded(self):
+ model_text = """
+
+ agraph (float[M, 256] x) => (float[M, 512] z)
+ # placeholder for large initializer of shape [512, 256]
+ {
+ z = Transpose (w)
+ }
+ """
+ model = ir.from_onnx_text(model_text)
+ w = model.graph.initializers["w"]
+ w.shape = ir.Shape([512, 256])
+ w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32))
+
+ # Input size limit will not prevent folding of Transpose op
+ optimized = self._fold(model, input_size_limit=1)
+ ops = [node.op_type for node in optimized.graph]
+ self.assertEqual(ops, ["Constant"])
+
+ def test_node_is_folded_if_specified_as_should_fold(self):
+ model_text = """
+
+ agraph (float[M, 256] x) => (float[42, 42] z)
+
+ {
+ z = ConstantOfShape (w)
+ }
+ """
+ model = ir.from_onnx_text(model_text)
+
+ # ConstantOfShape is not folded by default
+ optimized = self._fold(model)
+ ops = [node.op_type for node in optimized.graph]
+ self.assertEqual(ops, ["ConstantOfShape"])
+
+ # But ConstantOfShape is folded when specified in should_fold
+ optimized = self._fold(
+ model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None
+ )
+ ops = [node.op_type for node in optimized.graph]
+ self.assertEqual(ops, ["Constant"])
+ np.testing.assert_array_equal(
+ optimized.graph.node(0).attributes["value"].as_tensor().numpy(),
+ np.ones((42, 42), dtype=np.int64),
+ )
+
+ def test_multi_graph_identity_output_preserves_output_name(self):
+ model = """
+
+ agraph (float[N] x) => (float[N] graph_output1, float[N] graph_output2) {
+ t = Identity(x)
+ graph_output1 = Identity(t)
+ graph_output2 = Identity(t)
+ }"""
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+ self.assertEqual([n.op_type for n in optimized.graph], ["Identity", "Identity"])
+ self.assertEqual(
+ [n.outputs[0].name for n in optimized.graph], ["graph_output1", "graph_output2"]
+ )
+ self.assertEqual([input.name for input in optimized.graph.inputs], ["x"])
+
+ # This should not be constant-foldable as the constant references an
+ # attribute and thus the shape cannot be resolved. At the same time it
+ # should not fail due to the attribute value being None in
+ # _process_constant_node
+ def test_attribute_reference(self):
+ model = """
+
+ agraph () => (int64[N] z) {
+ x = Constant ()
+ z = Shape (x)
+ }
+ """
+
+ optimized = self._fold(model)
+ self.assertEqual(len(optimized.graph), 2)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxscript/optimizer/_function_folding_test.py b/onnxscript/optimizer/_function_folding_test.py
new file mode 100644
index 0000000000..6f2b052b9e
--- /dev/null
+++ b/onnxscript/optimizer/_function_folding_test.py
@@ -0,0 +1,159 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+import unittest
+
+import onnx
+
+import onnxscript.testing
+from onnxscript import ir, optimizer
+
+
+def _create_model(model_text: str) -> ir.Model:
+ """Create a model from the given text."""
+ model = onnx.parser.parse_model(model_text)
+ return ir.serde.deserialize_model(model)
+
+
+class FunctionFoldingTest(unittest.TestCase):
+ def test_identity(self):
+ model = _create_model(
+ """
+
+ agraph (float[N] x1, bool cond1) => (float[N] z1) {
+ z1 = local.fun1(x1, cond1)
+ }
+
+ fun1 (x, cond) => (z) {
+ t = Identity(x)
+ t2 = Identity(t)
+ t3 = If (cond) <
+ then_branch = then_graph() => (t4) {
+ t5 = Identity(t2)
+ t4 = Identity(t5)
+ },
+ else_branch = else__graph() => (t6) {
+ t7 = Identity(t)
+ t6 = Identity(t7)
+ }
+ >
+ t4 = Add(t3, t3)
+ z = Identity(t4)
+ }"""
+ )
+ optimized = optimizer.optimize(
+ model, onnx_shape_inference=False, num_iterations=1, inline=True
+ )
+ self.assertEqual(len(optimized.functions), 0)
+ self.assertEqual(len(optimized.graph), 2)
+
+ def test_sequence_concat(self):
+ model = _create_model(
+ """
+
+ agraph (float[N] x1) => (float[M] z1) {
+ z1 = local.fun1(x1)
+ }
+
+ fun1 (x) => (z) {
+ t0 = Add (x, x)
+ t2 = Add (x, x)
+ t3 = SequenceConstruct (x, t0, t2, x)
+ z = ConcatFromSequence (t3)
+ }"""
+ )
+ optimized = optimizer.optimize(
+ model, onnx_shape_inference=False, num_iterations=1, inline=False
+ )
+ function = optimized.functions[("local", "fun1", "")]
+ self.assertEqual(len(function), 3)
+ self.assertEqual(function[2].op_type, "Concat")
+
+ def test_sequence_at(self):
+ model = _create_model(
+ """
+
+ agraph (float[N] x) => (float[M] z) {
+ t0 = Add (x, x)
+ t1 = Mul (x, x)
+ s = SequenceConstruct (x, t0, t1)
+ one = Constant ()
+ z = SequenceAt (s, one)
+ }"""
+ )
+ optimized = optimizer.optimize(
+ model, onnx_shape_inference=False, num_iterations=1, inline=False
+ )
+ expected = _create_model(
+ """
+
+ agraph (float[N] x) => (float[M] z) {
+ z = Add (x, x)
+ }"""
+ )
+ # TODO(justinchuby): Implement assert_isomorphic_graph for IR objects
+ onnxscript.testing.assert_isomorphic_graph(
+ ir.to_proto(optimized.graph), ir.to_proto(expected.graph)
+ )
+
+ def test_single_user_function_is_modified_inplace_after_folding(self):
+ model = _create_model(
+ """
+
+ agraph (float[N] x1) => (float[M] z1) {
+ z1 = local.fun1(x1)
+ }
+
+ fun1 (x) => (z) {
+ t0 = Add (x, x)
+ t2 = Add (x, x)
+ t3 = SequenceConstruct (x, t0, t2, x)
+ z = ConcatFromSequence (t3)
+ }"""
+ )
+ optimized = optimizer.optimize(
+ model, onnx_shape_inference=False, num_iterations=1, inline=False
+ )
+ self.assertEqual(next(iter(optimized.functions.values())).name, "fun1")
+
+ def test_fold_nested_if_function_succeeds(self):
+ model = _create_model(
+ """
+ <
+ ir_version: 9,
+ opset_import: ["this" : 1, "" : 18]
+ >
+ func (float[1,512] x, float[1,512] y) => ( out) {
+ out = this.foldable_func (x, y)
+ }
+ <
+ domain: "this",
+ opset_import: ["" : 18]
+ >
+ foldable_func (x, y) => (z_6)
+ {
+ cond = Constant ()
+ z_6 = If (cond) ( z_2) {
+ cond_0 = Not (cond)
+ z_2 = If (cond_0) ( z) {
+ z = Add (x, x)
+ }, else_branch: graph = elseGraph_5 () => ( z_1) {
+ z_1 = Identity (x)
+ }>
+ }, else_branch: graph = elseGraph_4 () => ( z_5) {
+ z_5 = If (cond) ( z_3) {
+ z_3 = Add (y, y)
+ }, else_branch: graph = elseGraph_10 () => ( z_4) {
+ z_4 = Add (x, y)
+ }>
+ }>
+ }"""
+ )
+ optimized = optimizer.optimize(model, onnx_shape_inference=False, inline=True)
+
+ self.assertEqual(len(optimized.functions), 0)
+ self.assertEqual(len(optimized.graph), 1)
+ self.assertNotIn("If", {n.op_type for n in optimized.graph})
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py
new file mode 100644
index 0000000000..307144462f
--- /dev/null
+++ b/onnxscript/optimizer/_optimizer.py
@@ -0,0 +1,74 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+
+import logging
+from typing import Callable
+
+import onnx_ir as ir
+import onnx_ir.passes.common as common_passes
+
+from onnxscript import rewriter
+from onnxscript.optimizer import _constant_folding
+
+logger = logging.getLogger(__name__)
+
+
+def optimize_ir(
+ model: ir.Model,
+ num_iterations: int = 2,
+ *,
+ onnx_shape_inference: bool = True,
+ stop_if_no_change: bool = True,
+ input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
+ output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
+ should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
+ inline: bool = True,
+) -> None:
+ """Optimizes a model.
+
+ Args:
+ model: The model to be optimized.
+ num_iterations: Number of times the optimization loop is repeated.
+ onnx_shape_inference: Applies node-level shape-inference as part of optimization
+ stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
+ input_size_limit: Will not apply constant folding to ops with any input of size
+ greater than this. Does not apply to special ops like Shape() and Size().
+ output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
+ of the output tensor is greater than this.
+ should_fold: An optional function that takes a node and returns True if
+ the node should be considered for folding.
+ The function should return True/False value to indicate if this particular
+ node should be folded, or None to use the default folding rules.
+ inline: If True, inlines all functions in the model.
+ """
+ passes = [
+ ir.passes.PassManager(
+ [
+ _constant_folding.FoldConstantsPass(
+ shape_inference=onnx_shape_inference,
+ input_size_limit=input_size_limit,
+ output_size_limit=output_size_limit,
+ should_fold=should_fold,
+ ),
+ rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
+ common_passes.RemoveUnusedNodesPass(),
+ common_passes.RemoveUnusedFunctionsPass(),
+ common_passes.RemoveUnusedOpsetsPass(),
+ ],
+ steps=num_iterations,
+ early_stop=stop_if_no_change,
+ ),
+ common_passes.RemoveUnusedNodesPass(),
+ common_passes.LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=0),
+ common_passes.LiftSubgraphInitializersToMainGraphPass(),
+ common_passes.DeduplicateInitializersPass(),
+ common_passes.CommonSubexpressionEliminationPass(),
+ ]
+ if inline:
+ # Inline all functions first before optimizing
+ passes = [common_passes.InlinePass(), *passes]
+ optimizer_pass = ir.passes.Sequential(*passes)
+ assert optimizer_pass.in_place
+ result = optimizer_pass(model)
+ assert result.model is model
diff --git a/onnxscript/optimizer/_optimizer_test.py b/onnxscript/optimizer/_optimizer_test.py
new file mode 100644
index 0000000000..0aed7f57ca
--- /dev/null
+++ b/onnxscript/optimizer/_optimizer_test.py
@@ -0,0 +1,83 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import unittest
+
+import onnx
+import onnx_ir as ir
+
+import onnxscript.optimizer as optimizer
+
+
+class OptimizerTest(unittest.TestCase):
+ def _model_proto(self) -> onnx.ModelProto:
+ return onnx.parser.parse_model(
+ """
+ <
+ ir_version: 8,
+ opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1],
+ producer_name: "pytorch",
+ producer_version: "2.2.0"
+ >
+ main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val)
+ < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1>
+ {
+ _val_1 = Constant ()
+ _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1)
+ _val_3 = Constant ()
+ getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3)
+ _val_5 = Constant ()
+ getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5)
+ return_val = Concat (getitem_1, getitem)
+ }
+
+
+ aten_split (self, split_size) => (return_val)
+ {
+ return_val = SplitToSequence (self, split_size)
+ }
+
+
+ aten_getitem (self, i) => (return_val)
+ {
+ return_val = SequenceAt (self, i)
+ }
+
+
+ Rank (input) => (return_val)
+ {
+ tmp = Shape (input)
+ return_val = Size (tmp)
+ }
+
+
+ IsScalar (input) => (return_val)
+ {
+ tmp = Shape (input)
+ tmp_0 = Size (tmp)
+ tmp_1 = Constant ()
+ return_val = Equal (tmp_0, tmp_1)
+ }
+ """
+ )
+
+ def test_static_split_to_sequence_with_uneven_split_proto(self):
+ model_proto = self._model_proto()
+ optimized = optimizer.optimize(
+ model_proto, num_iterations=1, onnx_shape_inference=False
+ )
+ self.assertEqual(len(optimized.graph.node), 2)
+ self.assertEqual(len(optimized.graph.node[0].output), 2)
+ self.assertEqual(optimized.graph.node[0].op_type, "Split")
+
+ def test_static_split_to_sequence_with_uneven_split_ir(self):
+ model_proto = self._model_proto()
+ model_ir = ir.serde.deserialize_model(model_proto)
+ optimizer.optimize_ir(model_ir, num_iterations=1, onnx_shape_inference=False)
+ self.assertEqual(len(model_ir.graph), 2)
+ self.assertEqual(len(model_ir.graph.node(0).outputs), 2)
+ self.assertEqual(model_ir.graph.node(0).op_type, "Split")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py
deleted file mode 100644
index 283a13fd13..0000000000
--- a/onnxscript/optimizer/constant_folding.py
+++ /dev/null
@@ -1,283 +0,0 @@
-from __future__ import annotations
-
-import logging
-from typing import Any, Sequence
-
-import numpy as np
-import onnx
-import onnx.reference.ops
-
-import onnxscript._legacy_ir as ir
-from onnxscript._legacy_ir import visitor
-from onnxscript.optimizer import evaluator
-from onnxscript.utils.utils import (
- is_control_flow_op,
- is_onnx_domain,
-)
-
-logger = logging.getLogger(__name__)
-
-_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024
-
-# Ops excluded from constant-propagation:
-# * Random ops, which are not deterministic (checked below)
-# * Control flow ops (checked by presence of graph-attribute)
-
-non_deterministic_ops = frozenset(
- {
- "RandomUniform",
- "RandomNormal",
- "RandomUniformLike",
- "RandomNormalLike",
- "Multinomial",
- }
-)
-
-onnx_domain = frozenset({"", "onnx.ai"})
-
-
-def is_non_deterministic_op(node: onnx.NodeProto) -> bool:
- return node.op_type in non_deterministic_ops and is_onnx_domain(node.domain)
-
-
-def is_constant_op(node: onnx.NodeProto) -> bool:
- return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(node.domain)
-
-
-class ConstantFolder(visitor.FunctionCallsiteProtoTransformer):
- def __init__(
- self,
- registry: evaluator.PartialEvaluatorRegistry,
- external_data_folder: str,
- *,
- do_shape_inference: bool,
- ) -> None:
- self.registry = registry
- # TODO: make evaluator a parameter
- self.evaluate = evaluator.reference_evaluator.evaluate
- self._do_shape_inference = do_shape_inference
- self._init()
- super().__init__(external_data_folder, do_shape_inference=do_shape_inference)
-
- def _init(self) -> None:
- self.counts = {}
- self.sizes = {}
-
- def add_count(self, op: str, size: int = 1):
- self.counts[op] = self.counts.get(op, 0) + 1
- self.sizes[op] = self.sizes.get(op, 0) + size
-
- def foldable_value(self, name: str, value):
- """Checks if a runtime-constant can and should be folded into the graph.
-
- We fold constants only if they are tensors (not lists of tensors, for example)
- and have size below desired limit.
- """
- if value is ir.NotConstant:
- return None
-
- if not isinstance(value, np.ndarray):
- # ONNX does not have a way to represent non-tensor constants, eg. a sequence.
- # So, a constant-value of type sequence is not folded, but it can be used
- # to optimize subsequent operations when possible.
- logger.warning(
- "Skip storing constant folded value %s due to unsupported type %s.",
- name,
- type(value),
- )
- return None
-
- if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT:
- logger.warning(
- "Skip storing constant folded nvalue %s due to large size %s.",
- name,
- value.nbytes,
- )
- return None
-
- return onnx.numpy_helper.from_array(value, name)
-
- def new_constant(self, name, value):
- if isinstance(value, (int, float, np.ScalarType)):
- value = np.array(value)
-
- info = self.lookup_or_create(name)
- info.value = value
-
- tensor = self.foldable_value(name, value)
- if tensor is None:
- return None
-
- logger.debug(
- "New constant for value %s dtype: %s shape: %s",
- name,
- value.dtype,
- value.shape,
- )
- info.type = onnx.helper.make_tensor_type_proto(
- onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape
- )
- node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor)
- return [node]
-
- def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]:
- if self.scopes.current_scope().current_function_scope():
- # Need to resolve ref_attr_name if inside a function.
- attr_dict = {}
- for attribute in attributes:
- concrete_attribute = (
- self.lookup_ref_attribute(attribute.ref_attr_name)
- if attribute.ref_attr_name
- else attribute
- )
- if concrete_attribute is None:
- continue
- attr_dict[attribute.name] = onnx.helper.get_attribute_value(concrete_attribute)
- return attr_dict
- return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes}
-
- def replace_copy(self, node: onnx.NodeProto) -> None:
- for i in range(len(node.input)):
- input = self.get_input(node, i)
- if input is not None and input.is_copy():
- old_value = self.lookup_or_create(input.name)
- assert isinstance(input.symbolic_value, str)
- new_value = self.lookup_or_create(input.symbolic_value)
- # Merge meta info. It is important to do if the new value
- # is created by evaluator, and thus carries zero meta info.
- # Since this is a copy, the meta info should be the same.
- new_value.identity_merge_from(old_value)
- node.input[i] = input.symbolic_value
-
- def process_function_outputs(self, function: onnx.FunctionProto) -> bool:
- # Resolve copy for function subgraph output.
- # Avoid copy of function subgraph input, because it is illegal for a direct edge
- # from function input to function output.
- prohibited_value_set = set(function.input)
- updated = False
- for i, output_name in enumerate(function.output):
- output = self.lookup(output_name)
- if (
- output is not None
- and output.is_copy()
- and output.symbolic_value not in prohibited_value_set
- ):
- old_value = self.lookup_or_create(output.name)
- assert isinstance(output.symbolic_value, str)
- new_value = self.lookup_or_create(output.symbolic_value)
- new_value.identity_merge_from(old_value)
- function.output[i] = output.symbolic_value
- updated = True
- return updated
-
- def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
- self.replace_copy(node)
-
- super().process_node(node)
-
- inputs = [self.lookup(x) for x in node.input]
- attrs = self.convert_attributes(node.attribute)
-
- domain = node.domain
- op = node.op_type
- version = self.lookup_version(domain)
-
- # if any(x is Undefined for x in inputs):
- # return None
- # Above check ensures that none of the optimizations below need to handle
- # undefined inputs
-
- op_optimizers = self.registry.lookup_evaluators(domain, op, version)
- for optimizer in op_optimizers:
- assert optimizer
- output = optimizer(self, node)
- if output is None:
- continue
- if isinstance(output, list):
- return output
- else:
- # Currently handles single output only
- self.add_count(node.op_type, output.size)
- return self.new_constant(node.output[0], output)
-
- if is_control_flow_op(node) or is_non_deterministic_op(node):
- return None
-
- input_values = [x.value if x is not None else None for x in inputs]
- if any(x is ir.NotConstant for x in input_values):
- return None
-
- outputs = self.evaluate(domain, op, version, *input_values, **attrs)
- # TODO: what if evaluated value is None?
- if outputs is None:
- return None
- if len(node.output) == 1 and not isinstance(outputs, (tuple, list)):
- replacement = self.new_constant(node.output[0], outputs)
- if is_constant_op(node):
- return None
- self.add_count(op, outputs.size)
- return replacement
- else:
- logger.warning("Skipping constant folding for op %s with multiple outputs.", op)
- return None
-
- def process_function_node(
- self, node: onnx.NodeProto
- ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]:
- self.replace_copy(node)
-
- _, new_function = super().process_function_node(node)
-
- # Replace function node with Constant if all outputs are constants
- ir_values = [self.lookup(output_name) for output_name in node.output]
- tensors = [
- self.foldable_value(output_name, ir_value.value if ir_value is not None else None)
- for output_name, ir_value in zip(node.output, ir_values)
- ]
- if all(tensor is not None for tensor in tensors):
- replacements = []
- for output_name, tensor in zip(node.output, tensors):
- newnode = onnx.helper.make_node(
- "Constant", inputs=[], outputs=[output_name], value=tensor
- )
- replacements.append(newnode)
- logger.debug(
- "Function node replacements: node %s %s (%s/%s)",
- node.name,
- [replacement.output for replacement in replacements],
- len(replacements),
- len(node.output),
- )
- return replacements, new_function
- return None, new_function
-
- def visit_model(self, model: onnx.ModelProto) -> None:
- self._init()
-
- super().visit_model(model)
-
-
-def fold_constants(
- model: onnx.ModelProto,
- external_data_folder: str = "",
- *,
- onnx_shape_inference: bool = False,
-) -> bool:
- """
- Applies constant folding optimization to the model.
- Returns true iff the model was modified.
- """
- folder = ConstantFolder(
- evaluator.registry,
- external_data_folder,
- do_shape_inference=onnx_shape_inference,
- )
- folder.visit_model(model)
- for op in folder.counts:
- logger.info(
- "Constant-folded '%s' %s times, with %s size.",
- op,
- folder.counts[op],
- folder.sizes[op],
- )
- return folder.modified
diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/constant_folding_test.py
deleted file mode 100644
index 64a27e33de..0000000000
--- a/onnxscript/optimizer/constant_folding_test.py
+++ /dev/null
@@ -1,444 +0,0 @@
-import unittest
-
-import onnx
-import pytest
-
-from onnxscript import optimizer
-
-
-class FoldConstantsTest(unittest.TestCase):
- def test_fold_add(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[N] x) => (float[N] z) {
- two = Constant ()
- four = Add(two, two)
- z = Mul(x, four)
- }
- """
- )
- optimized = optimizer.optimize(model, num_iterations=1)
- self.assertEqual(len(optimized.graph.node), 2)
- self.assertEqual(optimized.graph.node[0].output[0], "four")
-
- def test_fold_cast_like(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[N] x) => (float[N] z) {
- two = Constant ()
- two_float = CastLike(two, x)
- four = Add(two_float, two_float)
- z = Mul(x, four)
- }
- """
- )
- optimized = optimizer.optimize(model, num_iterations=1)
- self.assertEqual(len(optimized.graph.node), 2)
- self.assertEqual(optimized.graph.node[0].output[0], "four")
-
- def test_fold_shape(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[16, 16] x) => (float[16, 16] z) {
- shape = Shape(x)
- rank = Size(shape)
- two_float = CastLike(rank, x)
- four = Add(two_float, two_float)
- z = Mul(x, four)
- }
- """
- )
- optimized = optimizer.optimize(model, num_iterations=1)
- self.assertEqual(len(optimized.graph.node), 2)
- self.assertEqual(optimized.graph.node[0].output[0], "four")
-
- def test_fold_shape_slice(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[M, N, 16, 16] x) => (float[M, N, 16, 16] z) {
- shape = Shape (x)
- two = Size(shape)
- two_float = CastLike(two, x)
- four = Add(two_float, two_float)
- z = Mul(x, four)
- }
- """
- )
- optimized = optimizer.optimize(model, num_iterations=1)
- self.assertEqual(len(optimized.graph.node), 2)
- self.assertEqual(optimized.graph.node[0].output[0], "four")
-
- def test_fold_if_cond(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[16, 16] x) => (float[16, 16] z) {
- shape = Shape(x)
- rank = Size(shape)
- zero = Constant ()
- zero_cast = CastLike (zero, rank)
- is_scalar = Equal(zero_cast, rank)
- z = If (is_scalar) <
- then_branch = then_graph () => (then_z) { then_z = Add (x, x) },
- else_branch = else_graph () => (else_z) { else_z = Mul (x, x) }
- >
- }
- """
- )
- optimized = optimizer.optimize(model, num_iterations=1)
- self.assertEqual(len(optimized.graph.node), 1)
- self.assertEqual(optimized.graph.node[0].output[0], "z")
- self.assertEqual(optimized.graph.node[0].op_type, "Mul")
-
- def test_fold_inside_if_branch(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[16, 16] x, bool cond) => (float[16, 16] z) {
- two = Constant ()
- z = If (cond) <
- then_branch = then_graph () => (then_z) {
- three = Constant ()
- temp = Add (two, three)
- then_z = Mul (temp, x)
- },
- else_branch = else_graph () => (else_z) {
- four = Constant ()
- temp = Add (two, four)
- else_z = Mul (temp, x)
- }
- >
- }
- """
- )
- optimized = optimizer.optimize(model, num_iterations=1)
- self.assertEqual(len(optimized.graph.node), 1)
- then_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "then_branch")
- self.assertEqual(len(then_graph.node), 2)
- else_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "else_branch")
- self.assertEqual(len(else_graph.node), 2)
-
- def test_fold_if_propagate(self):
- model = onnx.parser.parse_model(
- """
-
- agraph (float[16, 16] x) => (float[16, 16] z) {
- shape = Shape(x)
- rank = Size(shape)
- zero = Constant