diff --git a/.azure-pipelines/_release-template.yml b/.azure-pipelines/_release-template.yml deleted file mode 100644 index 9bf1d9d0cf..0000000000 --- a/.azure-pipelines/_release-template.yml +++ /dev/null @@ -1,21 +0,0 @@ -# Template steps for the release pipeline - -steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.11' - displayName: 'Set Up Python' - - script: python -m pip install --upgrade pip build wheel - displayName: 'Install Python build dependencies' - - script: python -m build - displayName: 'Build ONNX Script wheel dev version' - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: 'dist' - Contents: '*.*' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 - displayName: 'Publish onnxscript' - inputs: - ArtifactName: onnxscript diff --git a/.azure-pipelines/publish-dev.yml b/.azure-pipelines/publish-dev.yml new file mode 100644 index 0000000000..77968d313b --- /dev/null +++ b/.azure-pipelines/publish-dev.yml @@ -0,0 +1,45 @@ +trigger: none +name: onnxscript-publish-dev.$(Date:yyyyMMdd).$(Rev:r) +resources: + repositories: + - repository: 1ESPipelineTemplates + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + pipelines: + - pipeline: onnxscript-release-dev + source: onnxscript-release-dev + trigger: true +extends: + template: v1/1ES.Official.PipelineTemplate.yml@1ESPipelineTemplates + parameters: + stages: + - stage: Release + dependsOn: [] + jobs: + - job: onnxscript_publish_dev + templateContext: + type: releaseJob + isProduction: true + inputs: + - input: pipelineArtifact + artifactName: drop + pipeline: onnxscript-release-dev + targetPath: $(Pipeline.Workspace)/drop + pool: + name: 'onnxruntime-Win-CPU-2022' + steps: + - task: EsrpRelease@9 + displayName: 'ESRP Release' + inputs: + connectedservicename: esrp_release + keyvaultname: 'ortbuildkeyvault' + signcertname: 'esrpcodesign' + clientid: '53d54d02-978d-4305-8572-583cf6711c4f' + contenttype: PyPi + folderlocation: '$(Pipeline.Workspace)/drop' + owners: 'justinchu@microsoft.com' + approvers: 'grama@microsoft.com' + mainpublisher: AIFrameworks + usemanagedidentity: true + domaintenantid: '975f013f-7f24-47e8-a7d3-abc4752bf346' diff --git a/.azure-pipelines/publish.yml b/.azure-pipelines/publish.yml new file mode 100644 index 0000000000..e37d34a282 --- /dev/null +++ b/.azure-pipelines/publish.yml @@ -0,0 +1,50 @@ +trigger: none +name: onnxscript-publish.$(Date:yyyyMMdd).$(Rev:r) +resources: + repositories: + - repository: 1ESPipelineTemplates + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + pipelines: + - pipeline: onnxscript-release + source: onnxscript-release + trigger: true +extends: + template: v1/1ES.Official.PipelineTemplate.yml@1ESPipelineTemplates + parameters: + stages: + - stage: Release + dependsOn: [] + jobs: + - deployment: onnxscript_publish + templateContext: + type: releaseJob + isProduction: true + inputs: + - input: pipelineArtifact + artifactName: drop + pipeline: onnxscript-release + targetPath: $(Pipeline.Workspace)/drop + environment: + name: 'onnxscript-release' + pool: + name: 'onnxruntime-Win-CPU-2022' + strategy: + runOnce: + deploy: + steps: + - task: EsrpRelease@9 + displayName: 'ESRP Release' + inputs: + connectedservicename: esrp_release + keyvaultname: 'ortbuildkeyvault' + signcertname: 'esrpcodesign' + clientid: '53d54d02-978d-4305-8572-583cf6711c4f' + contenttype: PyPi + folderlocation: '$(Pipeline.Workspace)/drop' + owners: 'justinchu@microsoft.com' + approvers: 'grama@microsoft.com' + mainpublisher: AIFrameworks + usemanagedidentity: true + domaintenantid: '975f013f-7f24-47e8-a7d3-abc4752bf346' diff --git a/.azure-pipelines/release-dev.yml b/.azure-pipelines/release-dev.yml index 81ffa68b3a..61f780ed31 100644 --- a/.azure-pipelines/release-dev.yml +++ b/.azure-pipelines/release-dev.yml @@ -3,9 +3,28 @@ # To configure triggers, see https://github.com/microsoft/onnx-converters-private/wiki/ONNX-Script-release trigger: none -pool: - vmImage: ubuntu-latest variables: CI: 'true' -steps: - - template: _release-template.yml + +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + sdl: + sourceAnalysisPool: + name: onnxruntime-Win-CPU-2022 + os: windows + pool: + name: 'onnxruntime-Ubuntu2204-AMD-CPU' + os: 'linux' + stages: + - template: stages/release-stage.yml diff --git a/.azure-pipelines/release.yml b/.azure-pipelines/release.yml index 130ae5a09c..b5fde4c319 100644 --- a/.azure-pipelines/release.yml +++ b/.azure-pipelines/release.yml @@ -2,19 +2,30 @@ trigger: none -pool: - vmImage: ubuntu-latest variables: CI: 'true' # Set the release environment variable to build a release version of the wheel ONNX_SCRIPT_RELEASE: 1 -steps: - - template: _release-template.yml - # Test the wheels. This needs to happen after PublishBuildArtifacts - # to avoid interference with the artifacts - - script: python -m pip install -r requirements-dev.txt - displayName: 'Install Python dependencies' - - script: python -m pip install dist/*.whl --no-deps - displayName: 'Install wheel' - - script: python -m pytest -v -n auto - displayName: 'Run tests' + +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + sdl: + sourceAnalysisPool: + name: onnxruntime-Win-CPU-2022 + os: windows + pool: + name: 'onnxruntime-Ubuntu2204-AMD-CPU' + os: 'linux' + stages: + - template: stages/release-stage.yml diff --git a/.azure-pipelines/stages/jobs/steps/release-steps.yml b/.azure-pipelines/stages/jobs/steps/release-steps.yml new file mode 100644 index 0000000000..be1d9e8860 --- /dev/null +++ b/.azure-pipelines/stages/jobs/steps/release-steps.yml @@ -0,0 +1,20 @@ +steps: +- task: UsePythonVersion@0 + inputs: + versionSpec: '3.11' + displayName: 'Set Up Python' +- script: python -m pip install --upgrade pip build wheel + displayName: 'Install Python build dependencies' +- script: python -m build + displayName: 'Build ONNX Script wheel' +- task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: 'dist' + Contents: '*.*' + TargetFolder: '$(Build.ArtifactStagingDirectory)' +- task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Python Wheel' + inputs: + ArtifactName: 'onnxscript' + targetPath: '$(Build.ArtifactStagingDirectory)' diff --git a/.azure-pipelines/stages/release-stage.yml b/.azure-pipelines/stages/release-stage.yml new file mode 100644 index 0000000000..881fdbd60b --- /dev/null +++ b/.azure-pipelines/stages/release-stage.yml @@ -0,0 +1,11 @@ +stages: +- stage: Stage + jobs: + - job: Job + steps: + - template: jobs/steps/release-steps.yml + # Test the wheels. This needs to happen after PublishBuildArtifacts + # to avoid interference with the artifacts + - script: python -m pip install dist/*.whl --no-deps + displayName: 'Install wheel' + condition: eq(variables['ONNX_SCRIPT_RELEASE'], 1) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000000..b74c06fed3 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,5 @@ +## Code Standards + +### Required Before Each Commit +- Run `lintrunner -a` before committing any changes to ensure proper code formatting +- This will run lintrunner on all updated files to maintain consistent style diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 0000000000..2434ad5390 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,30 @@ +changelog: + exclude: + authors: + - dependabot + categories: + - title: Breaking Changes + labels: + - "topic: breaking changes" + - title: Core ONNX Script + labels: + - "topic: onnxscript core" + - "topic: ast converter" + - title: Optimizer and rewriter + labels: + - "module: rewriter" + - "module: optimizer" + - "topic: ort-fusions" + - title: ONNX IR + labels: + - "module: IR" + - "topic: passes" + - title: Torch Lib + labels: + - "module: torchlib" + - title: Documentation + labels: + - "topic: documentation" + - title: Other Changes + labels: + - "*" diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index a4cedc9daa..6953a76929 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -41,7 +41,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index d0ecd01ebf..3fe51a3a5a 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -6,7 +6,7 @@ on: - main - 'gh/**/base' # ghstack base branches pull_request: - types: [opened, synchronize, reopened, ready_for_review] + merge_group: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -20,7 +20,7 @@ jobs: pull-requests: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: misspell # Check spelling uses: reviewdog/action-misspell@v1 with: @@ -43,19 +43,20 @@ jobs: permissions: security-events: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: # Version range or exact version of Python to use, using SemVer's version range syntax. Reads from .python-version if unset. python-version: "3.10" - name: Install ONNXScript run: | - # The code is from azure-pipelines.yml # Install dependencies python -m pip install --upgrade pip python -m pip install --upgrade setuptools - python -m pip install -q -r requirements-dev.txt + python -m pip install -r requirements-dev.txt + # FIXME: numpy 2.2 has some typing changes that break the mypy CI but it's otherwise fine + python -m pip install "numpy<2.2" # Install packages python -m pip install -e . lintrunner init diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index f28d6ce349..faf40b9ec3 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -13,6 +13,7 @@ on: # Allows you to run this workflow manually from the Actions tab workflow_dispatch: + merge_group: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -25,32 +26,23 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] name: - - py312-torch-nightly + - py312 - py311 - py311-torch-nightly - py311-onnx-weekly - py311-ort-nightly - - py311-experimental-torchlib-tracing - - py311-experimental-torchlib-onnx-ir + - py311-onnx-ir-git - py310 - - py39 - - py38 include: + - name: py312 + python-version: "3.12" + nox-tag: test build - name: py311 python-version: "3.11" - nox-tag: test build + nox-tag: test - name: py310 python-version: "3.10" nox-tag: test - - name: py39 - python-version: "3.9" - nox-tag: test - - name: py38 - python-version: "3.8" - nox-tag: test - - name: py312-torch-nightly - python-version: "3.12" - nox-tag: test-torch-nightly - name: py311-torch-nightly python-version: "3.11" nox-tag: test-torch-nightly @@ -60,17 +52,14 @@ jobs: - name: py311-ort-nightly python-version: "3.11" nox-tag: test-ort-nightly - - name: py311-experimental-torchlib-tracing - python-version: "3.11" - nox-tag: test-experimental-torchlib-tracing - - name: py311-experimental-torchlib-onnx-ir + - name: py311-onnx-ir-git python-version: "3.11" - nox-tag: test-experimental-torchlib-onnx-ir + nox-tag: test-onnx-ir-git runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install nox @@ -78,33 +67,26 @@ jobs: - name: Pull Test Data run: git lfs pull - name: Run tests - run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml + run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junitxml junit.xml env: CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" - name: Upload coverage to Codecov if: always() - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} - - name: Upload Test Results - if: always() - uses: actions/upload-artifact@v3 + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 with: - name: Test Results (${{ matrix.name }}-${{ matrix.os }}) - path: pytest.xml + token: ${{ secrets.CODECOV_TOKEN }} - name: Upload torchlib error reports if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: Error reports (${{ matrix.name }}-${{ matrix.os }}) path: error_reports - - name: Upload IR profiling results - if: matrix.name == 'py311' || matrix.name == 'py311-onnx-weekly' - uses: actions/upload-artifact@v3 - with: - name: IR profiling results - path: tests/ir/serde_test_profiles build_docs: strategy: @@ -113,9 +95,9 @@ jobs: os: [ubuntu-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" cache: pip @@ -137,9 +119,9 @@ jobs: update_readme: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 - name: Update readme run: | python docs/update_readme.py @@ -148,23 +130,3 @@ jobs: echo "Update readme by running `python docs/update_readme.py`" exit 1 fi - - publish-test-results: - name: "Publish Tests Results to Github" - needs: test - runs-on: ubuntu-latest - permissions: - checks: write - # only needed unless run with comment_mode: off - pull-requests: write - if: always() - steps: - - name: Download Artifacts - uses: actions/download-artifact@v3 - with: - path: artifacts - - - name: Publish Test Results - uses: EnricoMi/publish-unit-test-result-action@v2 - with: - files: "artifacts/**/*.xml" diff --git a/.github/workflows/pages.yaml b/.github/workflows/pages.yaml index 1e6aa4142c..ce638dc60d 100644 --- a/.github/workflows/pages.yaml +++ b/.github/workflows/pages.yaml @@ -25,14 +25,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Setup Pages uses: actions/configure-pages@v4 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel @@ -42,7 +42,7 @@ jobs: - name: Build documentation run: python -m sphinx docs dist/html - name: Upload documentation archive - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v4 with: path: 'dist/html' - name: Deploy to GitHub Pages diff --git a/.gitignore b/.gitignore index 0e9a057b9f..3344aa7659 100644 --- a/.gitignore +++ b/.gitignore @@ -41,9 +41,11 @@ coverage.xml .pytest_cache/ cover/ test-output.xml +*.sarif # Sphinx documentation docs/_build/ +docs/sg_execution_times.rst # Jupyter Notebook .ipynb_checkpoints @@ -93,10 +95,13 @@ dmypy.json # Generated files *.onnx +*.csv +*.xlsx !testdata/**/*.onnx *.onnxlib **/onnx_backend_test_code/** docs/auto_examples/* +docs/**/generated/* tests/export/* tests/models/testoutputs/* tests/mylib.onnxlib diff --git a/.lintrunner.toml b/.lintrunner.toml index aa88d1f66e..907f3bfce6 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -46,17 +46,17 @@ exclude_patterns = [ 'onnxscript/onnx_types.py', 'onnxscript/**/*_test.py', # Skip linting test files for speed 'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy - 'onnxscript/optimizer/evaluator.py', # FIXME - 'onnxscript/optimizer/constant_folding.py', # FIXME + 'onnxscript/optimizer/_legacy/evaluator.py', # FIXME + 'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME 'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME - 'onnxscript/_legacy_ir/irbuilder.py', # FIXME - 'onnxscript/optimizer/fold_constants_v0.py', # FIXME + 'onnxscript/rewriter/ort_fusions/models/*.py', # onnxscript code + 'onnxscript/rewriter/ort_fusions/models/_phi2lm.py', # onnxscript code + 'onnxscript/rewriter/ort_fusions/models/_phi4lm.py', # onnxscript code + 'onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py', # onnxscript code 'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME 'onnxscript/tools/function_unittest_producer.py', # FIXME - 'onnxscript/_legacy_ir/visitor.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME - 'onnxscript/rewriter/generic_pattern.py', # FIXME ] command = [ 'python', @@ -113,16 +113,14 @@ include_patterns = [ '**/*.py', ] exclude_patterns = [ - 'examples/**', # TODO: Merge with docs/examples - 'docs/examples/**', - 'docs/tutorial/examples/**', + 'examples/**', + 'docs/**', 'onnxscript/converter_test.py', 'tests/functions/**', 'tests/models/**', 'tests/onnx_backend_test_code/**', 'onnxscript/optimizer/**', # FIXME 'onnxscript/rewriter/**', # FIXME - 'onnxscript/_legacy_ir/**', # FIXME ] command = [ 'python', diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index f9ba8cf65f..686e5e7a09 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -7,3 +7,4 @@ Resources: - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns +- Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 66d4781c4f..346fad1f6a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,19 +1,3 @@ - - - - - - -
⚠️ -NOTE: ONNX Script is in very early -and active development and the team anticipates -breaking changes as the project evolves. -ONNX Script is not ready for production, -but early feedback is welcome. -⚠️
- ----- - # Contributing to ONNX Script We're always looking for your help to improve the product (bug fixes, new features, documentation, etc). Currently ONNX Script is under early and heavy development, so we encourage proposing any major changes by [filing an issue](https://github.com/microsoft/onnxscript/issues) to discuss your idea with the team first. diff --git a/README.md b/README.md index 484917be66..ec3ce7bcc8 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,19 @@ models using a subset of Python. ONNX Script is: * **Debuggable:** allows for eager-mode evaluation that provides for a more delightful ONNX model debugging experience. +This repo also covers: + +* **ONNX Script Optimizer:** provides functionality to optimize an ONNX + model by performing optimizations and clean-ups such as constant folding, + dead code elimination, etc. +* **ONNX Rewriter:** provides functionality to replace certain patterns in + an ONNX graph with replacement patterns based on user-defined rewrite rules. + Note however that ONNX Script does **not** intend to support the entirety of the Python language. +Website: [https://microsoft.github.io/onnxscript/](https://microsoft.github.io/onnxscript/) + ## Design Overview ONNX Script provides a few major capabilities for authoring and debugging @@ -140,6 +150,63 @@ result = Hardmax(v) More examples can be found in the [docs/examples](docs/examples) directory. +## ONNX Script Tools + +### ONNX Optimizer + +The ONNX Script Optimizer tool provides the user with the functionality to optimize an ONNX model by performing optimizations and clean-ups such as constant folding, dead code elimination, etc. In order to utilize the optimizer tool: + +```python +import onnxscript + +onnxscript.optimizer.optimize(onnx_model) +``` + +For a detailed summary of all the optimizations applied by the optimizer call, refer to the tutorial [Optimizing a Model using the Optimizer](https://microsoft.github.io/onnxscript/tutorial/optimizer/optimize.html) + +### ONNX Rewriter + +The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on user-defined rewrite rules. The rewriter tools allows two different methods in which patterns in the graph can be rewritten. + +### Pattern-based rewriting + +For this style of rewriting, the user provides a `target_pattern` that is to be replaced, a `replacement_pattern` and a `match_condition` (pattern rewrite will occur only if the match condition is satisfied). A simple example on how to use the pattern-based rewriting tool is as follows: + +```python +from onnxscript.rewriter import pattern + +# The target pattern +def erf_gelu_pattern(op, x): + return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) + +def erf_gelu_pattern_2(op, x): + return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5 + +# The replacement pattern +def gelu(op, x: ir.Value): + return op.Gelu(x, domain="com.microsoft") + +# Create multiple rules +rule1 = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement +) +rule2 = pattern.RewriteRule( + erf_gelu_pattern_2, # Target Pattern + gelu, # Replacement +) +# Create a Rewrite Rule Set with multiple rules. +rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2]) +# Apply rewrites +model_with_rewrite_applied = onnxscript.rewriter.rewrite( + model, # Original ONNX Model + pattern_rewrite_rules=rewrite_rule_set, +) +return model_with_rewrite_applied +``` + +For a detailed tutorial on how to create target_pattern, replacement_pattern and match_condition blocks in order to utilize the pattern-based rewriter, refer to the tutorial [Pattern-based Rewrite Using Rules](https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html) + ## Development Guidelines Every change impacting the converter or the eager evaluation must be diff --git a/VERSION b/VERSION index 6e8bf73aa5..7d8568351b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.0 +0.5.4 diff --git a/docs/_templates/classtemplate.rst b/docs/_templates/classtemplate.rst new file mode 100644 index 0000000000..24a5ac1803 --- /dev/null +++ b/docs/_templates/classtemplate.rst @@ -0,0 +1,14 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: + :undoc-members: + :member-order: bysource + +.. + autogenerated from docs/_templates/classtemplate.rst diff --git a/docs/_templates/classtemplate_inherited.rst b/docs/_templates/classtemplate_inherited.rst new file mode 100644 index 0000000000..07c84a9068 --- /dev/null +++ b/docs/_templates/classtemplate_inherited.rst @@ -0,0 +1,16 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: + :undoc-members: + :inherited-members: + :member-order: bysource + + +.. + autogenerated from docs/_templates/classtemplate.rst diff --git a/docs/_templates/functiontemplate.rst b/docs/_templates/functiontemplate.rst new file mode 100644 index 0000000000..f41fb0d764 --- /dev/null +++ b/docs/_templates/functiontemplate.rst @@ -0,0 +1,12 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autofunction:: {{ name }} + + +.. + autogenerated from docs/_templates/functiontemplate.rst diff --git a/docs/api/index.md b/docs/api/index.md index 59162fb166..a6dd4bd59b 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -1,8 +1,31 @@ # API +## Author Models + ```{toctree} +:maxdepth: 1 + decorator opsets converter values ``` + +## Model transformation + +```{toctree} +:maxdepth: 1 + +optimizer +rewriter +rewriter_pattern +version_converter +``` + +## Testing + +```{toctree} +:maxdepth: 1 + +testing +``` diff --git a/docs/api/optimizer.md b/docs/api/optimizer.md new file mode 100644 index 0000000000..6c8adf21bb --- /dev/null +++ b/docs/api/optimizer.md @@ -0,0 +1,18 @@ +# onnxscript.optimizer + +```{eval-rst} +.. automodule::onnxscript.optimizer +.. currentmodule:: onnxscript +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: functiontemplate.rst + :nosignatures: + + optimizer.optimize + optimizer.inline + optimizer.basic_constant_propagation + optimizer.fold_constants +``` diff --git a/docs/api/rewriter.md b/docs/api/rewriter.md new file mode 100644 index 0000000000..8ff015844b --- /dev/null +++ b/docs/api/rewriter.md @@ -0,0 +1,26 @@ +# onnxscript.rewriter + +```{eval-rst} +.. automodule::onnxscript.rewriter +.. currentmodule:: onnxscript +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: functiontemplate.rst + :nosignatures: + + rewriter.rewrite +``` + +## IR passes + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: classtemplate.rst + :nosignatures: + + rewriter.RewritePass +``` diff --git a/docs/api/rewriter_pattern.md b/docs/api/rewriter_pattern.md new file mode 100644 index 0000000000..c7deccc6dd --- /dev/null +++ b/docs/api/rewriter_pattern.md @@ -0,0 +1,40 @@ +# onnxscript.rewriter.pattern + +```{eval-rst} +.. automodule::onnxscript.rewriter.pattern +.. currentmodule:: onnxscript +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: classtemplate.rst + :nosignatures: + + rewriter.pattern.Pattern + rewriter.pattern.StringPattern + rewriter.pattern.StringConstantPattern + rewriter.pattern.PrefixPattern + rewriter.pattern.AttrPattern + rewriter.pattern.AttrConstantPattern + rewriter.pattern.OpsetPatternBuilder + rewriter.pattern.OpPatternBuilder + rewriter.pattern.MatchResult + rewriter.pattern.ValuePattern + rewriter.pattern.NodePattern + rewriter.pattern.NodeOutputPattern + rewriter.pattern.AnyValue + rewriter.pattern.Constant + rewriter.pattern.OrValue + rewriter.pattern.GraphPattern + rewriter.pattern.ReplacementSubgraph + rewriter.pattern.ReplacementPatternFunction + rewriter.pattern.PatternMatcher + rewriter.pattern.SimplePatternMatcher + rewriter.pattern.RewriteRule + rewriter.pattern.RewriteRuleSet + rewriter.pattern.RewriteRuleClassBase + rewriter.pattern.MatchStatus + rewriter.pattern.MatchInfo + rewriter.pattern.MatchingTracer +``` diff --git a/docs/api/testing.md b/docs/api/testing.md new file mode 100644 index 0000000000..d7d5fca800 --- /dev/null +++ b/docs/api/testing.md @@ -0,0 +1,6 @@ +# Testing + +```{eval-rst} +.. automodule:: onnxscript.testing + :members: +``` diff --git a/docs/api/version_converter.md b/docs/api/version_converter.md new file mode 100644 index 0000000000..0478efbf5a --- /dev/null +++ b/docs/api/version_converter.md @@ -0,0 +1,28 @@ +# onnxscript.version_converter + +```{eval-rst} +.. automodule::onnxscript.version_converter +.. currentmodule:: onnxscript +``` + +## Functions + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: functiontemplate.rst + :nosignatures: + + version_converter.convert_version +``` + +## IR passes + +```{eval-rst} +.. autosummary:: + :toctree: generated + :template: classtemplate.rst + :nosignatures: + + version_converter.ConvertVersionPass +``` diff --git a/docs/conf.py b/docs/conf.py index 63dd8e7d44..d96ffe067f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,9 @@ -# Configuration file for the Sphinx documentation builder. -# To run the documentation: python -m sphinx docs dist/html +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Configuration file for the Sphinx documentation builder. + +To run the documentation: python -m sphinx docs dist/html +""" import os import re @@ -20,7 +24,7 @@ # -- General configuration --------------------------------------------------- extensions = [ - "myst_parser", + "myst_nb", "sphinx_copybutton", "sphinx_exec_code", "sphinx_gallery.gen_gallery", @@ -84,7 +88,11 @@ "python": (f"https://docs.python.org/{sys.version_info.major}", None), "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), + "onnx": ("https://onnx.ai/onnx/", None), + "onnx_ir": ("https://onnx.ai/ir-py/", None), "onnxruntime": ("https://onnxruntime.ai/docs/api/python/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), + "torch": ("https://pytorch.org/docs/main/", None), } # -- Options for Sphinx Gallery ---------------------------------------------- diff --git a/docs/examples/01_plot_selu.py b/docs/examples/01_plot_selu.py index 57a1f03c11..5ad3c49355 100644 --- a/docs/examples/01_plot_selu.py +++ b/docs/examples/01_plot_selu.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ Generating a FunctionProto ========================== diff --git a/docs/examples/02_plot_square_loss.py b/docs/examples/02_plot_square_loss.py index 5dce3545c8..181e4cd2ac 100644 --- a/docs/examples/02_plot_square_loss.py +++ b/docs/examples/02_plot_square_loss.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ Generating a ModelProto ======================= diff --git a/docs/examples/03_export_lib.py b/docs/examples/03_export_lib.py index 8a8993b7a8..f710fcb880 100644 --- a/docs/examples/03_export_lib.py +++ b/docs/examples/03_export_lib.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ Generating a LibProto ===================== diff --git a/docs/examples/04_plot_eager_mode_evaluation.py b/docs/examples/04_plot_eager_mode_evaluation.py index 740e2275af..d1c8f7fb75 100644 --- a/docs/examples/04_plot_eager_mode_evaluation.py +++ b/docs/examples/04_plot_eager_mode_evaluation.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ Eager mode evaluation ===================== diff --git a/docs/examples/05_plot_model_props.py b/docs/examples/05_plot_model_props.py index 4e10339bea..950b0e3467 100644 --- a/docs/examples/05_plot_model_props.py +++ b/docs/examples/05_plot_model_props.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ ModelProto Properties ===================== diff --git a/docs/examples/06_plot_model_local_funs.py b/docs/examples/06_plot_model_local_funs.py index 3a60b3e6cc..fdb0e434bb 100644 --- a/docs/examples/06_plot_model_local_funs.py +++ b/docs/examples/06_plot_model_local_funs.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ Model Local Functions ===================== diff --git a/docs/index.md b/docs/index.md index 3cd5e3db30..4dd0472706 100644 --- a/docs/index.md +++ b/docs/index.md @@ -103,7 +103,7 @@ result = MatmulAdd(x, wt, bias) Overview tutorial/index api/index -intermediate_representation/index +ir/index auto_examples/index articles/index ``` diff --git a/docs/intermediate_representation/index.md b/docs/intermediate_representation/index.md deleted file mode 100644 index fd3199671b..0000000000 --- a/docs/intermediate_representation/index.md +++ /dev/null @@ -1,8 +0,0 @@ -# ONNX IR - -```{toctree} -:maxdepth: 1 - -tensors -ir_api -``` diff --git a/docs/intermediate_representation/ir_api.md b/docs/intermediate_representation/ir_api.md deleted file mode 100644 index 2d1d8ebcb6..0000000000 --- a/docs/intermediate_representation/ir_api.md +++ /dev/null @@ -1,9 +0,0 @@ -# onnxscript.ir - - - -```{eval-rst} -.. automodule:: onnxscript.ir - :members: - :undoc-members: -``` diff --git a/docs/intermediate_representation/tensors.md b/docs/intermediate_representation/tensors.md deleted file mode 100644 index cca80264ee..0000000000 --- a/docs/intermediate_representation/tensors.md +++ /dev/null @@ -1,322 +0,0 @@ -# Tensor Representation in the IR - -The ONNX IR offers the {py:class}`ir.TensorProtocol ` interface for usings different data structures as backing data for tensors. Besides the traditional {py:class}`onnx.TensorProto`, you can also use {py:class}`np.ndarray`, {py:class}`torch.Tensor`, {py:class}`jax.Array`, and virtually anything else to represent tensors in the graph. This allows for them to be accessed and serialized via the same `TensorProtocol` interface, without incurring additional copies at initialization. - -## The `TensorProtocol` - -{py:class}`ir.TensorProtocol ` defines a read-only interface for representing tensors. A tensor class implementing the interface has attributes like `name`, `shape`, `dtype`, `size`, `nbytes` and `metadata_props` to describe basic properties of the tensor. Additionally, it should implement two methods {py:meth}`numpy ` and {py:meth}`__array__ ` which will produce equivalent NumPy arrays from the backing data. - -:::{note} -When interacting with initializers, constant values and tensor attributes, it is best to assume `TensorProtocol` and only use `isinstance` to check for concrete classes when there is a need. -::: - -## Tensor Classes - -### ir.TensorProtoTensor - -The ONNX spec defines [different ways](https://github.com/onnx/onnx/blob/d6f87121ba256ac6cc4d1da0463c300c278339d2/onnx/onnx.proto#L567-L654) for storing tensor data as an {py:class}`onnx.TensorProto ` protocol buffer message. The IR has corresponding classes for each of these data storage methods. - -We use the {py:class}`ir.TensorProtoTensor ` as a wrapper around the proto to implement the `ir.TensorProtocol` interface. You can access `shape`, `dtype` etc. as usual. A copy is incurred only when `numpy()` is called. - -:::{note} -Directly initializing an `ir.TensorProtoTensor`, as below, is possible. However, it is usually recommended to use `ir.serde.deserialize_tensor` because it handles all types of `TensorProto`s (`ir.TensorProtoTensor` doesn't handle external tensors, for example). Please refer to [From `TensorProto`s and back](#from-tensorprotos-and-back) for an example. -::: - -```{eval-rst} -.. exec_code:: - - import onnx - from onnxscript import ir - - tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3]) - tensor = ir.TensorProtoTensor(tensor_proto) - print("tensor: ", tensor) # TensorProtoTensor(name='tensor') - print("shape: ", tensor.shape) # ir.Shape([3]) - print("dtype: ", tensor.dtype) # ir.DataType.INT16 - print(tensor.raw == tensor_proto) # The raw field is the exact tensor_proto provided at initialization - print("tobytes: ", tensor.tobytes()) # b'\x01\x00\x02\x00\x03\x00' - print("numpy: ", tensor.numpy()) # array([1, 2, 3], dtype=int16) -``` - -### ir.ExternalTensor - -Tensor data stored externally in the disk are typically large and will take up memory when loaded. The {py:class}`ir.ExternalTensor ` class uses memory mapping to avoid loading the tensor into memory. You are able to use the tensor as a normal NumPy array with minimal memory usage. - -Refer to {py:func}`ir.serde.deserialize_tensor ` to find an example on converting an `onnx.TensorProto` to an {py:class}`ir.ExternalTensor `. - -### ir.Tensor - -{py:class}`ir.Tensor ` is a wrapper around NumPy array compatible array objects like {py:class}`np.ndarray` and {py:class}`torch.Tensor`. It is best for creating in-memory tensors without converting it to a `TensorProto` to reduce the conversion overhead. - -:::{tip} -An array object is compatible if it defines the `__array__` method. -::: - -To create a tensor from an array, simply initialize it with an NumPy array - -```python -tensor = ir.Tensor(np.random.rand(1, 2)) -``` - -The initializer will obtain dtype and shape information from the array. - -To create a tensor from objects other than NumPy array, you need to specify the dtype: - -```{eval-rst} -.. exec_code:: - - import torch - from onnxscript import ir - - torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16) - tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16) - print(tensor.numpy()) # array([1., 2., 3.], dtype=float16) -``` - -### String Tensor - -Use {py:class}`ir.StringTensor ` to create a string tensor. - - - -### Sparse Tensor - -Sparse tensors are not yet supported, but they are on our roadmap. - -## From `TensorProto`s and back - -In the following scenario, we show how to go from a `TensorProto` to an `ir.Tensor`, run some computation, then turn it back to an `ir.Tensor` and finally `TensorProto` - -```{eval-rst} -.. exec_code:: - - from onnxscript import ir - import onnx - import numpy as np - - # 1. Create the TensorProto - proto = onnx.helper.make_tensor( - "tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6] - ) - - # 2. Create an IR Tensor from the Protobuf message - tensor = ir.serde.deserialize_tensor(proto) - # Note that we get a TensorProtoTensor that implements the TensorProtocol - print("tensor:", tensor) # TensorProtoTensor(name='tensor') - print("tensor.numpy():", tensor.numpy()) # [[1. 2. 3.] - # [4. 5. 6.]] - print("tensor.tobytes():", tensor.tobytes()) # b'\x00<\x00@\x00B\x00D\x00E\x00F' - - # 3. Do computation using numpy - mean = tensor.numpy().mean(axis=0) - print("mean:", mean) # array([2.5, 3.5, 4.5], dtype=float16) - - # 4. Create a Tensor from the ndarray. Note that we use ir.Tensor - tensor_mean = ir.Tensor(mean) - print("tensor_mean:", tensor_mean) # Tensor(array([2.5, 3.5, 4.5], dtype=float16), name='') - - # 5. Obtain the TensorProto from ir.Tensor - mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean) - print("mean_tensor_proto:", mean_tensor_proto) - print( - "onnx.numpy_helper.to_array(mean_tensor_proto):", - onnx.numpy_helper.to_array(mean_tensor_proto) - # array([2.5, 3.5, 4.5], dtype=float16) - ) - - # You can obtain the bytes data as well - print("tensor_mean.tobytes():", tensor_mean.tobytes()) - print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes()) - - # Explore other methods defined by TensorProtocol: - print("\n# Explore other methods defined by TensorProtocol:") - print("tensor_mean.shape:", tensor_mean.shape) - print("tensor_mean.dtype:", tensor_mean.dtype) - print("tensor_mean.name:", tensor_mean.name) - print("tensor_mean.doc_string:", tensor_mean.doc_string) - print("tensor_mean.raw:", tensor_mean.raw) - print("tensor_mean.metadata_props:", tensor_mean.metadata_props) - print("tensor_mean.size:", tensor_mean.size) - print("tensor_mean.nbytes:", tensor_mean.nbytes) - print("tensor_mean.raw:", tensor_mean.raw) - print("\nUse the display() method to view the tensor") - tensor_mean.display() -``` - -## Working with non-native NumPy dtypes: bfloat16, float8, int4 - -`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, the value is the bit representation for the dtype: - -- `int8` for (unpacked) int4, with the sign bit extended to 8 bits. -- `uint8` for (unpacked) uint4. -- `uint8` for 8-bit data types like float8. -- `uint16` for bfloat16. - -uint4/int4 is always unpacked; `tobyte()` produces a packed representation as expected. - -Initialization of `ir.Tensor` requires the NumPy array to follow these typing constraints as well. - -:::{tip} -You can use the [ml_dtypes package](https://github.com/jax-ml/ml_dtypes) to extend NumPy and work with these values. - -```bash -pip install --upgrade ml_dtypes -``` - -::: - -The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its values, and create a new tensor to store the transformed values. - -```{eval-rst} -.. exec_code:: - - from onnxscript import ir - import numpy as np - - array = np.array([0b1, 0b11], dtype=np.uint8) - tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN) - print(tensor) # Tensor(array([1, 3], dtype=uint8), name='') - print("tensor.numpy():", tensor.numpy()) # array([1, 3], dtype=uint8) - - # You can use the ml_dtypes package to work with these values in NumPy - import ml_dtypes - float8_array = tensor.numpy().view(ml_dtypes.float8_e4m3fn) - print("float8_array:", float8_array) # array([0.00195312, 0.00585938], dtype='float8_e4m3fn') - - # Compute - times_100 = float8_array * 100 - print("times_100:", times_100) - - # Create a new tensor out of the new value; dtype must be specified - new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN) - print("new_tensor:", new_tensor) # Tensor(array([36, 49], dtype=uint8), name='') - print("new_tensor == times_100", new_tensor.numpy().view(ml_dtypes.float8_e4m3fn) == times_100) # array([ True, True]) - -``` - -## Advanced Usage - -### Subclass ir.Tensor for More Efficient Access and Broader dtype Support - -{py:class}`ir.Tensor` internally converts any array compatible objects into NumPy arrays to produce the byte representation in `tobytes()`. This can be inefficient due to the additional conversion. It also limits support for dtypes not supported by NumPy like bfloat16, because the `__array__` method would fail. - -To fully support arrays from other frameworks, it is usually a good idea to create specialized classes to handle them. The `TorchTensor` class below demonstrates how you can subclass `ir.Tensor` to handle PyTorch tensors: - -```{eval-rst} -.. exec_code:: - - import ctypes - from typing import Any - - import torch - from onnxscript import ir - - # Define utilities to convert PyTorch data types so users do not need to specify manually - _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { - torch.bfloat16: ir.DataType.BFLOAT16, - torch.bool: ir.DataType.BOOL, - torch.complex128: ir.DataType.COMPLEX128, - torch.complex64: ir.DataType.COMPLEX64, - torch.float16: ir.DataType.FLOAT16, - torch.float32: ir.DataType.FLOAT, - torch.float64: ir.DataType.DOUBLE, - torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, - torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, - torch.float8_e5m2: ir.DataType.FLOAT8E5M2, - torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, - torch.int16: ir.DataType.INT16, - torch.int32: ir.DataType.INT32, - torch.int64: ir.DataType.INT64, - torch.int8: ir.DataType.INT8, - torch.uint8: ir.DataType.UINT8, - } - - - def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: - return _TORCH_DTYPE_TO_ONNX[dtype] - - class TorchTensor(ir.Tensor): - def __init__(self, tensor: torch.Tensor): - # Pass the tensor as the raw data to ir.Tensor's constructor - super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype)) - - def __array__(self, dtype: Any = None) -> "np.ndarray": - # numpy() calls __array__ in ir.Tensor - if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).__array__(dtype) - if self.dtype in { - ir.DataType.FLOAT8E4M3FN, - ir.DataType.FLOAT8E4M3FNUZ, - ir.DataType.FLOAT8E5M2, - ir.DataType.FLOAT8E5M2FNUZ - }: - return self.raw.view(torch.uint8).__array__(dtype) - return self.raw.__array__(dtype) - - def tobytes(self) -> bytes: - # Implement tobytes to support native PyTorch types so we can use types like bloat16 - # Reading from memory directly is also more efficient because - # it avoids the copy to NumPy array - tensor = self.raw.detach().cpu().contiguous() - return bytes( - (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( - tensor.data_ptr() - ) - ) - - # Test the implementation - torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16) - tensor = TorchTensor(torch_tensor) - print("tensor: ", tensor) - print("numpy: ", tensor.numpy()) - print("tobytes: ", tensor.tobytes()) # b'\x80?\x00@@@' - print("nbytes: ", tensor.nbytes) # 6 -``` - -The `TorchTensor` class above implements `tobytes()` to produce the correct bytes representation for the tensor when it is serialized into an ONNX file / TensorProto. The class also implements the `__array__()` method to return the bit representation for types NumPy does not support. This way analysis passes can still perform computation on these values. - -### Computation with different Frameworks - -Since `ir.Tensor` implements the `__array__` method and `__dlpack__` methods, its content can be shared with computation frameworks without copying. For example: - -```{eval-rst} -.. exec_code:: - - from onnxscript import ir - - # We can call numpy methods directly on ir.Tensor - import numpy as np - print(np.multiply(ir.Tensor(np.array([1, 2])), 42)) # array([42., 84.]) - - # We can transfer arrays to different frameworks - import jax.numpy as jnp - import jax - import torch - - # Create ir.Tensor - jax_array = jnp.array([10., 20.]) - ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT) - torch_tensor = torch.tensor([30., 40.]) - ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) - - # Use numpy for computation - print(np.multiply(ir_tensor_jax, ir_tensor_torch)) # array([300., 800.], dtype=float32) - - # Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same - jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch) - print(jax_array_from_ir + jax_array) # [40. 60.] - - # Use PyTorch for computation - torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax) - print(torch_tensor_from_ir - torch_tensor) # tensor([-20., -20.]) - - # They can all be serialized into TensorProto - proto = ir.serde.serialize_tensor(ir_tensor_jax) - print(type(proto)) # - print(proto) - - # The value is exactly the same as jax_array - print(ir.serde.deserialize_tensor(proto).numpy()) # [10. 20.] -``` - -This is particularly useful if you are creating passes on the graph that requires doing computation on concrete values. You are free to use your favorite frameworks to create the passes. The transformed graph that contains newly created `ir.Tensor`s will be compatible with downstream passes even if they leverage other computation frameworks. diff --git a/docs/ir/index.md b/docs/ir/index.md new file mode 100644 index 0000000000..ae6b0802b5 --- /dev/null +++ b/docs/ir/index.md @@ -0,0 +1,5 @@ +# ONNX IR + +ONNX IR is now an official ONNX project! Documentation has been migrated to [onnx.ai/ir-py/](https://onnx.ai/ir-py/). + +You may continue to use `onnxscript.ir` unchanged for compatibility with older (<0.3) versions of ONNX Script. diff --git a/docs/test/test_documentation_examples.py b/docs/test/test_documentation_examples.py index dcdcde2818..3cf7ac3b30 100644 --- a/docs/test/test_documentation_examples.py +++ b/docs/test/test_documentation_examples.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import os import subprocess @@ -36,6 +34,9 @@ def do_test_folder(self, folder): if tested == 0: raise RuntimeError(f"No example was tested in folder {folder}.") + @unittest.skipIf( + sys.platform != "linux", reason="No need to run the documentation on every OS." + ) def test_documentation_examples(self): this = os.path.abspath(os.path.dirname(__file__)) onxc = os.path.normpath(os.path.join(this, "..", "..")) diff --git a/docs/tutorial/examples/dropout.py b/docs/tutorial/examples/dropout.py index 850b22edc4..4530c7f34d 100644 --- a/docs/tutorial/examples/dropout.py +++ b/docs/tutorial/examples/dropout.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/firstdim.py b/docs/tutorial/examples/firstdim.py index 187fedf569..63476949fd 100644 --- a/docs/tutorial/examples/firstdim.py +++ b/docs/tutorial/examples/firstdim.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/forloop.py b/docs/tutorial/examples/forloop.py index 3b32b1a0eb..75a13205d7 100644 --- a/docs/tutorial/examples/forloop.py +++ b/docs/tutorial/examples/forloop.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/forwhileloop.py b/docs/tutorial/examples/forwhileloop.py index 100f246c76..ffca170d43 100644 --- a/docs/tutorial/examples/forwhileloop.py +++ b/docs/tutorial/examples/forwhileloop.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/hardmax_end_to_end.py b/docs/tutorial/examples/hardmax_end_to_end.py index e4cd881eb3..9b49a5ca77 100644 --- a/docs/tutorial/examples/hardmax_end_to_end.py +++ b/docs/tutorial/examples/hardmax_end_to_end.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import onnx # We use ONNX opset 15 to define the function below. diff --git a/docs/tutorial/examples/leaky_relu.py b/docs/tutorial/examples/leaky_relu.py index 92fce52b10..e1d09a2a3d 100644 --- a/docs/tutorial/examples/leaky_relu.py +++ b/docs/tutorial/examples/leaky_relu.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/leaky_relu_attr_promoted.py b/docs/tutorial/examples/leaky_relu_attr_promoted.py index eb736162e3..058dc19366 100644 --- a/docs/tutorial/examples/leaky_relu_attr_promoted.py +++ b/docs/tutorial/examples/leaky_relu_attr_promoted.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/omitted_input.py b/docs/tutorial/examples/omitted_input.py index b4e839dd26..df35f49686 100644 --- a/docs/tutorial/examples/omitted_input.py +++ b/docs/tutorial/examples/omitted_input.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/outerscope_redef_error.py b/docs/tutorial/examples/outerscope_redef_error.py index a810e8eb71..41bd820d93 100644 --- a/docs/tutorial/examples/outerscope_redef_error.py +++ b/docs/tutorial/examples/outerscope_redef_error.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import graph, script from onnxscript import opset15 as op @@ -13,7 +15,7 @@ def Sum(sum_in, next): return sum_out, sum_out g = op.Constant(value=1) - all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1) + _all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1) return cumulative_sum except Exception as e: diff --git a/docs/tutorial/examples/scanloop.py b/docs/tutorial/examples/scanloop.py index c12da498da..6a409716a7 100644 --- a/docs/tutorial/examples/scanloop.py +++ b/docs/tutorial/examples/scanloop.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import graph, script from onnxscript import opset15 as op @@ -9,5 +11,5 @@ def Sum(sum_in, next): sum_out = sum_in + next return sum_out, sum_out - all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1) + _all_sum, cumulative_sum = op.Scan(0, X, body=Sum, num_scan_inputs=1) return cumulative_sum diff --git a/docs/tutorial/examples/softplus.py b/docs/tutorial/examples/softplus.py index 0929bc0a0b..18c194ea5d 100644 --- a/docs/tutorial/examples/softplus.py +++ b/docs/tutorial/examples/softplus.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. # We use ONNX opset 15 to define the function below. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/tensor_attr.py b/docs/tutorial/examples/tensor_attr.py index de24de9f70..312ad7c5eb 100644 --- a/docs/tutorial/examples/tensor_attr.py +++ b/docs/tutorial/examples/tensor_attr.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnx import TensorProto, helper from onnxscript import opset15 as op diff --git a/docs/tutorial/examples/tensor_attr2.py b/docs/tutorial/examples/tensor_attr2.py index a602b914c8..eb60b04bcd 100644 --- a/docs/tutorial/examples/tensor_attr2.py +++ b/docs/tutorial/examples/tensor_attr2.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnx import TensorProto, helper from onnxscript import opset15 as op diff --git a/docs/tutorial/examples/tensor_attr_short.py b/docs/tutorial/examples/tensor_attr_short.py index ddf32295cf..b6a2452b9b 100644 --- a/docs/tutorial/examples/tensor_attr_short.py +++ b/docs/tutorial/examples/tensor_attr_short.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnxscript import opset15 as op from onnxscript import script diff --git a/docs/tutorial/examples/whileloop.py b/docs/tutorial/examples/whileloop.py index 36b153c810..68bcfbea46 100644 --- a/docs/tutorial/examples/whileloop.py +++ b/docs/tutorial/examples/whileloop.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnx import TensorProto from onnx.helper import make_tensor diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index f3d2173270..708793a8a0 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -123,7 +123,9 @@ subsequently modified, this modification has no effect on the attribute-value or the ONNX function/model created. This may potentially cause the behavior of eager-mode execution to be inconsistent with the ONNX construct generated. -Thus, the example shown above is equivalent to the following: +Thus, the second assignment to `script_const` in the following code has no effect +on the subsequent call to `tensor_attr.to_function_proto()`, which will use the +original value of `script_const`: ```{literalinclude} examples/tensor_attr2.py ``` @@ -268,7 +270,7 @@ ONNX perspective, the two assignments to *g* represent two distinct tensors ```{toctree} :maxdepth: 1 -optimizer/index rewriter/index -``` +optimizer/index +``` diff --git a/docs/tutorial/optimizer/optimize.md b/docs/tutorial/optimizer/optimize.md index 5ceb7dfb80..8ff36f4c67 100644 --- a/docs/tutorial/optimizer/optimize.md +++ b/docs/tutorial/optimizer/optimize.md @@ -15,6 +15,7 @@ onnxscript.optimizer.optimize(model) ``` ### optimize API + The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization. ```{eval-rst} @@ -24,12 +25,8 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th ## Description of optimizations applied by `onnxscript.optimizer.optimize` -:::{table} -:widths: auto -:align: center - -| Optimization 'onnxscript.optimizer.` + .. | Description | -| - | - | +| Optimization | Description | +|-------------|-------------| | **Constant folding**
`constant_folding.fold_constants` | Applies constant folding optimization to the model. | | **Constant propagation**
`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. | | **Sequence simplification**
`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. | @@ -37,17 +34,3 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th | **Remove unused functions**
`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. | | **Inline functions with unused outputs**
`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. | | **Inline simple functions**
`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. | -::: - -## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize` - -```{eval-rst} -.. autosummary:: - :nosignatures: - - onnxscript.rewriter.broadcast_to_matmul - onnxscript.rewriter.cast_constant_of_shape - onnxscript.rewriter.gemm_to_matmul_add - onnxscript.rewriter.no_op - -``` diff --git a/docs/tutorial/rewriter/allow_other_inputs.md b/docs/tutorial/rewriter/allow_other_inputs.md new file mode 100644 index 0000000000..29ccabca03 --- /dev/null +++ b/docs/tutorial/rewriter/allow_other_inputs.md @@ -0,0 +1,27 @@ +# Specifying variable inputs in the pattern + +This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting. +The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs +beyond those specified in the pattern. If it is set to `False` (the default), then the node must +have exactly the specified inputs for a successful match. If set to `True`, the pattern will +match nodes that have the specified inputs plus any number of additional inputs. + +This is particularly useful when matching operations like `Conv` that can have optional inputs +(such as bias), or when creating generic patterns that should work with various input configurations. + +```{literalinclude} examples/allow_other_inputs.py +:pyobject: conv_pattern +``` + +```{literalinclude} examples/allow_other_inputs.py +:pyobject: conv_replacement +``` + +```{literalinclude} examples/allow_other_inputs.py +:pyobject: apply_rewrite +``` + +In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation +might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting +`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs +in the pattern definition. diff --git a/docs/tutorial/rewriter/attributes.md b/docs/tutorial/rewriter/attributes.md new file mode 100644 index 0000000000..ba72cc5ade --- /dev/null +++ b/docs/tutorial/rewriter/attributes.md @@ -0,0 +1,23 @@ +# Specifying attributes in the pattern + +This section demonstrates the use of attribute values in pattern-based rewriting. +First, write a target pattern and replacement pattern in a similar way to the previous examples. +The example pattern below will match successfully only against Dropout nodes with the +attribute value `training_mode` set to `False`. + +The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes +not specified in the pattern. If it is set to `False`, then the node must have only the specified +attribute values, and no other attributes, for a successful match. The default value for this +option is `True`. + +```{literalinclude} examples/allow_other_attributes.py +:pyobject: add_pattern +``` + +```{literalinclude} examples/allow_other_attributes.py +:pyobject: add_replacement +``` + +```{literalinclude} examples/allow_other_attributes.py +:pyobject: apply_rewrite +``` diff --git a/docs/tutorial/rewriter/commute.md b/docs/tutorial/rewriter/commute.md new file mode 100644 index 0000000000..d0690892f2 --- /dev/null +++ b/docs/tutorial/rewriter/commute.md @@ -0,0 +1,83 @@ +(heading-target-commute)= +# Utilizing `commute` parameter for pattern-matching + +```{warning} +Please note that the section below describes a convenience feature for handling commutative operators +in pattern matching. However, the implementation is a simple, brute-force, technique that generates a collection +of rewrite-rules from a given rule, taking commutativity of addition and multiplication into account. This can +lead to an exponential increase in the number of rewrite-rules. So, it should be used with caution. Pattern +disjunctions (_OR Patterns_) described earlier can be used judiciously to get a somewhat more efficient +implementation in practice (even though the potential for exponential increase still exists within the +pattern matching algorithm). Reimplementing commutativity handling using pattern disjunctions is future +work. +``` + +Extending the previous [simple example](heading-target-simple), assuming a scenario where we have a graph with the following structure. + +![commute](examples/img/erfgelu_03_commute.png){align=center width=500px} + +In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched. + +![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center} + + +If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched. + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` + +```{image} examples/img/erfgelu_06_commute.png +:alt: The resulting graph after matching. +:width: 400px +:align: center +``` + +Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods. + +(heading-target-commute-ruleset)= + +## 1. Creating a rule-set with different patterns. + +This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example: + +```python +from onnxscript.rewriter import pattern +rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) +``` + +In order to apply this method to the example above, first create the two separate target patterns as follows: + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern_2 +``` + +:::{note} +:name: rule-application-order-matters + +When you pass multiple rules in `pattern_rewrite_rules`, the **order in which they appear is important**. +This is because some rules may depend on patterns created or modified by earlier rules. For example, if `rule2` can only match after `rule1` has made a specific change in the model, then `rule1` must come **before** `rule2` in the list. +If you're not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur. +::: + + +Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite_with_ruleset +``` + +## 2. Using the `commute` parameter while creating a rule. + +Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite_with_commute +``` + +For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows: + +![commute](examples/img/erfgelu_07_commute.png){align=center width=300px} diff --git a/docs/tutorial/rewriter/conditional_rewrite.md b/docs/tutorial/rewriter/conditional_rewrite.md new file mode 100644 index 0000000000..379788e657 --- /dev/null +++ b/docs/tutorial/rewriter/conditional_rewrite.md @@ -0,0 +1,103 @@ +# Using the `match_condition` parameter for pattern-matching + +This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration. + +Let us consider a model which consists of the following pattern. + +![target_pattern](examples/img/broadcast_01.png){align=center} + +Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: + +1. Input shapes check: `input_a` and `input_b` should be broadcastable +2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)` + +If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite. + +First, write a target pattern and replacement pattern in a similar way to the first example. + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: two_reshapes_matmul_reshape_pattern +``` + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: matmul_pattern +``` + +:::{note} +:name: omitting inputs in signature + +The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters. + +Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature. +::: + +In order to validate whether matmul broadcast is sufficient, we write a condition checking function as below. +Note that the relevant inputs passed to the check function are all instances of {py:class}`onnx_ir.Value`. These represent +the values in the input graph IR that matched against the corresponding _pattern variables_ in the target +pattern. Please see documentation of the [IR API](https://onnx.ai/ir-py/) for more details on how to use it, for example to identify +the type or shape or rank of these values. + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: check_if_not_need_reshape +``` + +With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: apply_rewrite +``` + +The final graph with the applied rewrite looks as follows: + +![broadcast_rewrite](examples/img/broadcast_02.png){align=center} + +# Using MatchContext for Advanced Condition Checking + +The `context` parameter passed to condition functions is an instance of {py:class}`onnxscript.rewriter.MatchContext`, which provides access to additional information about the pattern match that can be useful for sophisticated condition checking. + +## MatchContext Properties + +The MatchContext provides the following read-only properties: + +- `model`: The entire ONNX model being matched +- `graph_or_function`: The specific graph or function being matched +- `root`: The root node of the matching subgraph +- `output_values`: The output values of the matching subgraph +- `nodes`: All nodes that are part of the matching subgraph + +## Example Usage + +Here's an example showing how to use the MatchContext to implement more sophisticated condition checking: + +```python +def advanced_condition_check(context, x, y, **_): + """Example condition function using MatchContext.""" + + # Access the main node of the pattern match + main_node = context.root + + # Check that the main_node does not have an attribute called "alpha" + if "alpha" in main_node.attributes: + return False + + # Access the broader graph context and check that x occurs as a graph-input + model = context.model + if x not in model.graph.inputs: + return False + + # You can inspect the matched nodes for advanced validation + for node in context.nodes: + if node.op_type == "Constant": + # Check properties of constant nodes in the match + pass + + # Access output values for shape/type validation + outputs = context.output_values + if len(outputs) > 0 and outputs[0].shape is not None: + # Validate output shapes + pass + + return True +``` + +This context information enables condition functions to make decisions based on the broader graph structure, the specific nodes involved in the match, and relationships between matched patterns and the rest of the model. diff --git a/docs/tutorial/rewriter/domain_option.md b/docs/tutorial/rewriter/domain_option.md new file mode 100644 index 0000000000..30a7384b59 --- /dev/null +++ b/docs/tutorial/rewriter/domain_option.md @@ -0,0 +1,38 @@ +# Specifying domains in the pattern + +This section demonstrates the use of the `_domain` option in pattern-based rewriting. +The `_domain` option allows you to specify which operator domain the pattern should match against, +and also allows you to create replacement operations in specific domains. + +ONNX operators can belong to different domains: +- The default ONNX domain (empty string or "ai.onnx") +- Custom domains like "com.microsoft" for Microsoft-specific operations +- User-defined domains for custom operations + +## Matching operations from a specific domain + +```{literalinclude} examples/domain_option.py +:pyobject: custom_relu_pattern +``` + +In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the +"custom.domain" domain will be matched, not standard ONNX `Relu` operations. + +## Creating replacement operations in a specific domain + +```{literalinclude} examples/domain_option.py +:pyobject: microsoft_relu_replacement +``` + +Here, the replacement operation is created in the "com.microsoft" domain, which might +provide optimized implementations of standard operations. + +## Complete rewrite example + +```{literalinclude} examples/domain_option.py +:pyobject: apply_rewrite +``` + +This example shows how domain-specific pattern matching can be used to migrate operations +between different operator domains, such as replacing custom domain operations with +standard ONNX operations or vice versa. diff --git a/docs/tutorial/rewriter/examples/allow_other_attributes.py b/docs/tutorial/rewriter/examples/allow_other_attributes.py new file mode 100644 index 0000000000..67e14ad659 --- /dev/null +++ b/docs/tutorial/rewriter/examples/allow_other_attributes.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Onnx Pattern Rewriting with attributes + +This script shows how to define a rewriting rule based on patterns that +are dependent on the attributes of the nodes. +""" + +import onnx + +import onnxscript +from onnxscript import FLOAT, opset18, script +from onnxscript.rewriter import pattern + + +@script() +def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: + add = opset18.Add(A, B) + result = opset18.Dropout(add, training_mode=False) + return result + + +_model = original_model.to_model_proto() +onnx.checker.check_model(_model) + + +#################################### +# The target pattern +# ===================== + + +def add_pattern(op, input): + return op.Dropout(input, training_mode=False, _allow_other_attributes=True) + + +#################################### +# The replacement pattern +# ===================== + + +def add_replacement(op, input, **_): + return op.Identity(input) + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + add_rule = pattern.RewriteRule( + add_pattern, # target pattern + add_replacement, # replacement pattern + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([add_rule]) + # Apply rewrite while passing match_condition + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/tutorial/rewriter/examples/allow_other_inputs.py b/docs/tutorial/rewriter/examples/allow_other_inputs.py new file mode 100644 index 0000000000..cc3a3d926f --- /dev/null +++ b/docs/tutorial/rewriter/examples/allow_other_inputs.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""ONNX Pattern Rewriting with variable number of inputs + +This script shows how to define a rewriting rule based on patterns that +can match nodes with additional inputs beyond those specified in the pattern. +""" + +import onnx + +import onnxscript +from onnxscript import FLOAT, opset18, script +from onnxscript.rewriter import pattern + + +@script() +def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]: + # Conv with bias - has 3 inputs: input, weight, bias + result = opset18.Conv(A, B, C) + return result + + +_model = original_model.to_model_proto() +onnx.checker.check_model(_model) + + +#################################### +# The target pattern +# ===================== + + +def conv_pattern(op, input, weight): + # Pattern to match Conv operations, allowing additional inputs like bias + # _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs) + # even though we only specify 2 inputs in the pattern + return op.Conv(input, weight, _allow_other_inputs=True) + + +#################################### +# The replacement pattern +# ===================== + + +def conv_replacement(op, input, weight, **_): + # Replace with a custom operation in a different domain + return op.OptimizedConv(input, weight, _domain="custom.domain") + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + conv_rule = pattern.RewriteRule( + conv_pattern, # target pattern + conv_replacement, # replacement pattern + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([conv_rule]) + # Apply rewrite + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/tutorial/rewriter/examples/broadcast_matmul.py b/docs/tutorial/rewriter/examples/broadcast_matmul.py index 22b374e5b2..cf56b49f07 100644 --- a/docs/tutorial/rewriter/examples/broadcast_matmul.py +++ b/docs/tutorial/rewriter/examples/broadcast_matmul.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Onnx Pattern Rewriting with match condition parameter. This script shows how to define a rewriting rule based on patterns while @@ -9,12 +11,11 @@ import logging -import numpy as np import onnx import onnxscript from onnxscript import FLOAT, ir, opset18, script -from onnxscript.rewriter import _ir_utils, pattern +from onnxscript.rewriter import pattern logger = logging.getLogger(__name__) @@ -40,14 +41,12 @@ def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1 # The target pattern # ===================== -_op = pattern.onnxop - -def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c): - reshape_a = _op.Reshape(input_a, shape_a) - reshape_b = _op.Reshape(input_b, shape_b) - matmul = _op.MatMul(reshape_a, reshape_b) - return _op.Reshape(matmul, shape_c) +def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c): + reshape_a = op.Reshape(input_a, shape_a) + reshape_b = op.Reshape(input_b, shape_b) + matmul = op.MatMul(reshape_a, reshape_b) + return op.Reshape(matmul, shape_c) #################################### @@ -65,72 +64,79 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_): def check_if_not_need_reshape( - input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_ + context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_ ) -> bool: - """If matmul broadcasting is enough, then we don't need the reshapes. + """Condition to check if we need to replace the pattern. + + If matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: 1. Input shapes check: input_a and input_b should be broadcastable 2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b) If the above are true, then we don't need the reshapes. + + Returns: + True if we need to replace the pattern, False otherwise. """ input_a_shape = input_a.shape input_b_shape = input_b.shape - # TODO: Get a helper func to get const_value - shape_c_value = _ir_utils.propagate_const_value(shape_c) - shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr] - if shape_c is None: - return False - if not isinstance(shape_c, np.ndarray): - logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c)) + shape_c_tensor = shape_c.const_value + if shape_c_tensor is None: + logger.info("The value 'shape_c' is not statically known.") return False - if len(shape_c.shape) != 1: + + if len(shape_c_tensor.shape) != 1: logger.info( "Unexpected final shape. The shape of 'shape' value is %s", - shape_c.shape, + shape_c_tensor.shape, ) return False - shape_c_list = shape_c.tolist() # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape # information. So, we need to check if the shape is None and return False. - if input_a_shape is None or input_b_shape is None or shape_c is None: + if input_a_shape is None or input_b_shape is None: logger.info("Shape information is not available for the inputs and outputs.") return False - input_a_shape = list(input_a_shape) - input_b_shape = list(input_b_shape) + input_a_shape = input_a_shape.numpy() + input_b_shape = input_b_shape.numpy() + shape_c = shape_c_tensor.numpy().tolist() + + a_rank = len(input_a_shape) + b_rank = len(input_b_shape) - dim_a = len(input_a_shape) - dim_b = len(input_b_shape) + # TODO(justinchuby): Check shape size # 1. Check if input shapes are broadcastable # 1.a. If the first input is 1-D, check whether # the dim matches the last second dim of the second input. mimic_matmul_broadcast_behavior = False - if dim_a < 2: + if a_rank < 2: + if b_rank < 2: + logger.info("Optimization of dot product is not supported yet.") + return False if input_a_shape[-1] != input_b_shape[-2]: logger.info("Original shape is not MatMul compatible.") return False else: input_a_shape = [1, *input_a_shape] - dim_a = len(input_a_shape) + a_rank = len(input_a_shape) mimic_matmul_broadcast_behavior = True # 1.b. If the second input is 1-D, check whether # the dim matches the last dim of the first input. - if dim_b < 2: + if b_rank < 2: if input_b_shape[-1] != input_a_shape[-1]: logger.info("Original shape is not MatMul compatible.") return False else: input_b_shape = [*input_b_shape, 1] - dim_b = len(input_b_shape) + b_rank = len(input_b_shape) mimic_matmul_broadcast_behavior = True # 1.c. If both inputs are at least 2-D, check whether # the last dimension of the first input matches the second # last dimension of the second input, and shape[:-2] are # broadcastable. - input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] + input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]] input_b_shape_except_last_dim = input_b_shape[:-1] broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] for idx, (dim_from_a, dim_from_b) in enumerate( @@ -150,23 +156,27 @@ def check_if_not_need_reshape( # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) # Prepend the broadcast_matmul_output_shape with the longer shape of input - if dim_a > dim_b: + if a_rank > b_rank: longer_shape = input_a_shape shorter_shape = input_b_shape else: longer_shape = input_b_shape shorter_shape = input_a_shape - broadcast_matmul_output_shape = ( - longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape - ) - if mimic_matmul_broadcast_behavior and dim_b == 2: + broadcast_matmul_output_shape = [ + *longer_shape[: -len(shorter_shape)], + *broadcast_matmul_output_shape, + ] + if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1: + # If input_b is expanded to 2-D, then we need to remove the last dimension broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] - if mimic_matmul_broadcast_behavior and dim_a == 2: + if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1: + # If input_a is expanded to 2-D, then we need to remove the first dimension + # of input_a, which would be the -2nd dimension of the output shape. broadcast_matmul_output_shape.pop(-2) - if shape_c_list != broadcast_matmul_output_shape: + if shape_c != broadcast_matmul_output_shape: logger.info( "Final output shape is not the same. Expected %s vs actual %s", - shape_c_list, + shape_c, broadcast_matmul_output_shape, ) return False diff --git a/docs/tutorial/rewriter/examples/domain_option.py b/docs/tutorial/rewriter/examples/domain_option.py new file mode 100644 index 0000000000..7018c04719 --- /dev/null +++ b/docs/tutorial/rewriter/examples/domain_option.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""ONNX Pattern Rewriting with domain specification + +This script shows how to define a rewriting rule that targets operations +from specific domains and replaces them with operations in other domains. +""" + +import onnx + +import onnxscript +from onnxscript import script +from onnxscript.rewriter import pattern +from onnxscript.values import Opset + +# Create an opset for the custom domain +opset = Opset("custom.domain", 1) + + +@script(opset) +def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]: + """Create a model with a Relu operation in a custom domain.""" + return opset.Relu(input) + + +_model = create_model_with_custom_domain.to_model_proto() +_model = onnx.shape_inference.infer_shapes(_model) +onnx.checker.check_model(_model) + + +#################################### +# The target pattern +# ===================== + + +def custom_relu_pattern(op, input): + # Pattern to match Relu operations from a specific domain + # _domain="custom.domain" specifies we only want to match operations from this domain + return op.Relu(input, _domain="custom.domain") + + +#################################### +# The replacement pattern +# ===================== + + +def standard_relu_replacement(op, input, **_): + # Replace with standard ONNX Relu (default domain) + return op.Relu(input) + + +#################################### +# Alternative: Replace with operation in different domain +# ===================== + + +def microsoft_relu_replacement(op, input, **_): + # Replace with operation in Microsoft's domain + return op.OptimizedRelu(input, _domain="com.microsoft") + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + relu_rule = pattern.RewriteRule( + custom_relu_pattern, # target pattern - matches custom domain operations + standard_relu_replacement, # replacement pattern - uses standard domain + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([relu_rule]) + # Apply rewrite + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +# The rewrite rule will now match the Relu operation in the custom domain +# and replace it with a standard ONNX Relu operation +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/tutorial/rewriter/examples/erfgelu.py b/docs/tutorial/rewriter/examples/erfgelu.py index f8723da594..e042d9f337 100644 --- a/docs/tutorial/rewriter/examples/erfgelu.py +++ b/docs/tutorial/rewriter/examples/erfgelu.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Onnx Pattern Rewriting. This script shows how to define a rewriting rule based on patterns. @@ -70,15 +72,13 @@ def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]: # The target pattern # ===================== -_op = pattern.onnxop +def erf_gelu_pattern(op, x): + return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) -def erf_gelu_pattern(x): - return 0.5 * (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) - -def erf_gelu_pattern_2(x): - return (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5 +def erf_gelu_pattern_2(op, x): + return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5 #################################### @@ -87,7 +87,7 @@ def erf_gelu_pattern_2(x): def gelu(op, x: ir.Value): - return op.Gelu(x, domain="com.microsoft") + return op.Gelu(x, _domain="com.microsoft") #################################### @@ -98,7 +98,7 @@ def gelu(op, x: ir.Value): def apply_rewrite(model): rule = pattern.RewriteRule( erf_gelu_pattern, # Target Pattern - gelu, # Replacement Pattern + gelu, # Replacement ) model_with_rewrite_applied = onnxscript.rewriter.rewrite( model, @@ -107,15 +107,31 @@ def apply_rewrite(model): return model_with_rewrite_applied +#################################### +# Rewrite Rule as a Class +# ===================== + + +class ErfGeluFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x): + return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5 + + def rewrite(self, op, x): + return op.Gelu(x, _domain="com.microsoft") + + +erf_gelu_rule_from_class = ErfGeluFusion.rule() + + def apply_rewrite_with_ruleset(model): # Create multiple rules rule1 = pattern.RewriteRule( erf_gelu_pattern, # Target Pattern - gelu, # Replacement Pattern + gelu, # Replacement ) rule2 = pattern.RewriteRule( erf_gelu_pattern_2, # Target Pattern - gelu, # Replacement Pattern + gelu, # Replacement ) # Create a Rewrite Rule Set with multiple rules. rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2]) @@ -131,7 +147,7 @@ def apply_rewrite_with_ruleset(model): def apply_rewrite_with_commute(model): rule = pattern.RewriteRule( erf_gelu_pattern, # Target Pattern - gelu, # Replacement Pattern + gelu, # Replacement ) # Create a Rewrite Rule Set with commute=True rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True) diff --git a/docs/tutorial/rewriter/examples/or_pattern.py b/docs/tutorial/rewriter/examples/or_pattern.py new file mode 100644 index 0000000000..0e9231cc1f --- /dev/null +++ b/docs/tutorial/rewriter/examples/or_pattern.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""OR-patterns. + +This script shows how to define a rewriting rule based on OR-patterns. +""" + +import onnx + +import onnxscript +from onnxscript import FLOAT, opset18, script +from onnxscript.rewriter import pattern + +#################################### +# The target pattern +# ===================== + + +def scaled_matmul(op, x, y, factor): + xy = op.MatMul(x, y) + choice1 = op.Mul(xy, factor) + choice2 = op.Div(xy, factor) + scaled_xy = pattern.OrValue( + [choice1, choice2], tag_var="op_type", tag_values=["Mul", "Div"] + ) + return op.Relu(scaled_xy) + + +#################################### +# The replacement pattern +# ===================== + + +def scaled_matmul_replacement(op, x, y, factor, op_type): + if op_type == "Mul": + return op.MatMulMulRelu(x, y, factor, _domain="some.domain") + elif op_type == "Div": + return op.MatMulDivRelu(x, y, factor, _domain="some.domain") + else: + raise ValueError(f"Unknown operation type: {op_type}") + + +#################################### +# Rewrite Rule +# ===================== +def apply_rewrite(model): + rule = pattern.RewriteRule( + scaled_matmul, # target pattern + scaled_matmul_replacement, # replacement pattern + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([rule]) + return onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + + +@script() +def original_model1(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: + t1 = opset18.MatMul(A, B) + c = opset18.Constant(value_float=2.0) + t2 = opset18.Mul(t1, c) + t3 = opset18.Relu(t2) + return t3 + + +_model = original_model1.to_model_proto() +onnx.checker.check_model(_model) + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) + +assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulMulRelu"] + + +@script() +def original_model2(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: + t1 = opset18.MatMul(A, B) + c = opset18.Constant(value_float=2.0) + t2 = opset18.Div(t1, c) + t3 = opset18.Relu(t2) + return t3 + + +_model = original_model2.to_model_proto() +onnx.checker.check_model(_model) + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) + +assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulDivRelu"] diff --git a/docs/tutorial/rewriter/examples/outputs_option.py b/docs/tutorial/rewriter/examples/outputs_option.py new file mode 100644 index 0000000000..88483385dc --- /dev/null +++ b/docs/tutorial/rewriter/examples/outputs_option.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""ONNX Pattern Rewriting with output specification + +This script shows how to define a rewriting rule that specifies +the number and names of outputs from operations. +""" + +import onnx + +import onnxscript +from onnxscript import FLOAT, opset18, script +from onnxscript.rewriter import pattern + + +@script() +def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]: + # Split operation that produces 2 outputs + result1, _result2 = opset18.Split(A, num_outputs=2, axis=0) + # We only return the first output for simplicity + return result1 + + +_model = original_model.to_model_proto() +onnx.checker.check_model(_model) + + +#################################### +# The target pattern with multiple outputs +# ===================== + + +def split_pattern(op, input): + # Pattern to match Split operations with 2 outputs + # num_outputs=2 corresponds to the attribute of the ONNX Split op + # _outputs=2 is an option controlling the pattern constructor + return op.Split(input, num_outputs=2, axis=0, _outputs=2) + + +#################################### +# The replacement pattern with named outputs +# ===================== + + +def custom_split_replacement(op, input, **_): + # Replace with a custom split operation using named outputs + # _outputs=["first_half", "second_half"] assigns names to the outputs + # IMPORTANT: The number of outputs must match the pattern (2 outputs) + return op.CustomSplit( + input, _domain="custom.domain", _outputs=["first_half", "second_half"] + ) + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + split_rule = pattern.RewriteRule( + split_pattern, # target pattern - matches Split with 2 outputs + custom_split_replacement, # replacement pattern - uses named outputs + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([split_rule]) + # Apply rewrite + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/tutorial/rewriter/index.md b/docs/tutorial/rewriter/index.md index 3b4e01e149..d86ae9a474 100644 --- a/docs/tutorial/rewriter/index.md +++ b/docs/tutorial/rewriter/index.md @@ -1,4 +1,4 @@ -# Rewriter Tutorials +# Rewriter Tutorial ```{toctree} rewrite_patterns diff --git a/docs/tutorial/rewriter/node_value_checkers.md b/docs/tutorial/rewriter/node_value_checkers.md new file mode 100644 index 0000000000..e9e5661431 --- /dev/null +++ b/docs/tutorial/rewriter/node_value_checkers.md @@ -0,0 +1,187 @@ +(heading-target-checkers)= +# Node and Value Level Checkers + +The pattern matching infrastructure supports custom validation logic at both the node and value levels through checker functions. These checkers allow for more sophisticated pattern matching by enabling additional constraints beyond basic operator and structure matching. + +## Value-Level Checkers + +Value-level checkers validate properties of specific values in the pattern. They are particularly useful for checking constants, shapes, or other value-specific properties. + +### Basic Usage + +A value checker is a function that takes a `MatchContext` and an `ir.Value`, and returns either a boolean or a `MatchResult`: + +```python +def is_positive_constant(context, value: ir.Value): + """Check if a value is a positive constant.""" + if value.const_value is not None: + # Get the numpy array from const_value + numpy_array = value.const_value.numpy() + + # Check if it represents a single value and is positive + if numpy_array.size != 1: + return False + + return float(numpy_array.item()) > 0 + + return False +``` + +You can use this checker directly in your pattern by passing the callable as an input: + +```python +def add_pattern(op, x, y): + # Use callable as input to create ValuePattern with checker + return op.Add(is_positive_constant, y) +``` + +This pattern will only match `Add` operations where the first input is a positive constant value. + +### Example Usage + +```python +from onnxscript.rewriter import pattern +from onnxscript import ir, optimizer +import onnx + +# Create a model with different Add operations +model_proto = onnx.parser.parse_model(""" + + agraph (float[N] x, float[N] y) => (float[N] z1, float[N] z2, float[N] z3) + { + pos_const = Constant () + neg_const = Constant () + z1 = Add(x, y) # non-constant first parameter + z2 = Add(pos_const, y) # positive constant first parameter + z3 = Add(neg_const, y) # negative constant first parameter + } +""") +model = ir.serde.deserialize_model(model_proto) + +# Apply constant propagation to set const_value fields +optimizer.basic_constant_propagation(model.graph.all_nodes()) + +# Create the pattern with value checker +rule_pattern = pattern.Pattern(add_pattern) + +# Test matching against different Add nodes +add_nodes = [node for node in model.graph if node.op_type == "Add"] + +# Non-constant first parameter - will not match +match_result = rule_pattern.match(model, model.graph, add_nodes[0]) +print(f"Non-constant: {bool(match_result)}") # False + +# Positive constant first parameter - will match +match_result = rule_pattern.match(model, model.graph, add_nodes[1]) +print(f"Positive constant: {bool(match_result)}") # True + +# Negative constant first parameter - will not match +match_result = rule_pattern.match(model, model.graph, add_nodes[2]) +print(f"Negative constant: {bool(match_result)}") # False +``` + +## Node-Level Checkers + +Node-level checkers validate properties of the operation nodes themselves, such as attributes, operation types, or other node-specific properties. + +### Basic Usage + +A node checker is a function that takes a `MatchContext` and an `ir.Node`, and returns either a boolean or a `MatchResult`: + +```python +def shape_node_checker(context, node): + """Check if a Shape operation has start attribute equal to 0.""" + return node.attributes.get_int("start", 0) == 0 +``` + +You can use this checker by passing it to the `_check` parameter of an operation: + +```python +def shape_pattern(op, x): + return op.Shape(x, _check=shape_node_checker) +``` + +This pattern will only match `Shape` operations where the `start` attribute is 0 (or not present, as the default is 0). + +### Example Usage + +```python +from onnxscript.rewriter import pattern +from onnxscript import ir +import onnx + +# Create a model with different Shape operations +model_proto = onnx.parser.parse_model(""" + + agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3) + { + z1 = Shape(x) + z2 = Shape (x) + z3 = Shape (x) + } +""") +model = ir.serde.deserialize_model(model_proto) + +# Create the pattern with node checker +rule_pattern = pattern.Pattern(shape_pattern) + +# Test matching against different Shape nodes +nodes = list(model.graph) +shape_nodes = [node for node in nodes if node.op_type == "Shape"] + +# Shape without start attribute (default 0) - will match +match_result = rule_pattern.match(model, model.graph, shape_nodes[0]) +print(f"No start attr: {bool(match_result)}") # True + +# Shape with start=0 - will match +match_result = rule_pattern.match(model, model.graph, shape_nodes[1]) +print(f"Start=0: {bool(match_result)}") # True + +# Shape with start=1 - will not match +match_result = rule_pattern.match(model, model.graph, shape_nodes[2]) +print(f"Start=1: {bool(match_result)}") # False +``` + +## Combining Checkers + +You can combine both node-level and value-level checkers in the same pattern for more sophisticated matching: + +```python +def complex_pattern(op, x, y): + # Value-level checker for first input + validated_x = is_positive_constant + # Node-level checker for the operation + return op.Add(validated_x, y, _check=lambda ctx, node: len(node.attributes) == 0) +``` + +This pattern will only match `Add` operations where: +1. The first input is a positive constant (value-level check) +2. The node has no custom attributes (node-level check) + +## Execution Timing and Limitations + +### When Checkers Are Called + +Node-level and value-level checkers are called **only at the end of the complete structural match**. This means: + +1. **Structural matching happens first**: The pattern matching engine first validates that the graph structure matches the pattern (correct operators, connections, etc.) +2. **Checkers run after structural validation**: Only after the structural match succeeds do the node and value checkers execute +3. **Order of execution**: Value-level checkers run first, followed by node-level checkers, and finally the pattern's condition function + +### Limitations with Pattern Disjunctions + +One important limitation of this design is that these checks don't compose well with pattern disjunctions (multiple alternative patterns). When searching among multiple value patterns: + +- **Only structural checking is performed initially**: If structural matching succeeds for the first alternative, other alternatives are not considered +- **Checker failures don't trigger backtracking**: If a checker fails, the entire pattern match fails rather than trying the next alternative pattern + +This means you should be careful when designing patterns with multiple alternatives that rely on checkers, as the checker logic may prevent exploration of valid alternative matches. + +## Error Handling + +Checkers can return either: +- `True`: Check passed, continue matching +- `False`: Check failed, pattern does not match +- `MatchResult`: More detailed result with potential failure reasons + +If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions. diff --git a/docs/tutorial/rewriter/or_pattern.md b/docs/tutorial/rewriter/or_pattern.md new file mode 100644 index 0000000000..6c42112467 --- /dev/null +++ b/docs/tutorial/rewriter/or_pattern.md @@ -0,0 +1,20 @@ +# OR Patterns + +*Note* : This feature is work-in-progress. + +Consider the following pattern: + +```{literalinclude} examples/or_pattern.py +:pyobject: scaled_matmul +``` + +This pattern will successfully match against the sequence "MatMul => Mul => Relu" as +well as the sequence "MatMul => Div => Relu". The matcher will bind the variable +specified in `tag_var` (`op_type` in the above example) to a value from those +listed in `tag_values` to indicate which of the alternatives was used for a +successful match. We can use this in the rewrite function to determine how +we want to rewrite the matched sub-graph, as illustrated by the following code: + +```{literalinclude} examples/or_pattern.py +:pyobject: scaled_matmul_replacement +``` diff --git a/docs/tutorial/rewriter/outputs_option.md b/docs/tutorial/rewriter/outputs_option.md new file mode 100644 index 0000000000..cc73bcc561 --- /dev/null +++ b/docs/tutorial/rewriter/outputs_option.md @@ -0,0 +1,43 @@ +# Specifying outputs in the pattern + +This section demonstrates the use of the `_outputs` option in pattern-based rewriting. +The `_outputs` option allows you to specify the number of outputs an operation produces +and optionally assign names to those outputs for easier reference in replacement patterns. + +The `_outputs` option can be specified in two ways: +- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs +- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs + +## Matching operations with multiple outputs + +```{literalinclude} examples/outputs_option.py +:pyobject: split_pattern +``` + +This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2` +specification ensures the pattern only matches operations with this specific output count. + +## Creating replacement operations with named outputs + +```{literalinclude} examples/outputs_option.py +:pyobject: custom_split_replacement +``` + +In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with +descriptive names. This can make the replacement pattern more readable and maintainable. + +**Important**: The number of outputs in the replacement pattern must match the number of +outputs in the target pattern. Since the pattern specifies `_outputs=2`, the replacement +must also produce exactly 2 outputs. + +## Complete rewrite example + +```{literalinclude} examples/outputs_option.py +:pyobject: apply_rewrite +``` + +The `_outputs` option is particularly important when: +- Working with operations that have variable numbers of outputs (like `Split`) +- Creating custom operations that need specific output configurations +- Ensuring pattern matching precision by specifying exact output counts +- Improving code readability by naming outputs in replacement patterns diff --git a/docs/tutorial/rewriter/rewrite_patterns.md b/docs/tutorial/rewriter/rewrite_patterns.md index 7312380446..50615945d1 100644 --- a/docs/tutorial/rewriter/rewrite_patterns.md +++ b/docs/tutorial/rewriter/rewrite_patterns.md @@ -1,10 +1,8 @@ -# Pattern-based Rewrite Using Rules +# Introduction -## Introduction +The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on conditional rewrite rules provided by the user. -The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on rewrite rules provided by the user. - -## Usage +# Usage There are three main components needed when rewriting patterns in the graph: @@ -12,188 +10,40 @@ There are three main components needed when rewriting patterns in the graph: 2. `replacement_pattern` : Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators. 3. `match_condition` (optional) : Pattern rewrite will occur only if the match condition is satisfied. -(heading-target-simple)= -## A Simple Example - -An simple example demonstrating the usage of this functionality using the `GELU` activation function: - -`GELU` activation function can be computed using a Gauss Error Function using the given formula: - -```{math} -\text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})] -``` - -We will show how we can find a subgraph matching this computation and replace it by a call to the function. - -Firstly, include all the rewriter relevant imports. - -```python -from onnxscript.rewriter import pattern -from onnxscript import ir - -_op = pattern.onnxop -``` - -Then create a target pattern that needs to be replaced using onnxscript operators. - -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern -``` - -After this, create a replacement pattern that consists of the GELU onnxscript operator. - -```{literalinclude} examples/erfgelu.py -:pyobject: gelu -``` -:::{note} -:name: type annotate ir.Value - -The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value ` class. -::: - - -For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function. - -```python -rule = pattern.RewriteRule( - erf_gelu_pattern, # Target Pattern - gelu, # Replacement Pattern -) -``` - -Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components: - -1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`. -2. `function_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]` -3. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjunction with `model`. This parameter is of either one of these types: - - `Sequence[PatternRewriteRule]` - - `RewriteRuleSet` - -:::{note} -:name: pattern_rewrite_rules input formatting - -`pattern_rewrite_rules` takes a sequence of `PatternRewriteRule` types or a RewriteRuleSet which is also essentially a rule set created using a sequence of `PatternRewriteRule` types, so if only a singular rewrite rule is to be passed, it needs to passed as part of a sequence. For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset). -::: - -The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above: - -```{literalinclude} examples/erfgelu.py -:pyobject: apply_rewrite -``` - -The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended. - -![target_pattern](examples/img/erfgelu_01.png) ![replacement_pattern](examples/img/erfgelu_02.png) - +## Pattern Options -(heading-target-commute)= -## Utilizing `commute` parameter for pattern-matching -Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure. +When defining patterns, you can use several special options to control how patterns match and what they produce: -![commute](examples/img/erfgelu_03_commute.png){align=center width=500px} +- `_allow_other_attributes`: Controls whether the pattern allows additional attributes not specified in the pattern (default: True) +- `_allow_other_inputs`: Controls whether the pattern allows additional inputs beyond those specified (default: False) +- `_domain`: Specifies the operator domain for matching or creating operations +- `_outputs`: Specifies the number and optionally names of outputs from an operation -In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched. +These options are documented in detail in the following sections. -![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center} - - -If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched. - -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern +```{include} simple_example.md ``` -```{image} examples/img/erfgelu_06_commute.png -:alt: The resulting graph after matching. -:width: 400px -:align: center +```{include} attributes.md ``` -Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods. - -(heading-target-commute-ruleset)= -### 1. Creating a rule-set with different patterns. - -This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example: - -```python -from onnxscript.rewriter import pattern -rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) +```{include} allow_other_inputs.md ``` -In order to apply this method to the example above, first create the two separate target patterns as follows: - -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern -``` -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern_2 +```{include} domain_option.md ``` -Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. - -```{literalinclude} examples/erfgelu.py -:pyobject: apply_rewrite_with_ruleset +```{include} outputs_option.md ``` - -### 2. Using the `commute` parameter while creating a rule. - -Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. - -```{literalinclude} examples/erfgelu.py -:pyobject: apply_rewrite_with_commute +```{include} conditional_rewrite.md ``` -For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows: - -![commute](examples/img/erfgelu_07_commute.png){align=center width=300px} - -## Using the `match_condition` parameter for pattern-matching - -This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration. - -Let us consider a model which consists of the following pattern. - -![target_pattern](examples/img/broadcast_01.png){align=center} - -Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: - -1. Input shapes check: `input_a` and `input_b` should be broadcastable -2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)` - -If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite. - -First, write a target pattern and replacement pattern in a similar way to the first example. - -```{literalinclude} examples/broadcast_matmul.py -:pyobject: two_reshapes_matmul_reshape_pattern +```{include} or_pattern.md ``` -```{literalinclude} examples/broadcast_matmul.py -:pyobject: matmul_pattern +```{include} commute.md ``` -:::{note} -:name: omitting inputs in signature - -The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters. - -Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature. -::: - -In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows: - -```{literalinclude} examples/broadcast_matmul.py -:pyobject: check_if_not_need_reshape +```{include} node_value_checkers.md ``` - -With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. - -```{literalinclude} examples/broadcast_matmul.py -:pyobject: apply_rewrite -``` - -The final graph with the applied rewrite looks as follows: - -![broadcast_rewrite](examples/img/broadcast_02.png){align=center} diff --git a/docs/tutorial/rewriter/simple_example.md b/docs/tutorial/rewriter/simple_example.md new file mode 100644 index 0000000000..53b3c89aff --- /dev/null +++ b/docs/tutorial/rewriter/simple_example.md @@ -0,0 +1,81 @@ +(heading-target-simple)= +# A Simple Example + +An simple example demonstrating the usage of this functionality using the `GELU` activation function: + +`GELU` activation function can be computed using a Gauss Error Function using the given formula: + +```{math} +\text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})] +``` + +We will show how we can find a subgraph matching this computation and replace it by a call to the function. + +Firstly, include all the rewriter relevant imports. + +```python +from onnxscript.rewriter import pattern +from onnxscript import ir + +``` + +Then create a target pattern that needs to be replaced using onnxscript operators. + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` + +After this, create a replacement pattern that consists of the GELU onnxscript operator. + +```{literalinclude} examples/erfgelu.py +:pyobject: gelu +``` +:::{note} +:name: type annotate ir.Value + +The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value ` class. +::: + + +For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function. + +```python +rule = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement Pattern +) +``` + +It is more convenient to organize more complex rewrite-rules as a class. The above rule can be +alternatively expressed as below. + +```{literalinclude} examples/erfgelu.py +:pyobject: ErfGeluFusion +``` + +The corresponding rewrite-rule can be obtained as below: + +```python +erf_gelu_rule_from_class = ErfGeluFusion.rule() +``` + +Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite (model, pattern_rewrite_rules)` call applies the specified rewrite rules to the given model. + +1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `ir.Model` or `onnx.ModelProto`. If the model is an `ir.Model`, the rewriter applies the changes in-place, modifying the input model. If it is an `ModelProto`, the rewriter returns a new `ModelProto` representing the transformed model. +2. `pattern_rewrite_rules` : This parameter either a `Sequence[PatternRewriteRule]` or a `RewriteRuleSet`. + +:::{note} +:name: pattern_rewrite_rules input formatting + +For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset). +::: + +The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above: + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite +``` + +The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended. + +![target_pattern](examples/img/erfgelu_01.png) ![replacement_pattern](examples/img/erfgelu_02.png) diff --git a/docs/update_readme.py b/docs/update_readme.py index ddc5859cd5..7d39406883 100644 --- a/docs/update_readme.py +++ b/docs/update_readme.py @@ -1,4 +1,6 @@ -# Script to update end-to-end example in README.md. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Script to update end-to-end example in README.md.""" updated_readme = [] with open("README.md", encoding="utf-8") as f: @@ -12,7 +14,7 @@ with open( "docs/tutorial/examples/hardmax_end_to_end.py", encoding="utf-8" ) as example_f: - example_code = example_f.readlines() + example_code = example_f.readlines()[2:] # Skip the copyright header updated_readme += example_code if line == "```\n" and in_stub: updated_readme.append(line) diff --git a/examples/pattern_matching_example.py b/examples/pattern_matching_example.py new file mode 100644 index 0000000000..8de09ecd6a --- /dev/null +++ b/examples/pattern_matching_example.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Example demonstrating the new pattern matching functionality.""" + +import onnx.parser + +from onnxscript import ir +from onnxscript.rewriter import pattern + + +def example_standalone_pattern_matching(): + """Example showing how to use Pattern for standalone pattern matching.""" + + print("=== Standalone Pattern Matching Example ===") + + # Define a pattern that matches Identity nodes + def identity_pattern(op, x): + return op.Identity(x) + + # Create a Pattern for standalone pattern matching (no replacement) + pattern_matcher = pattern.Pattern(identity_pattern, name="IdentityMatcher") + + # Create a model with an Identity node + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Identity(x) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + + # Find nodes to test pattern matching against + for node in model.graph: + print(f"Testing pattern against {node.op_type} node...") + match_result = pattern_matcher.match(model, model.graph, node) + + if match_result is not None: + print(f" ✓ Pattern matched! Found {len(match_result.nodes)} nodes in match.") + print(f" Matched node: {match_result.nodes[0].op_type}") + else: + print(f" ✗ Pattern did not match {node.op_type} node.") + + +def example_class_based_pattern(): + """Example showing how to use PatternBase for class-based pattern definition.""" + + print("\n=== Class-Based Pattern Example ===") + + class IdentityPatternClass(pattern.PatternBase): + """A class-based pattern that matches Identity nodes.""" + + def pattern(self, op, x): + return op.Identity(x) + + def check(self, context, x): + """Custom condition - always succeeds for this example.""" + print(f" Checking condition for input: {x}") + return pattern.MatchResult() # Always succeeds + + # Create an instance of the pattern class + identity_pattern_class = IdentityPatternClass(name="ClassBasedIdentity") + + # The Pattern is created internally, we can use the pattern directly + print(f"Created pattern matcher: {identity_pattern_class.name}") + + # Use it directly with the match method + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Identity(x) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + + for node in model.graph: + if node.op_type == "Identity": + print(f"Testing class-based pattern against {node.op_type} node...") + match_result = identity_pattern_class.match(model, model.graph, node) + + if match_result is not None: + print(" ✓ Class-based pattern matched!") + else: + print(" ✗ Class-based pattern did not match.") + + +def example_rewrite_rule_still_works(): + """Example showing that existing RewriteRule functionality is preserved.""" + + print("\n=== Existing RewriteRule Still Works ===") + + def identity_pattern(op, x): + return op.Identity(x) + + def identity_replacement(op, x): + return op.Identity(x) # No-op replacement + + # Create a RewriteRule (which now inherits from Pattern) + rule = pattern.RewriteRule(identity_pattern, identity_replacement, name="IdentityRule") + + print(f"Created rewrite rule: {rule.name}") + print(f"Rule is also a Pattern: {isinstance(rule, pattern.Pattern)}") + + # The rule can be used both for pattern matching and rewriting + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Identity(x) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + + # Use it for just pattern matching (inherited from Pattern) + for node in model.graph: + if node.op_type == "Identity": + print(f"Using RewriteRule for pattern matching on {node.op_type}...") + match_result = rule.match(model, model.graph, node) + + if match_result is not None: + print(" ✓ RewriteRule matched as a pattern matcher!") + + # Use it for rewriting (original functionality) + print("Using RewriteRule for rewriting...") + count = rule.apply_to_model(model) + print(f" Applied rule {count} times") + + +if __name__ == "__main__": + example_standalone_pattern_matching() + example_class_based_pattern() + example_rewrite_rule_still_works() + print("\n=== All Examples Completed ===") diff --git a/examples/pattern_rewriting.py b/examples/pattern_rewriting.py index 737ce02e84..fd84d7f3cb 100644 --- a/examples/pattern_rewriting.py +++ b/examples/pattern_rewriting.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Onnx Pattern Rewriting. This script shows how to define a rewriting rule based on patterns. @@ -13,9 +15,8 @@ import onnx.helper as oh import onnx.numpy_helper as onh -import onnxscript from onnxscript import ir -from onnxscript.rewriter import generic_pattern +from onnxscript.rewriter import pattern def get_rotary_model(bad_model=False): @@ -67,18 +68,17 @@ def get_rotary_model(bad_model=False): # The rewriting pattern # ===================== -op = onnxscript.opset18 -msft_op = onnxscript.values.Opset("com.microsoft", 1) - -def rotary_match_pattern(x, pos_ids, axis): +def rotary_match_pattern(op, x, pos_ids, axis): """The pattern to match.""" unsqueeze = op.Unsqueeze(x, axis) cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) - output, length = msft_op.ConcatTraining(transpose, transpose) + output, _length = op.ConcatTraining( + transpose, transpose, domain="com.microsoft", outputs=2 + ) sin = op.Sin(output) cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) @@ -87,25 +87,13 @@ def rotary_match_pattern(x, pos_ids, axis): return cast1, cast2 -def validate_rotary_mapping(g, match_result) -> bool: - """The validation post matching. - - Returns True to validate the replacement, - False not to apply it. - - :param g: model - :param match_result: matched nodes - """ - del g - del match_result - return True - - -def rotary_apply_pattern(x, pos_ids, axis): +def rotary_apply_pattern(op, x, pos_ids, axis): """The replacement pattern.""" cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) - part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) + part1, part2 = op.RotaryEmbedding( + x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + ) return part1, part2 @@ -115,18 +103,7 @@ def rotary_apply_pattern(x, pos_ids, axis): # # The rule is easy to create. - -rule_with_validation_function = generic_pattern.make_pattern_rule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, -) - -################################ -# ``validate_rotary_mapping`` always return True. -# This argument can be ignored in that case. - -rule = generic_pattern.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern) +rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10) ########################## # Let's apply it. @@ -161,31 +138,6 @@ def rotary_apply_pattern(x, pos_ids, axis): # The match did not happen. # Let's increase the verbosity. -rule = generic_pattern.make_pattern_rule( - rotary_match_pattern, rotary_apply_pattern, verbose=10 -) +rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10) rule.apply_to_model(ir_model) - -###################################### -# The logs shows every time the algorithm rejected a pattern. -# We can see the following: -# -# :: -# -# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast -# --hint--: BACKWARD: different node types -# --pattern -# ConcatTraining(transpose, transpose) -> (output, length) -# -- model -# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1) -# iteration=1 -# --marked-- #2 -# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320] -# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472] -# len(stacked)=0:[] -# -# Line 673 in file `generic_pattern.py`, the match was rejected. -# It says while comparing two nodes in the backward direction, -# node types do not match. -# It also says that two nodes were actually matched. diff --git a/noxfile.py b/noxfile.py index 3aad2dfc35..60c2bb901b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Test with different environment configuration with nox. Documentation: @@ -10,14 +12,12 @@ COMMON_TEST_DEPENDENCIES = ( - "beartype==0.17.2", "expecttest==0.1.6", "hypothesis", - 'numpy==1.24.4; python_version<"3.12"', - 'numpy>1.26.0; python_version>="3.12"', + "numpy", "packaging", "parameterized", - "pyinstrument", + 'psutil; sys_platform != "win32"', "pytest-cov", "pytest-randomly", "pytest-subtests", @@ -25,12 +25,14 @@ "pytest!=7.1.0", "pyyaml", "types-PyYAML", - "typing_extensions", + "typing_extensions>=4.10", + "ml-dtypes", ) -ONNX = "onnx==1.16" -ONNX_RUNTIME = "onnxruntime==1.17.1" -PYTORCH = "torch==2.2.2" -TORCHVISON = "torchvision==0.17.2" +ONNX = "onnx==1.17" +ONNX_RUNTIME = "onnxruntime==1.23.0" +PYTORCH = "torch==2.7.1" +TORCHVISON = "torchvision==0.22.1" +TRANSFORMERS = "transformers==4.37.2" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", "coloredlogs", @@ -39,6 +41,8 @@ "packaging", "protobuf", ) +ONNX_IR = "onnx_ir==0.1.10" +ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" @nox.session(tags=["build"]) @@ -56,7 +60,9 @@ def test(session): PYTORCH, TORCHVISON, ONNX, + ONNX_IR, ONNX_RUNTIME, + TRANSFORMERS, ) session.install(".", "--no-deps") session.run("pip", "list") @@ -70,9 +76,11 @@ def test_torch_nightly(session): session.install( *COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, + TRANSFORMERS, ) session.install("-r", "requirements/ci/requirements-onnx-weekly.txt") session.install("-r", "requirements/ci/requirements-pytorch-nightly.txt") + session.install(ONNX_IR, "--no-deps") session.install(".", "--no-deps") session.run("pip", "list") session.run("pytest", "onnxscript", "--doctest-modules", *session.posargs) @@ -82,7 +90,8 @@ def test_torch_nightly(session): @nox.session(tags=["test-onnx-weekly"]) def test_onnx_weekly(session): """Test with ONNX weekly (preview) build.""" - session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON) + session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON, TRANSFORMERS) + session.install(ONNX_IR, "--no-deps") session.install("-r", "requirements/ci/requirements-onnx-weekly.txt") session.install(".", "--no-deps") session.run("pip", "list") @@ -98,6 +107,8 @@ def test_ort_nightly(session): PYTORCH, TORCHVISON, ONNX, + ONNX_IR, + TRANSFORMERS, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, ) session.install("-r", "requirements/ci/requirements-ort-nightly.txt") @@ -107,43 +118,19 @@ def test_ort_nightly(session): session.run("pytest", "tests", *session.posargs) -@nox.session(tags=["test-experimental-torchlib-tracing"]) -def test_experimental_torchlib_tracing(session): - """Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on.""" - session.install( - *COMMON_TEST_DEPENDENCIES, - PYTORCH, - TORCHVISON, - ONNX, - *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, - ) - session.install("-r", "requirements/ci/requirements-ort-nightly.txt") - session.install(".", "--no-deps") - session.run("pip", "list") - session.run( - "pytest", - "tests/function_libs/torch_lib/ops_test.py", - *session.posargs, - env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"}, - ) - - -@nox.session(tags=["test-experimental-torchlib-onnx-ir"]) -def test_experimental_torchlib_onnx_ir(session): - """Test TorchLib using the ONNX IR to build graphs.""" +@nox.session(tags=["test-onnx-ir-git"]) +def test_onnx_ir_git(session): + """Test with ONNX IR Git builds.""" session.install( *COMMON_TEST_DEPENDENCIES, PYTORCH, TORCHVISON, ONNX, - *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, + ONNX_RUNTIME, + TRANSFORMERS, ) - session.install("-r", "requirements/ci/requirements-ort-nightly.txt") + session.install(ONNX_IR_MAIN) session.install(".", "--no-deps") session.run("pip", "list") - session.run( - "pytest", - "tests/function_libs/torch_lib/ops_test.py", - *session.posargs, - env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"}, - ) + session.run("pytest", "onnxscript", "--doctest-modules", *session.posargs) + session.run("pytest", "tests", *session.posargs) diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py index bee5a1b230..b839093d2b 100644 --- a/onnxscript/__init__.py +++ b/onnxscript/__init__.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- __all__ = [ "script", @@ -9,6 +7,7 @@ "ir", "optimizer", "rewriter", + "version_converter", "export_onnx_lib", "OnnxFunction", "TracedOnnxFunction", @@ -54,10 +53,14 @@ "opset18", "opset19", "opset20", + "opset21", + "opset22", "opset_ai_onnx_ml1", "opset_ai_onnx_ml2", "opset_ai_onnx_ml3", "opset_ai_onnx_ml4", + "opset_ai_onnx_ml5", + "DEBUG", ] import importlib.metadata @@ -87,10 +90,13 @@ opset18, opset19, opset20, + opset21, + opset22, opset_ai_onnx_ml1, opset_ai_onnx_ml2, opset_ai_onnx_ml3, opset_ai_onnx_ml4, + opset_ai_onnx_ml5, ) from .onnx_types import ( @@ -118,10 +124,13 @@ # isort: on -from . import ir, optimizer, rewriter +from . import ir, optimizer, rewriter, version_converter from ._internal.utils import external_tensor from .values import OnnxFunction, TracedOnnxFunction +# Set DEBUG to True to enable additional debug checks +DEBUG: bool = False + try: # noqa: SIM105 __version__ = importlib.metadata.version("onnxscript") except importlib.metadata.PackageNotFoundError: diff --git a/onnxscript/_framework_apis/__init__.py b/onnxscript/_framework_apis/__init__.py new file mode 100644 index 0000000000..2aee3dcace --- /dev/null +++ b/onnxscript/_framework_apis/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Semi-private stable APIs for framework-specific usage only.""" diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py new file mode 100644 index 0000000000..162faf4b75 --- /dev/null +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Stable APIs for PyTorch 2.5.""" + +from __future__ import annotations + +__all__ = [ + "check_model", + "convert_version", + "get_torchlib_ops", + "optimize", + "save_model_with_external_data", +] + +import dataclasses +import os +import pathlib +from typing import Callable + +from onnxscript import ir, optimizer, version_converter +from onnxscript.function_libs.torch_lib import registration + + +@dataclasses.dataclass(frozen=True) +class _OnnxFunctionMeta: + """A wrapper of onnx-script function with additional metadata. + + qualified_name: The qualified name of the aten operator. + function: The onnx-script function. + domain: The domain of the function. + name: The name of the function. + is_complex: Whether the function is a complex function. + """ + + qualified_name: str + function: Callable + domain: str + name: str + is_complex: bool = False + + +def optimize(model: ir.Model) -> ir.Model: + """Optimize the model.""" + # Internal flag. Will go away. + enabled = os.getenv("TORCH_ONNX_ENABLE_OPTIMIZATION") == "1" + if enabled: + optimizer.optimize_ir(model) + return model + + +def convert_version(model: ir.Model, target_version: int) -> ir.Model: + """Convert the model to the specified ONNX opset version.""" + # Internal flag. Will go away. + enabled = os.getenv("TORCH_ONNX_ENABLE_VERSION_CONVERSION") == "1" + if enabled: + version_converter.convert_version(model, target_version) + return model + + +def check_model(model: ir.Model) -> None: + """Check the model.""" + + del model # Unused yet + + +def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None: + """Save the model with external data. The model is unchanged after saving.""" + + # TODO(#1835): Decide if we want to externalize large attributes as well + uninitialized_values = [ + value.name for value in model.graph.initializers.values() if value.const_value is None + ] + if uninitialized_values: + raise ValueError( + f"The model contains uninitialized initializer values ({uninitialized_values}). " + "Please make sure all initializer values are initialized." + ) + destination_path = pathlib.Path(model_path) + data_path = f"{destination_path.name}.data" + + ir.save(model, model_path, external_data=data_path) + + +def get_torchlib_ops() -> list[_OnnxFunctionMeta]: + # Trigger op registration + from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel + ops, + ) + + del ops # Unused + + torchlib_registry = registration.default_registry + function_metas = [] + + for qualified_name, aten_overloads_func in torchlib_registry.items(): + if qualified_name.startswith("internal::"): + # Skip the custom defined internal functions + continue + + for overload_func in aten_overloads_func.overloads: + function_meta = _OnnxFunctionMeta( + qualified_name=qualified_name, + function=overload_func, + domain=overload_func.function_ir.domain, + name=overload_func.name, + is_complex=False, + ) + function_metas.append(function_meta) + for complex_func in aten_overloads_func.complex: + function_meta = _OnnxFunctionMeta( + qualified_name=qualified_name, + function=complex_func, + domain=complex_func.function_ir.domain, + name=complex_func.name, + is_complex=True, + ) + function_metas.append(function_meta) + + return function_metas diff --git a/onnxscript/_framework_apis/torch_2_6.py b/onnxscript/_framework_apis/torch_2_6.py new file mode 100644 index 0000000000..2d166cb967 --- /dev/null +++ b/onnxscript/_framework_apis/torch_2_6.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Stable APIs for PyTorch 2.6.""" + +from __future__ import annotations + +__all__ = [ + "check_model", + "convert_version", + "get_torchlib_ops", + "optimize", + "save_model_with_external_data", + "torchlib_opset", +] +import logging +from typing import TYPE_CHECKING + +from onnxscript import ir, optimizer, version_converter +from onnxscript._framework_apis.torch_2_5 import ( + check_model, + get_torchlib_ops, + save_model_with_external_data, +) + +if TYPE_CHECKING: + from onnxscript.onnx_opset._impl.opset18 import Opset18 + + +logger = logging.getLogger(__name__) + + +def optimize(model: ir.Model) -> ir.Model: + """Optimize the model.""" + optimizer.optimize_ir(model) + return model + + +def convert_version(model: ir.Model, target_version: int) -> ir.Model: + """Convert the model to the specified ONNX opset version.""" + if target_version < 18: + logger.warning("Conversion to opset < 18 is not supported.") + return model + version_converter.convert_version(model, target_version, fallback=True) + return model + + +def torchlib_opset() -> Opset18: + """Return the default opset for torchlib.""" + import onnxscript # pylint: disable=import-outside-toplevel + + return onnxscript.opset18 # type: ignore + + +def torchlib_opset_version() -> int: + """Return the default opset version for torchlib.""" + + return torchlib_opset().version diff --git a/onnxscript/_framework_apis/torch_2_7.py b/onnxscript/_framework_apis/torch_2_7.py new file mode 100644 index 0000000000..ee5e6089e5 --- /dev/null +++ b/onnxscript/_framework_apis/torch_2_7.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Stable APIs for PyTorch 2.7.""" + +from __future__ import annotations + +__all__ = [ + "check_model", + "convert_version", + "get_torchlib_ops", + "optimize", + "save_model_with_external_data", +] + +from onnxscript._framework_apis.torch_2_6 import ( + check_model, + convert_version, + get_torchlib_ops, + optimize, + save_model_with_external_data, +) diff --git a/onnxscript/_framework_apis/torch_2_8.py b/onnxscript/_framework_apis/torch_2_8.py new file mode 100644 index 0000000000..dca34086a0 --- /dev/null +++ b/onnxscript/_framework_apis/torch_2_8.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Stable APIs for PyTorch 2.8.""" + +from __future__ import annotations + +__all__ = [ + "check_model", + "convert_version", + "get_torchlib_ops", + "optimize", + "save_model_with_external_data", +] + +import onnx_ir as ir + +import onnxscript.optimizer +import onnxscript.rewriter.onnx_fusions +from onnxscript._framework_apis.torch_2_6 import ( + check_model, + convert_version, + get_torchlib_ops, + save_model_with_external_data, +) + + +def optimize(model: ir.Model) -> ir.Model: + """Optimize the model.""" + onnxscript.optimizer.optimize_ir(model) + onnxscript.rewriter.onnx_fusions.fuse(model) + return model diff --git a/onnxscript/_framework_apis/torch_2_9.py b/onnxscript/_framework_apis/torch_2_9.py new file mode 100644 index 0000000000..88c9b85734 --- /dev/null +++ b/onnxscript/_framework_apis/torch_2_9.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Stable APIs for PyTorch 2.9.""" + +from __future__ import annotations + +__all__ = [ + "check_model", + "convert_version", + "get_torchlib_ops", + "optimize", + "save_model_with_external_data", +] + +from typing import TYPE_CHECKING + +from onnxscript import version_converter +from onnxscript._framework_apis.torch_2_8 import ( + check_model, + get_torchlib_ops, + optimize, + save_model_with_external_data, +) + +if TYPE_CHECKING: + import onnx_ir as ir + + +def convert_version(model: ir.Model, target_version: int) -> ir.Model: + """Convert the model to the specified ONNX opset version. + + Starting from PyTorch 2.9, down conversion is turned on and supported. + """ + version_converter.convert_version(model, target_version, fallback=True) + return model diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py index 0901382eee..0403f60c91 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/analysis.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import ast diff --git a/onnxscript/_internal/analysis_test.py b/onnxscript/_internal/analysis_test.py index 5531ec3833..74e7ca4c18 100644 --- a/onnxscript/_internal/analysis_test.py +++ b/onnxscript/_internal/analysis_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import ast diff --git a/onnxscript/_internal/ast_utils.py b/onnxscript/_internal/ast_utils.py index 974ae75a09..4146f38e2f 100644 --- a/onnxscript/_internal/ast_utils.py +++ b/onnxscript/_internal/ast_utils.py @@ -1,23 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Utilities for working with Python ASTs.""" from __future__ import annotations import ast import inspect -import sys import textwrap -import types +from typing import Callable -PY_VERSION_GE_39 = sys.version_info >= (3, 9) - -def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]: +def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]: try: - src = inspect.getsource(f) + src = inspect.getsource(func) except OSError as e: raise RuntimeError( - f"Decorator script does not work on dynamically " - f"compiled function {f.__name__}." + f"Decorator script does not work on dynamically compiled function {func.__name__}." ) from e src = textwrap.dedent(src) top_level_ast = ast.parse(src) @@ -34,17 +32,10 @@ def normalize_subscript_expr(expr: ast.Subscript): # Returns a list of expressions, denoting the indices, after stripping the extraneous "Index" # wrapper present in python versions before 3.9 index_expr = expr.slice - if PY_VERSION_GE_39: - if isinstance(index_expr, ast.Tuple): - return index_expr.elts # multiple indices - else: - return [index_expr] # single index + if isinstance(index_expr, ast.Tuple): + return index_expr.elts # multiple indices else: - if isinstance(index_expr, ast.ExtSlice): - indices = index_expr.dims # type: ignore[attr-defined] - else: - indices = [index_expr] # single slice-index - return [x.value if isinstance(x, ast.Index) else x for x in indices] # type: ignore[attr-defined] + return [index_expr] # single index def is_print_call(stmt: ast.stmt) -> bool: diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index b79180ae59..1defac3e53 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations @@ -9,10 +7,10 @@ import numpy as np import onnx -from onnx import helper, numpy_helper +import onnx.helper # noqa: TID251 from onnx.defs import OpSchema -from onnxscript import tensor +from onnxscript import ir, tensor if TYPE_CHECKING: from onnxscript import converter @@ -26,42 +24,8 @@ # Utilities to convert a python value to TensorProto (for use by the script converter) -def _py_type_to_onnx_type(pytype: type): - if pytype is bool: - return onnx.TensorProto.BOOL - if pytype is int: - return onnx.TensorProto.INT64 - if pytype is float: - return onnx.TensorProto.FLOAT - if pytype is str: - return onnx.TensorProto.STRING - raise ValueError(f"Tensor element of type {pytype} not supported") - - def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue): - if isinstance(pyvalue, np.ndarray): - return numpy_helper.from_array(pyvalue, tensor_name) - if isinstance(pyvalue, list): - if len(pyvalue) == 0: - raise ValueError("Cannot convert an empty list to tensor") - pytype = type(pyvalue[0]) - if not all(isinstance(e, pytype) for e in pyvalue): - raise ValueError( - "Cannot convert an list with elements of different types to tensor" - ) - return helper.make_tensor( - tensor_name, - _py_type_to_onnx_type(pytype), - [len(pyvalue)], - pyvalue, - ) - onnx_type = _py_type_to_onnx_type(type(pyvalue)) - if onnx_type is onnx.TensorProto.BOOL: - return helper.make_tensor(tensor_name, onnx_type, [], [int(pyvalue)]) - if onnx_type is onnx.TensorProto.STRING: - return helper.make_tensor(tensor_name, onnx_type, [], vals=[pyvalue.encode("utf-8")]) - - return helper.make_tensor(tensor_name, onnx_type, [], [pyvalue]) + return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name)) _REPEATED_ATTRIBUTE_TYPES = frozenset( @@ -81,7 +45,7 @@ def pyvalue_to_onnx_attribute( key: str, value: Any, name_generator: Callable[[], str], - attr_type: Optional[onnx.AttributeProto.AttributeType] = None, + attr_type: onnx.AttributeProto.AttributeType | None = None, ) -> onnx.AttributeProto: """Helper function to create an ONNX AttributeProto. @@ -105,7 +69,9 @@ def pyvalue_to_onnx_attribute( name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value) ) else: - return onnx.helper.make_attribute(key, value) + # When the value is a subgraph, ONNX IR will complain that some values are + # not found from the scope. + return onnx.helper.make_attribute(key, value) # noqa: TID251 # Utilities to convert python values into onnxscript tensors. diff --git a/onnxscript/_internal/deprecation.py b/onnxscript/_internal/deprecation.py index 57769ba091..7bf18482a2 100644 --- a/onnxscript/_internal/deprecation.py +++ b/onnxscript/_internal/deprecation.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- """Utility for deprecating APIs.""" # Reference: https://github.com/pytorch/pytorch/blob/aed9bee0413dac190452fbfa9ab2a44b6e6843f5/torch/onnx/_deprecation.py @@ -14,6 +12,12 @@ T = TypeVar("T") +@functools.lru_cache(maxsize=1024) +def _warn_once(message: str): + """Issue a FutureWarning only once per message.""" + warnings.warn(message, category=FutureWarning, stacklevel=3) + + def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T], T]: """Marks functions as deprecated. @@ -32,12 +36,10 @@ def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T], def decorator(function): @functools.wraps(function) def wrapper(*args, **kwargs): - warnings.warn( + _warn_once( f"'{function.__module__}.{function.__qualname__}' " f"is deprecated in version {since} and will be " f"removed in {removed_in}. Please {instructions}.", - category=FutureWarning, - stacklevel=2, ) return function(*args, **kwargs) diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py index 54593abf32..b3591a0a8d 100644 --- a/onnxscript/_internal/param_manipulation.py +++ b/onnxscript/_internal/param_manipulation.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Function for manipulating input parameters of an Op or a OnnxFunction.""" from __future__ import annotations @@ -129,3 +131,18 @@ def tag_arguments_with_param_schemas( raise TypeError(f"Required input/attribute '{param}' was not provided") return tagged_args, tagged_kwargs + + +def turn_to_kwargs_to_avoid_ordering( + param_schemas: Sequence[values.ParamSchema], + inputs: list[Any], + attributes: dict[str, Any], +) -> dict[str, Any]: + """Return the inputs and attributes to the order of the function signature.""" + for idx, param in enumerate(param_schemas): + if param.name not in attributes: + if param.is_variadic_input: + attributes[param.name] = inputs[idx:] + elif inputs: + attributes[param.name] = inputs.pop(0) + return attributes diff --git a/onnxscript/_internal/param_manipulation_test.py b/onnxscript/_internal/param_manipulation_test.py index f7148268e0..7b67e4380d 100644 --- a/onnxscript/_internal/param_manipulation_test.py +++ b/onnxscript/_internal/param_manipulation_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. # mypy: disable-error-code=misc import collections diff --git a/onnxscript/_internal/runtime_typing.py b/onnxscript/_internal/runtime_typing.py deleted file mode 100644 index 54e7dae0c0..0000000000 --- a/onnxscript/_internal/runtime_typing.py +++ /dev/null @@ -1,39 +0,0 @@ -"""An internal wrapper for the beartype library. - -Decorate a function with `@runtime_typing.checked` to enable runtime -type checking. The decorator is a no-op when the `beartype` library is not -installed. -""" - -import typing -import warnings - -__all__ = [ - "checked", -] - -T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any]) - -try: - from beartype import beartype as checked - from beartype import roar as _roar - - # Beartype warns when we import from typing because the types are deprecated - # in Python 3.9. But there will be a long time until we can move to using - # the native container types for type annotations (when 3.9 is the lowest - # supported version). So we silence the warning. - warnings.filterwarnings( - "ignore", - category=_roar.BeartypeDecorHintPep585DeprecationWarning, - ) -except ImportError: - - def checked(func: T) -> T: # type: ignore[no-redef] - return func - -except Exception as e: # pylint: disable=broad-exception-caught - # Warn errors that are not import errors (unexpected). - warnings.warn(f"{e}", stacklevel=2) - - def checked(func: T) -> T: # type: ignore[no-redef] - return func diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index c4537e3bcd..ce2b657cfd 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import numbers @@ -9,7 +7,6 @@ import numpy as np import onnx -import onnx.helper from onnxscript import tensor @@ -67,26 +64,26 @@ def add(k, v): def value_to_type_proto(val): """Return the ONNX type of a python-value.""" if isinstance(val, (np.ndarray, tensor.Tensor)): - elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) + elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251 shape = val.shape - return onnx.helper.make_tensor_type_proto(elem_type, shape) + return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251 if isinstance(val, int): - return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) + return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251 if isinstance(val, (float, np.float32)): - return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) + return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251 if isinstance(val, list): if len(val) > 0: - return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) + return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251 # Edge-case. Cannot determine a suitable ONNX type for an empty list. # Should be using a typed-value instead. # Treated as a sequence of tensors of float-type. - return onnx.helper.make_sequence_type_proto( - onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) + return onnx.helper.make_sequence_type_proto( # noqa: TID251 + onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251 ) if isinstance(val, numbers.Number): nparray = np.array(val) - elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) - return onnx.helper.make_tensor_type_proto(elem_type, []) + elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251 + return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251 raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") @@ -95,7 +92,7 @@ def values_to_value_infos(name_values): skipping any None values. """ return [ - onnx.helper.make_value_info(name, value_to_type_proto(val)) + onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251 for (name, val) in name_values if val is not None ] diff --git a/onnxscript/_internal/version_utils.py b/onnxscript/_internal/version_utils.py index c66cb8d2ba..2b43c54f49 100644 --- a/onnxscript/_internal/version_utils.py +++ b/onnxscript/_internal/version_utils.py @@ -1,5 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Version utils for testing.""" +from __future__ import annotations + +import warnings +from typing import Callable, Sequence + import packaging.version @@ -23,6 +30,19 @@ def torch_older_than(version: str) -> bool: ) +def transformers_older_than(version: str) -> bool | None: + """Returns True if the transformers version is older than the given version.""" + try: + import transformers # pylint: disable=import-outside-toplevel + except ImportError: + return None + + return ( + packaging.version.parse(transformers.__version__).release + < packaging.version.parse(version).release + ) + + def onnxruntime_older_than(version: str) -> bool: """Returns True if the onnxruntime version is older than the given version.""" import onnxruntime # pylint: disable=import-outside-toplevel @@ -31,3 +51,48 @@ def onnxruntime_older_than(version: str) -> bool: packaging.version.parse(onnxruntime.__version__).release < packaging.version.parse(version).release ) + + +def numpy_older_than(version: str) -> bool: + """Returns True if the numpy version is older than the given version.""" + import numpy # pylint: disable=import-outside-toplevel + + return ( + packaging.version.parse(numpy.__version__).release + < packaging.version.parse(version).release + ) + + +def has_transformers(): + """Tells if transformers is installed.""" + try: + import transformers # pylint: disable=import-outside-toplevel + + assert transformers + return True # noqa + except ImportError: + return False + + +def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type] + """Catches warnings. + + Args: + warns: warnings to ignore + + Returns: + decorated function + """ + + def wrapper(fct): + if warns is None: + raise AssertionError(f"warns cannot be None for '{fct}'.") + + def call_f(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", warns) # type: ignore[arg-type] + return fct(self) + + return call_f + + return wrapper diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py deleted file mode 100644 index 74aa693593..0000000000 --- a/onnxscript/_legacy_ir/__init__.py +++ /dev/null @@ -1,339 +0,0 @@ -from __future__ import annotations - -import dataclasses -from collections import deque -from typing import List, Tuple, Union - -import numpy as np -import onnx - - -class Unknown: - """A special value used to indicate that a value is not a statically known constant. - - We use this instead of None because None is a valid constant value (since ONNX - supports the Optional type). - """ - - instance = None - - def __init__(self) -> None: - if Unknown.instance is not None: - raise ValueError("Unknown.instance is already set") - Unknown.instance = self - - -# Singleton instance of Unknown -unknown = Unknown() -NotConstant = unknown - -# ConcreteValue: This type represents constant values that an ONNX variable can take. -# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals, -# maps, etc. -# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto. -# A uniform representation would be helpful, but we should avoid unnecessary conversions for -# large tensors. Should be cleaned up in the new IR. -ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None] - -# SymbolicValue: This information is used to enable partial-evaluation and specialization -# of sequence operations, as well as elimination of redundant Identity ops. -# The symbolic value of a variable X can be: -# - a string with the value "Y", indicating that "X" is a copy of "Y" -# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values -# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to -# "SequenceConstruct(A, B, C)". -# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of -# tensors, etc. However, we currently only handle lists of tensors. - -SymbolicValue = Union[str, List[str]] - -FunctionId = Tuple[str, str, str] - - -def get_function_id(function: onnx.FunctionProto) -> FunctionId: - return (function.domain, function.name, getattr(function, "overload", "")) - - -def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId: - return (node.domain, node.op_type, getattr(node, "overload", "")) - - -@dataclasses.dataclass -class StaticValueInfo: - name: str - value: ConcreteValue = NotConstant - type: onnx.TypeProto | None = None - symbolic_value: SymbolicValue | None = None - - def is_copy(self) -> bool: - return isinstance(self.symbolic_value, str) - - def tensor_shape_proto(self) -> onnx.TensorShapeProto | None: - """Returns the shape of a tensor or None. - - A return value of None could mean that the type is unknown or that the type is not a tensor - or that the tensor shape (that is, even the rank) is unknown. - """ - type = self.type - if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - return type.tensor_type.shape - return None - - @property - def shape(self) -> list[str | int | None] | None: - """Returns the shape in a list. - - Str means that the shape is dynamic. - """ - type = self.type - if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - dims = [] - for dim in type.tensor_type.shape.dim: - if dim.HasField("dim_param"): - dims.append(dim.dim_param) - elif dim.HasField("dim_value"): - dims.append(dim.dim_value) - else: - dims.append(None) - return dims - if self.value_as_np_array is not None: - return list(self.value_as_np_array.shape) - return None - - @property - def element_type(self) -> int | None: - """Returns the element type of a tensor, or None if type is not known or is not a tensor.""" - type = self.type - if type and type.HasField("tensor_type"): - return type.tensor_type.elem_type - return None - - def identity_merge_from(self, other: StaticValueInfo) -> None: - """Merge the value of other into self. - - This models the effect of an identity (copy) operation. - This will update static-analysis information based on incoming value. - """ - if not isinstance(other, StaticValueInfo): - raise TypeError(f"Cannot merge {other} into {self}.") - if other.value is not NotConstant: - self.value = other.value - # TODO: merge and combine best shape information from both types. - if other.tensor_shape_proto() is not None and other.element_type is not None: - self.type = other.type - # We cannot copy symbolic value across different scopes. - - # WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo - # does not fill in the following fields. These fields are filled in by the IRBuilder - # which constructs the IR from the ONNX model. - node: Node | None = None - uses: list[Node] = dataclasses.field(default_factory=list) - output_index: int | None = None - is_output: bool = False - - @property - def const_value(self) -> ConcreteValue: - return self.value - - @property - def value_as_np_array(self) -> np.ndarray | None: - if isinstance(self.value, np.ndarray): - return self.value - if isinstance(self.value, onnx.TensorProto): - return onnx.numpy_helper.to_array(self.value) - return None - - def def_node(self) -> Node | None: - return self.node - - def def_index(self) -> int: - return self.output_index # type: ignore[return-value] - - def is_same_as(self, other: StaticValueInfo) -> bool: - """Returns true if this value represents the same IR object as the other value. - - This is *not* value-equality, but rather object-equality. - """ - return self is other - - def __str__(self) -> str: - shape = self.shape - if shape is not None: - shape = [str(dim) for dim in shape] - shape_str = f"[{', '.join(shape)}]" # type: ignore[arg-type] - else: - shape_str = "None" - return ( - f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, " - f"{'has const value' if self.value is not unknown else 'no const value'}.)" - ) - - -Value = StaticValueInfo - - -class Model: - def __init__(self) -> None: - self.gen_var_counter: int = 0 - - def set( - self, - model_proto: onnx.ModelProto, - graph: Graph, - functions: list[Function], - version_map: dict[str, int], - ) -> None: - """TODO. This is a temporary patch.""" - self.original_model_proto = model_proto - self.graph = graph - self.functions = functions - self.version_map = version_map - - def make_new_name(self): - # Temporary hack. - self.gen_var_counter += 1 - return f"_gen_{self.gen_var_counter}" - - def __str__(self) -> str: - # TODO: Naive string representation for debugging. Need to improve this. - return "\n".join( - [ - f"ModelGraph: {self.graph}", - f"Functions: {self.functions}", - f"VersionMap: {self.version_map}", - ] - ) - - -class Graph: - def __init__(self, graph_proto: onnx.GraphProto): - self.original_graph_proto = graph_proto - self.nodes: deque[Node] = deque() - self.values: dict[str, Value] = {} - - @property - def name(self) -> str: - return self.original_graph_proto.name - - def __str__(self) -> str: - return "\n".join( - [ - "Graph", - f"Nodes: {[str(n) for n in self.nodes]}", - f"Values: {[str(v) for v in self.values]}", - ] - ) - - @property - def input_names(self) -> list[str]: - return [_.name for _ in self.original_graph_proto.input] - - @property - def output_names(self) -> list[str]: - return [_.name for _ in self.original_graph_proto.output] - - -class Function: - def __init__(self, function_proto: onnx.FunctionProto): - self.original_function_proto = function_proto - self.nodes = deque() # type: ignore[var-annotated] - self.values = {} # type: ignore[var-annotated] - - @property - def id(self) -> FunctionId: - return (self.domain, self.name, self.overload) - - @property - def domain(self) -> str: - return self.original_function_proto.domain - - @property - def name(self) -> str: - return self.original_function_proto.name - - @property - def overload(self) -> str: - return getattr(self.original_function_proto, "overload", "") - - def __str__(self) -> str: - return "\n".join( - [ - "Function", - f"Nodes: {[str(n) for n in self.nodes]}", - f"Values: {[str(v) for v in self.values]}", - ] - ) - - -class RefAttr: - def __init__(self, name: str, ref_attr_name: str, type) -> None: - self.name = name - self.ref_attr_name = ref_attr_name - self.type = type - - def to_proto(self) -> onnx.AttributeProto: - attr_proto = onnx.AttributeProto() - attr_proto.name = self.name - attr_proto.ref_attr_name = self.ref_attr_name - attr_proto.type = self.type - return attr_proto - - -class Node: - def __init__( - self, - node_proto: onnx.NodeProto, - populate_io: bool = False, - ) -> None: - self.original_node_proto = node_proto - self.domain: str = node_proto.domain - self.version: int | None = None - self.op_type: str = node_proto.op_type - if populate_io: - self.inputs: list[Value | None] = [Value(i) for i in node_proto.input] - self.outputs: list[Value | None] = [Value(i) for i in node_proto.output] - else: - self.inputs: list[Value | None] = [] # type: ignore[no-redef] - self.outputs: list[Value | None] = [] # type: ignore[no-redef] - # TODO: attributes are never populated. - self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {} - - def __repr__(self) -> str: - return ( - f"{self.op_type}({','.join(self.original_node_proto.input)})" - f"->{','.join(self.original_node_proto.output)}" - ) - - @property - def name(self) -> str: - return self.original_node_proto.name - - @property - def input_names(self): - return self.original_node_proto.input - - @property - def output_names(self): - return self.original_node_proto.output - - @property - def attribute(self): - return self.original_node_proto.attribute - - def set_version_if_custom_op(self, version_map: dict[str, int]) -> None: - if self.domain != "" and self.domain in version_map: - self.version = version_map[self.domain] - - def get_attribute(self, name: str) -> int | float | None: - return self.attributes.get(name, None) # type: ignore[return-value] - - def __str__(self) -> str: - return "\n".join( - [ - "Node", - f"OpType: {self.op_type}", - f"Inputs: {self.inputs}", - f"Outputs: {self.outputs}", - f"Attributes: {self.attributes}", - ] - ) diff --git a/onnxscript/_legacy_ir/visitor.py b/onnxscript/_legacy_ir/visitor.py deleted file mode 100644 index 3044fdd77e..0000000000 --- a/onnxscript/_legacy_ir/visitor.py +++ /dev/null @@ -1,922 +0,0 @@ -from __future__ import annotations - -import dataclasses -import logging -from typing import Any, Sequence - -import numpy as np -import onnx - -import onnxscript._legacy_ir as ir -from onnxscript.utils.utils import ( - get_initializer_type, - is_control_flow_op, - normalize_domain, -) - -logger = logging.getLogger(__name__) - - -def _override_inferred_value_type_with_symbolic_value_type( - symbolic_value: ir.Value | None, - inferred_value: ir.Value | None, -) -> ir.Value | None: - if inferred_value is not None and symbolic_value is not None: - inferred_value.type = symbolic_value.type - if inferred_value is None: - inferred_value = symbolic_value - return inferred_value - - -def is_local_function_node( - node: onnx.NodeProto, functions: dict[ir.FunctionId, onnx.FunctionProto] -) -> bool: - return ir.get_function_id_from_node(node) in functions - - -class FunctionShapeEnv: - def __init__(self): - # Mapping from (domain, function_name, overload) to {value_name: ir_value} - self._function_values: dict[ir.FunctionId, dict[str, ir.Value]] = {} - - def load_from_model_proto(self, model_proto: onnx.ModelProto) -> None: - for value_info in model_proto.graph.value_info: - self.load_from_value_info(value_info) - - def save_to_model_proto(self, model_proto: onnx.ModelProto) -> None: - for ( - domain, - function_name, - overload, - ), named_ir_values in self._function_values.items(): - for ir_value in named_ir_values.values(): - if ( - value_info := self.save_to_value_info( - ir_value, domain, function_name, overload - ) - ) is not None: - model_proto.graph.value_info.append(value_info) - - def load_from_value_info(self, value_info: onnx.ValueInfoProto) -> None: - function_id, ir_value = self.process_value_info(value_info) - if function_id is not None: - logger.debug( - "Loads torch symbolic value info '%s'.", - value_info.name, - ) - self._function_values.setdefault(function_id, {})[ir_value.name] = ir_value - - def process_value_info( - self, value_info: onnx.ValueInfoProto - ) -> tuple[ir.FunctionId | None, ir.Value]: - name = value_info.name - if len(splits := name.split("/")) == 2: - # Experimental function value info format. - # To be deprecated after ONNX 1.16, where value_info is introduced in FunctionProto. - function_id, value_name = splits - splits = function_id.split("::") - domain, function_name = splits[0], splits[1] - # 'overload' is introduced in ONNX 1.16, consider it as empty string prior to that. - # The code is for future proof, in case overload is encoded in this format. - overload = "" - if len(splits) == 3: - overload = splits[2] - function_id = (domain, function_name, overload) - else: - # Standard main graph value info format. - function_id = None - value_name = name - return function_id, ir.Value(name=value_name, type=value_info.type) - - def save_to_value_info( - self, value: ir.Value, domain: str, function_name: str, overload: str - ) -> onnx.ValueInfoProto | None: - if overload != "": - raise NotImplementedError("Overload is not supported yet.") - function_id = f"{domain}::{function_name}" - - if value.type is not None: - return onnx.helper.make_value_info(f"{function_id}/{value.name}", value.type) - return None - - def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | None: - """Lookup ir value of 'value_name' inside 'function'.""" - function_id = ir.get_function_id(function) - function_values = self._function_values.get(function_id) - if function_values is None or (ir_value := function_values.get(value_name)) is None: - logger.debug( - "Lookup Missed %s torch symbolic value info in function %s::%s.", - value_name, - function.domain, - function.name, - ) - return None - logger.debug( - "Lookup found %s torch symbolic value info in function %s::%s.", - value_name, - function.domain, - function.name, - ) - return ir_value - - def bind(self, value: ir.Value, domain: str, function_name: str, overload: str) -> None: - """Bind ir value 'value' to 'value_name' inside 'function'.""" - function_id = (domain, function_name, overload) - self._function_values.setdefault(function_id, {})[value.name] = value - - def get_ir_values(self, function: onnx.FunctionProto) -> dict[str, ir.Value]: - """Get all ir values inside 'function'.""" - function_id = ir.get_function_id(function) - return self._function_values.get(function_id, {}) - - -class SubScope: - values: dict[str, ir.Value] - ref_attributes: dict[str, onnx.AttributeProto] - owner: onnx.GraphProto | onnx.FunctionProto - - def __init__(self, owner: onnx.GraphProto | onnx.FunctionProto): - self.values = {} - self.ref_attributes = {} - self.owner = owner - - def lookup(self, name: str) -> ir.Value | None: - return self.values.get(name) - - def bind(self, name: str, value: ir.Value) -> None: - self.values[name] = value - - def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: - return self.ref_attributes.get(ref_attr_name) - - def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: - self.ref_attributes[ref_attr_name] = attr - - def readable_strs(self, indent: int = 0) -> list[str]: - indent_str = " " * indent - strs = [] - if isinstance(self.owner, onnx.GraphProto): - strs.append(f"Graph {self.owner.name}:") - else: - strs.append(f"Function {self.owner.name}:") - strs.append(" ir.Values:") - for name, value in self.values.items(): - strs.append(f" {name}: {value}") - strs.append(" RefAttributes:") - for name, attr in self.ref_attributes.items(): - strs.append(f" {name}: {attr}") - - return [f"{indent_str}{s}" for s in strs] - - def __str__(self) -> str: - return "\n".join(self.readable_strs()) - - -@dataclasses.dataclass -class Scope: - _sub_scopes: list[SubScope] = dataclasses.field(default_factory=list) - - def lookup(self, name: str) -> ir.Value | None: - """Lookup value by name from all SubScopes.""" - for sub_scope in reversed(self._sub_scopes): - if (result := sub_scope.lookup(name)) is not None: - return result - return None - - def bind(self, name: str, value: ir.Value) -> None: - """Bind value to name in the most recent SubScope.""" - if name == "": - raise ValueError("Cannot bind to empty name.") - if value is None: - raise ValueError(f"Cannot bind None to value {name}.") - self._sub_scopes[-1].bind(name, value) - - def lookup_or_create(self, name: str) -> ir.Value: - """Lookup value by name from all SubScopes. If not found, create a new one in most recent SubScope.""" - if name == "": - raise ValueError("Cannot lookup or create empty name.") - for sub_scope in reversed(self._sub_scopes): - if (result := sub_scope.lookup(name)) is not None: - return result - value = ir.Value(name=name) - self.bind(name, value) - return value - - def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: - for sub_scope in reversed(self._sub_scopes): - if (result := sub_scope.lookup_ref_attribute(ref_attr_name)) is not None: - return result - return None - - def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: - self._sub_scopes[-1].bind_ref_attribute(ref_attr_name, attr) - - def enter_sub_scope(self, owner: onnx.GraphProto) -> None: - self._sub_scopes.append(SubScope(owner)) - - def exit_sub_scope(self) -> SubScope: - return self._sub_scopes.pop() - - def current_function_scope(self) -> SubScope | None: - if len(self._sub_scopes) == 0: - return None - if isinstance(self._sub_scopes[0].owner, onnx.FunctionProto): - return self._sub_scopes[0] - return None - - def current_function(self) -> onnx.FunctionProto | None: - current_function_scope = self.current_function_scope() - if current_function_scope is not None: - return current_function_scope.owner - return None - - def current_graph(self) -> onnx.GraphProto | None: - for sub_scope in reversed(self._sub_scopes): - if isinstance(sub_scope.owner, onnx.GraphProto): - return sub_scope.owner - return None - - def readable_strs(self, indent: int = 0) -> list[str]: - indent_str = " " * indent - strs = [] - for i, sub_scope in enumerate(self._sub_scopes): - strs.append(f"SubScope {i}:") - strs.extend(sub_scope.readable_strs(indent=indent + 2)) - return [f"{indent_str}{s}" for s in strs] - - def __str__(self) -> str: - return "\n".join(self.readable_strs()) - - -@dataclasses.dataclass -class ScopeStack: - """Stack of scopes. - - Each Scope represents statically-nested SubScopes (where inner SubScopes can access names defined in outer SubScopes) - produced by subgraphs (occurring as attribute values), except for the first SubScope which could be produced by a function. - With a ScopeStack, there is no such possibility of referencing variables defined higher up in the stack by name. - Instead, it is meant to represent a sequence of (nested) function-calls. Each entry in the stack (except the outermost) - represents a call to a function. - - Thus, we would use a ScopeStack for a context-sensitive analysis (where we recursively process a called function). - For a context-insensitive analysis, we would only need a Scope (where we recursively process subgraphs). - - To debug, `print(scope_stack)` will print the scope structure as well as the info stored - in each scope. - """ - - _scopes: list[Scope] = dataclasses.field(default_factory=lambda: [Scope()]) - - def current_scope(self) -> Scope: - return self._scopes[-1] - - def lookup(self, name: str) -> ir.Value | None: - """Lookup value by name from the current Scope.""" - return self.current_scope().lookup(name) - - def bind(self, name: str, value: ir.Value) -> None: - """Bind value to name in the current Scope.""" - self.current_scope().bind(name, value) - - def lookup_or_create(self, name: str) -> ir.Value: - """Lookup value by name from the current Scope. If not found, create a new one.""" - return self.current_scope().lookup_or_create(name) - - def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: - return self.current_scope().lookup_ref_attribute(ref_attr_name) - - def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: - self.current_scope().bind_ref_attribute(ref_attr_name, attr) - - def enter_graph_scope(self, graph: onnx.GraphProto) -> None: - self.current_scope().enter_sub_scope(graph) - - def exit_graph_scope(self) -> SubScope: - sub_scope = self.current_scope().exit_sub_scope() - assert isinstance(sub_scope.owner, onnx.GraphProto), "Expected graph scope." - return sub_scope - - def enter_function_scope(self, function: onnx.FunctionProto) -> None: - self._scopes.append(Scope()) - self.current_scope().enter_sub_scope(function) - - def exit_function_scope(self) -> SubScope: - sub_scope = self.current_scope().exit_sub_scope() - assert isinstance(sub_scope.owner, onnx.FunctionProto), "Expected function scope." - self._scopes.pop() - return sub_scope - - def current_function(self) -> onnx.FunctionProto | None: - return self.current_scope().current_function() - - def current_graph(self) -> onnx.GraphProto | None: - return self.current_scope().current_graph() - - def __str__(self) -> str: - strs = ["ScopeStach:"] - for i, scope in enumerate(self._scopes): - strs.append(f" Scope {i}:") - strs.extend(scope.readable_strs(indent=2)) - return "\n".join(strs) - - -class ProtoVisitorCore: - def visit_model(self, model: onnx.ModelProto): - self.process_model(model) - for opset in model.opset_import: - self.process_opset_import(opset) - self.visit_graph(model.graph) - for function in model.functions: - self.visit_function(function) - - def process_model(self, model: onnx.ModelProto): - pass - - def process_opset_import(self, opset: onnx.OperatorSetIdProto): - pass - - def visit_graph(self, graph: onnx.GraphProto): - self.enter_scope(graph) - self.process_graph(graph) - for input in graph.input: - self.process_graph_input(input) - for init in graph.initializer: - self.process_initializer(init) - for value_info in graph.value_info: - self.process_value_info(value_info) - for node in graph.node: - self.visit_node(node) - for output in graph.output: - self.process_graph_output(output) - self.exit_scope(graph) - - def visit_function(self, function: onnx.FunctionProto): - self.enter_function_scope(function) - self.process_function(function) - for input in function.input: - self.process_function_input(input) - for node in function.node: - self.visit_node(node) - for output in function.output: - self.process_function_output(output) - self.exit_function_scope(function) - - def process_function_input(self, input: str): - pass - - def process_function_output(self, output: str): - pass - - def process_function(self, function: onnx.FunctionProto): - pass - - def enter_function_scope(self, function: onnx.FunctionProto): - pass - - def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope: - pass - - def enter_scope(self, graph: onnx.GraphProto): - pass - - def process_graph(self, graph: onnx.GraphProto): - pass - - def exit_scope(self, graph: onnx.GraphProto) -> SubScope: - pass - - def process_graph_input(self, input: onnx.ValueInfoProto): - pass - - def process_initializer(self, init: onnx.TensorProto): - pass - - def process_value_info(self, value_info: onnx.ValueInfoProto): - pass - - def visit_node(self, node: onnx.NodeProto): - self.process_node(node) - for attr in node.attribute: - self.visit_attribute(attr) - - def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - pass - - def process_graph_output(self, output: onnx.ValueInfoProto): - pass - - def visit_attribute(self, attr: onnx.AttributeProto): - self.process_attribute(attr) - if attr.HasField("g"): - self.visit_graph(attr.g) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - self.visit_graph(graph) - - def process_attribute(self, attr: onnx.AttributeProto): - pass - - -class ProtoVisitor(ProtoVisitorCore): - def __init__( - self, external_data_folder: str = "", *, do_shape_inference: bool = False - ) -> None: - super().__init__() - self.scopes = ScopeStack() - self.function_shape_env = FunctionShapeEnv() - self.version_map = {} # Map from domain to version - self.do_shape_inference = do_shape_inference - self.external_data_folder = external_data_folder - self.modified = False - - def process_opset_import(self, opset: onnx.OperatorSetIdProto): - domain = normalize_domain(opset.domain) - self.version_map[domain] = opset.version - - def lookup_version(self, domain: str) -> int: - domain = normalize_domain(domain) - return self.version_map.get(domain, 1) # TODO: handle missing domain - - def lookup(self, name: str) -> ir.Value | None: - if name == "": - return None - if (result := self.scopes.lookup(name)) is None: - logger.debug("Lookup value %s unfound.", name) - raise ValueError( - f"Undefined variable {name}.\n" - f"Available variables: {self.scopes.current_scope()}" - ) - logger.debug("Lookup value %s. Shape %s", name, result.tensor_shape_proto()) - return result - - def bind(self, name: str, value: ir.Value) -> None: - logger.debug("Binding value %s. Shape %s", name, value.tensor_shape_proto()) - self.scopes.bind(name, value) - - def lookup_or_create(self, name: str) -> ir.Value: - return self.scopes.lookup_or_create(name) - - def has_input(self, node: onnx.NodeProto, index: int) -> bool: - return index < len(node.input) and node.input[index] != "" - - # TODO: Cleanup handling of undefined variables. May fail in some of methods below. - - def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: - if index < len(node.input): - return self.lookup(node.input[index]) - return None - - def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: - info = self.get_input(node, index) - return info.type if info is not None else None - - def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: - info = self.get_input(node, index) - return info.element_type if info is not None else None - - def input_shape(self, node: onnx.NodeProto, index: int) -> onnx.TensorShapeProto | None: - info = self.get_input(node, index) - return info.tensor_shape_proto() if info is not None else None - - def input_const_value(self, node: onnx.NodeProto, index: int) -> Any: - if not self.has_input(node, index): - return None # This is treated as a known constant value "None" - info = self.get_input(node, index) - return info.value - - def has_output(self, node: onnx.NodeProto, index: int) -> bool: - return index < len(node.output) and node.output[index] != "" - - def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: - if index < len(node.output): - return self.lookup(node.output[index]) - return None - - def get_input_value( - self, node: onnx.NodeProto, index: int, default: Any | None = None - ) -> Any | None: - info = self.get_input(node, index) - if info is not None: - return info.value - return default - - def get_input_type( - self, node: onnx.NodeProto, index: int, default: onnx.TypeProto | None = None - ) -> onnx.TypeProto | None: - info = self.get_input(node, index) - if info is not None: - return info.type - return default - - def enter_scope(self, graph: onnx.GraphProto): - logger.debug("enter_scope: graph %s", graph.name) - self.scopes.enter_graph_scope(graph) - - def exit_scope(self, graph: onnx.GraphProto) -> SubScope: - logger.debug("exit_scope: graph %s", graph.name) - return self.scopes.exit_graph_scope() - - def enter_function_scope(self, function: onnx.FunctionProto): - logger.debug("enter_function_scope: function %s", function.name) - self.scopes.enter_function_scope(function) - ir_values = self.function_shape_env.get_ir_values(function) - for name, ir_value in ir_values.items(): - inferred_ir_value = self.lookup_or_create(name) - updated_ir_value = _override_inferred_value_type_with_symbolic_value_type( - ir_value, inferred_ir_value - ) - self.bind(name, updated_ir_value) - - def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope: - logger.debug("exit_function_scope: function %s", function.name) - # Sync ir value back to function_shape_env - function_scope = self.scopes.exit_function_scope() - for ir_value in function_scope.values.values(): - self.function_shape_env.bind(ir_value, *ir.get_function_id(function)) - return function_scope - - def process_initializer(self, init: onnx.TensorProto): - array = onnx.numpy_helper.to_array(init, self.external_data_folder) - self.bind( - init.name, - ir.Value(name=init.name, value=array, type=get_initializer_type(init)), - ) - - def process_graph_input(self, input: onnx.ValueInfoProto): - self.bind(input.name, ir.Value(name=input.name, type=input.type)) - - def process_value_info(self, value_info: onnx.ValueInfoProto): - logger.debug("process_value_info: %s", value_info) - value = self.lookup_or_create(value_info.name) - value.type = value_info.type - # Populate function shape environment - self.function_shape_env.load_from_value_info(value_info) - - def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - output_types = {} - if self.do_shape_inference and not is_control_flow_op(node): - # Control-flow ops are more complicated. Not supported here yet. - # TODO: handle optional inputs - def get_constant_value(i: int) -> onnx.TensorProto | None: - value = self.input_const_value(node, i) - if isinstance(value, np.ndarray) and value.size < 20: - return onnx.numpy_helper.from_array(value, node.input[i]) - return None - - input_types = {x: self.input_type(node, i) for i, x in enumerate(node.input)} - input_data = {x: get_constant_value(i) for i, x in enumerate(node.input)} - input_data = {k: v for k, v in input_data.items() if v is not None} - if any(t is None for t in input_types.values()): - logger.debug( - "Skipping shape inference for node %s due to missing input type.", - node.name, - ) - else: - # TODO: pass in constant values, ir_version - try: - schema = onnx.defs.get_schema( - node.op_type, self.lookup_version(node.domain), node.domain - ) - output_types = onnx.shape_inference.infer_node_outputs( - schema, node, input_types, input_data - ) - except Exception as e: - logger.debug( - "Skipping shape inference for node %s due to exception: %s", - node.name, - e, - ) - - for output in node.output: - info = self.lookup_or_create(output) - if output in output_types: - # TODO: merge types - info.type = output_types[output] - - -class ProtoTransformer(ProtoVisitor): - # TODO(lowpri) Practically this is useless. - # Subgraph only exist in 'if' nodes. 'if' nodes only exist in torchlib functions. - # There is no pre-existing value_info in torchlib functions. - # def exit_scope(self, graph: onnx.GraphProto) -> SubScope: - # # Also sync updated ir values back to value_info in graph. - # sub_scope = super().exit_scope(graph) - - def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None: - replacement = self.process_node(node) - logger.debug( - "visit_node: %s::%s %s replacement %s", - node.domain, - node.op_type, - node.name, - "found" if replacement is not None else "missed", - ) - if replacement is None: - # No change. Process attributes. - for attr in node.attribute: - self.visit_attribute(attr) - return None - else: - self.modified = True - # We recursively visit the replacement nodes. - result = [] - for newnode in replacement: - n = self.visit_node(newnode) - if n is not None: - result.extend(n) - else: - result.append(newnode) - return result - - def visit_graph(self, graph: onnx.GraphProto) -> dict[str, ir.Value]: - self.enter_scope(graph) - self.process_graph(graph) - for input in graph.input: - self.process_graph_input(input) - for init in graph.initializer: - self.process_initializer(init) - for value_info in graph.value_info: - self.process_value_info(value_info) - updates = [] - nodes = graph.node - for i, node in enumerate(nodes): - replacement = self.visit_node(node) - if replacement is not None: - updates.append((i, replacement)) - for i, replacement in reversed(updates): - old_node_name = nodes[i].name - del nodes[i] - for newnode in reversed(replacement): - logger.debug( - "Replacement node %s for %s. Size %s", - newnode.name, - old_node_name, - newnode.ByteSize(), - ) - nodes.insert(i, newnode) - for output in graph.output: - self.process_graph_output(output) - return self.exit_scope(graph) - - -class FunctionCallsiteAnalysis(ProtoVisitor): - """Collects the callsites of each function.""" - - def __init__(self): - super().__init__() - self.functions: dict[ir.FunctionId, onnx.FunctionProto] = {} - self.function_calls: dict[ir.FunctionId, list[onnx.NodeProto]] = {} - - def visit_function(self, function: onnx.FunctionProto): - # Do not visit function via model.functions. - # Only visit function at callsites. - # The purpose of this analysis is to collect the callsites of each function. - pass - - def visit_node(self, node: onnx.NodeProto) -> None: - if is_local_function_node(node, self.functions): - function_id = ir.get_function_id_from_node(node) - self.function_calls.setdefault(function_id, []).append(node) - for subnode in self.functions[function_id].node: - self.visit_node(subnode) - - def visit_model(self, model: onnx.ModelProto) -> None: - for function in model.functions: - self.functions[ir.get_function_id(function)] = function - - super().visit_model(model) - - -class FunctionRenamer: - _POSTFIX_FORMAT = "{name}|{postfix}_{count}" - - def __init__(self, postfix="folded"): - self._function_key_to_instance_count = {} - self._postfix = postfix - - def rename(self, function: onnx.FunctionProto) -> None: - domain = function.domain - name = function.name - key = (domain, name) - self._function_key_to_instance_count.setdefault(key, 0) - function.name = self._POSTFIX_FORMAT.format( - name=name, - postfix=self._postfix, - count=self._function_key_to_instance_count[key], - ) - self._function_key_to_instance_count[key] += 1 - - -class FunctionCallsiteProtoTransformer(ProtoTransformer): - """Unlike other base visitors, this is a special visitor that visits functions at their callsite. - - This allows transforming and constructing specialized functions based on callsite context. - """ - - _functions: dict[ir.FunctionId, onnx.FunctionProto] - _function_callsites: dict[ir.FunctionId, list[onnx.NodeProto]] - _new_functions: list[onnx.FunctionProto] - _function_renamer: FunctionRenamer - - def _gather_function_metadata(self, model: onnx.ModelProto): - analysis = FunctionCallsiteAnalysis() - analysis.visit_model(model) - self._functions = analysis.functions - self._function_callsites = analysis.function_calls - self._new_functions = [] - self._function_renamer = FunctionRenamer() - - def process_function_outputs(self, function: onnx.FunctionProto) -> bool: - """Process function outputs. - - This method is called when a function is visited at its callsite. - - Returns: - True if the function outputs are modified. - """ - del function # Unused - return False - - def process_function_node_outputs( - self, - node: onnx.NodeProto, - function_scope: SubScope, - ) -> None: - """Fetch value infos of function output to re-bind them for function node output.""" - function = function_scope.owner - output_values = [function_scope.lookup(output) for output in function.output] - for actual_name, formal_value in zip(node.output, output_values): - if formal_value is None: - raise RuntimeError( - "Missing output %s in function-call to %s", - actual_name, - node.op_type, - ) - actual_value = self.lookup_or_create(actual_name) - actual_value.identity_merge_from(formal_value) - if logger.level <= logging.INFO: - logger.info( - "Binding outputs for function %s. %s => %s", - function.name, - actual_value, - node.output, - ) - - def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: - return self.scopes.lookup_ref_attribute(ref_attr_name) - - def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: - self.scopes.bind_ref_attribute(ref_attr_name, attr) - - def visit_model(self, model: onnx.ModelProto): - self._gather_function_metadata(model) - - self.process_model(model) - for opset in model.opset_import: - self.process_opset_import(opset) - self.visit_graph(model.graph) - - for new_function in self._new_functions: - model.functions.append(new_function) - - self.function_shape_env.save_to_model_proto(model) - - def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None: - if is_local_function_node(node, self._functions): - function_id = ir.get_function_id_from_node(node) - if function_id not in self._functions: - # Do not recursively visit new functions. - return None - replacement, _ = self.process_function_node(node) - else: - replacement = self.process_node(node) - logger.debug( - "visit_node: %s::%s %s replacement %s", - node.domain, - node.op_type, - node.name, - "found" if replacement is not None else "missed", - ) - if replacement is None: - # No change. Process attributes. - for attr in node.attribute: - self.visit_attribute(attr) - return None - else: - self.modified = True - # We recursively visit the replacement nodes. - result = [] - for newnode in replacement: - n = self.visit_node(newnode) - if n is not None: - result.extend(n) - else: - result.append(newnode) - return result - - def process_function_node( - self, node: onnx.NodeProto - ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: - function_id = ir.get_function_id_from_node(node) - function = self._functions[function_id] - - is_unique_callsite = len(self._function_callsites[function_id]) == 1 - if not is_unique_callsite: - mutable_function = onnx.FunctionProto() - mutable_function.CopyFrom(function) - else: - mutable_function = function - - logger.info("Visit function %s node %s", function_id, node.name) - actual_input_value_infos = [self.lookup(input) for input in node.input] - # Handle omitted inputs, these are considered optional inputs of the function. - actual_input_value_infos.extend( - [None] * (len(function.input) - len(actual_input_value_infos)) - ) - ref_attributes = { - attr_proto.name: self.lookup_ref_attribute(attr_proto.ref_attr_name) - for attr_proto in node.attribute - if attr_proto.ref_attr_name - } - - self.enter_function_scope(mutable_function) - if logger.level <= logging.INFO: - printable_actual_input_value_infos = [str(x) for x in actual_input_value_infos] - logger.info( - "Actual input value infos: %s", - printable_actual_input_value_infos, - ) - logger.info("Enter function scope: %s", self.scopes.current_scope()) - - logger.debug("Binding inputs for function %s", function.name) - for actual_input_value_info, formal_input in zip( - actual_input_value_infos, function.input - ): - formal_info = ir.Value(formal_input) - if actual_input_value_info is not None: - formal_info.identity_merge_from(actual_input_value_info) - self.bind(formal_input, formal_info) - - for attr_proto in function.attribute_proto: - # Default value of function attributes. - self.bind_ref_attribute(attr_proto.name, attr_proto) - - for attr_proto in node.attribute: - if attr_proto.ref_attr_name: - concrete_attribute = ref_attributes.get(attr_proto.name) - if concrete_attribute is None: - continue - self.bind_ref_attribute(attr_proto.name, concrete_attribute) - else: - self.bind_ref_attribute(attr_proto.name, attr_proto) - - # Visit inner function nodes. - node_updates: list[tuple[int, list[onnx.NodeProto]]] = [] - nodes = mutable_function.node - for i, inner_node in enumerate(nodes): - replacement = self.visit_node(inner_node) - if replacement is not None: - node_updates.append((i, replacement)) - for i, replacement in reversed(node_updates): - old_node_name = nodes[i].name - old_node_op_type = nodes[i].op_type - del nodes[i] - for newnode in reversed(replacement): - logger.debug( - "Replacement node inside function %s: %s for %s %s. Size %s", - node.name, - newnode.output, - old_node_name, - old_node_op_type, - newnode.ByteSize(), - ) - nodes.insert(i, newnode) - added_domains = set() - del mutable_function.opset_import[:] - for inner_node in nodes: - # Update opset_import if needed. - if inner_node.domain not in added_domains: - version = self.lookup_version(inner_node.domain) - mutable_function.opset_import.append( - onnx.OperatorSetIdProto(domain=inner_node.domain, version=version) - ) - added_domains.add(inner_node.domain) - - output_updates = self.process_function_outputs(mutable_function) - - is_new_function = not is_unique_callsite and (node_updates or output_updates) - if is_new_function: - self._new_functions.append(mutable_function) - self._function_renamer.rename(mutable_function) - node.op_type = mutable_function.name - - function_scope = self.exit_function_scope(mutable_function) - - self.process_function_node_outputs(node, function_scope) - - logger.info("Exit function scope: %s", function_scope) - logger.info("Exit function %s node %s", function_id, node.name) - - if is_new_function: - return [node], mutable_function - return None, None diff --git a/onnxscript/_legacy_ir/visitor_test.py b/onnxscript/_legacy_ir/visitor_test.py deleted file mode 100644 index e4559472e3..0000000000 --- a/onnxscript/_legacy_ir/visitor_test.py +++ /dev/null @@ -1,38 +0,0 @@ -import unittest - -import onnx - -from onnxscript._legacy_ir import visitor - - -class FunctionCallsiteProtoTransformerTest(unittest.TestCase): - def test_function_optional_input_is_recorded_by_shape_env(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - z = custom.function(x) - } - < - domain: "custom", - opset_import: ["" : 18] - > - function (x, optional_y, optional_z) => (return_val) - { - return_val = custom.custom_op (x, optional_y, optional_z) - } - """ - ) - - model_visitor = visitor.FunctionCallsiteProtoTransformer() - model_visitor.visit_model(model) - self.assertIsNotNone( - model_visitor.function_shape_env.lookup(model.functions[0], "optional_y") - ) - self.assertIsNotNone( - model_visitor.function_shape_env.lookup(model.functions[0], "optional_z") - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/_thirdparty/asciichartpy.py b/onnxscript/_thirdparty/asciichartpy.py index 3cd91f84f5..88c46202ca 100644 --- a/onnxscript/_thirdparty/asciichartpy.py +++ b/onnxscript/_thirdparty/asciichartpy.py @@ -1,5 +1,5 @@ -# SPDX-License-Identifier: MIT -# Modifications Copyright (c) Microsoft. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. # # Copyright © 2016 Igor Kroitor # @@ -198,8 +198,8 @@ def plot(series, *, bin_edges=None, cfg=None): height = cfg.get("height", interval) ratio = height / interval if interval > 0 else 1 - min2 = int(floor(minimum * ratio)) - max2 = int(ceil(maximum * ratio)) + min2 = floor(minimum * ratio) + max2 = ceil(maximum * ratio) def clamp(n): return min(max(n, minimum), maximum) diff --git a/onnxscript/backend/__init__.py b/onnxscript/backend/__init__.py index 862c45ce31..59e481eb93 100644 --- a/onnxscript/backend/__init__.py +++ b/onnxscript/backend/__init__.py @@ -1,4 +1,2 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/onnxscript/backend/onnx_backend.py b/onnxscript/backend/onnx_backend.py index 83c9bca39a..ef93bb50b7 100644 --- a/onnxscript/backend/onnx_backend.py +++ b/onnxscript/backend/onnx_backend.py @@ -1,8 +1,6 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- - +# ruff: noqa: TID251 import os import textwrap @@ -291,7 +289,7 @@ def enumerate_onnx_tests(series, fct_filter=None) -> Iterator[OnnxBackendTest]: sub = os.path.join(root, "data", series) if not os.path.exists(sub): raise FileNotFoundError( - "Unable to find series of tests in {root!r}, subfolders:\n" + f"Unable to find series of tests in {root!r}, subfolders:\n" + "\n".join(os.listdir(root)) ) tests = os.listdir(sub) diff --git a/onnxscript/backend/onnx_backend_test.py b/onnxscript/backend/onnx_backend_test.py index b640331490..efd9d823d8 100644 --- a/onnxscript/backend/onnx_backend_test.py +++ b/onnxscript/backend/onnx_backend_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import os import unittest diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 01ab09c8f2..cfea1a501c 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -1,21 +1,20 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations from typing import Any, Optional, Sequence -import numpy +import numpy as np import onnx from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto -from onnx.helper import make_node import onnxscript.onnx_types import onnxscript.type_annotation _SINGLE_INDENT = " " +_SMALL_TENSOR_SIZE = 4 + kwlist = { "False", "None", @@ -70,10 +69,11 @@ def _get_const_repr(const_node): if tensor_proto.data_type in {TensorProto.FLOAT, TensorProto.INT64}: rank = len(tensor_proto.dims) if rank == 0: - array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) - return repr(array[0]) + array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251 + return str(array[0]) if rank == 1 and tensor_proto.dims[0] < 5: - return repr(list(onnx.numpy_helper.to_array(tensor_proto))) + nparray = onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251 + return repr(nparray.tolist()) return None @@ -121,7 +121,7 @@ def renamer(name): def _translate_type(onnx_type): """Converts a onnx type into a type defined by *onnxscript*.""" - return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type) + return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False) def _translate_signature(inputs, outputs): @@ -141,6 +141,15 @@ def input_sig(inp: ValueInfoProto | str): return f"{result}:" +def _translate_value_infos(value_infos: Sequence[ValueInfoProto]) -> str: + def _translate_value_info(value_info: ValueInfoProto) -> str: + return f"{_SINGLE_INDENT}'{_cleanup_variable_name(value_info.name)}': {_translate_type(value_info.type)}," + + lines = [_translate_value_info(x) for x in value_infos] + lines_joined = "\n".join(lines) + return "{\n" + lines_joined + "\n}" + + def _to_str(s): if isinstance(s, bytes): return s.decode("utf-8") @@ -163,7 +172,7 @@ def _attribute_value(attr: onnx.AttributeProto): if onnx.external_data_helper.uses_external_data(tensor_proto): return tensor_proto else: - return onnx.numpy_helper.to_array(tensor_proto) + return onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251 # TODO: # - onnx.AttributeProto.GRAPH # - onnx.AttributeProto.SPARSE_TENSOR @@ -249,11 +258,11 @@ def _cond_is_used_in_loop_body(graph: GraphProto) -> bool: return False -class Exporter: +class _Exporter: """Class used for recursive traversal of Proto structures.""" def __init__( - self, rename: bool, use_operators: bool = False, inline_const: bool = False + self, *, rename: bool, use_operators: bool, inline_const: bool, skip_initializers: bool ) -> None: self.use_operators = use_operators if rename: @@ -268,6 +277,8 @@ def __init__( # _name_remappings: used to undo the SSA-renaming in ONNX control-flow ops. # We map the multiple SSA-variants back to the same Python variable name. self._name_remappings: list[dict[str, str]] = [] + self.skip_initializers = skip_initializers + self.skipped_initializers: dict[str, onnx.TensorProto] = {} def _handle_attrname_conflict(self, renamer): """Add ref-attr-name-conflict handling logic to renaming function.""" @@ -311,7 +322,7 @@ def _translate_onnx_var_ref(self, var): def _rename_domain(self, domain: str) -> str: if domain in {"", "ai.onnx"}: - return "opset" # TODO: Need checks to avoid name conflicts. + return "opset" # TODO: Need checks to avoid name conflicts. return _cleanup_variable_name(domain) # type: ignore[return-value] def _make_opset_name(self, domain, version): @@ -340,18 +351,34 @@ def _translate_graph_body(self, graph, opsets, indent=0): code = [] if hasattr(graph, "initializer"): for init in graph.initializer: - node = make_node( + if self.skip_initializers: + size = 1 + for d in init.dims: + size *= d + if size > _SMALL_TENSOR_SIZE: + init_py_name = self._translate_onnx_var(init.name) + if init_py_name in self.skipped_initializers: + raise RuntimeError( + f"Initializer {init.name!r} is already present in skipped_initializers." + ) + self.skipped_initializers[init_py_name] = init + continue + node = onnx.helper.make_node( # noqa: TID251 "Constant", [], [self._translate_onnx_var(init.name)], # type: ignore[list-item] value=init, ) - code.append(self._translate_node(node, opsets, indent=indent)) + pyinit = self._translate_node(node, opsets, indent=indent) + if pyinit: + code.append(pyinit) if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0: raise NotImplementedError("Unable to convert sparse_initilizer into python.") for node in graph.node: pynode = self._translate_node(node, opsets, indent=indent) if pynode: + if node.name: + pynode += f" # {node.name}" code.append(pynode) final = "\n".join(code) @@ -367,17 +394,17 @@ def _translate_attributes(self, node): if isinstance(value, str): attributes.append((at.name, f"{value!r}")) continue - if isinstance(value, numpy.ndarray): + if isinstance(value, np.ndarray): onnx_dtype = at.t.data_type if len(value.shape) == 0: text = ( f'make_tensor("value", {onnx_dtype}, dims=[], ' - f"vals=[{value.tolist()!r}])" + f"vals=[{repr(value.tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')}])" ) else: text = ( f'make_tensor("value", {onnx_dtype}, dims={list(value.shape)!r}, ' - f"vals={value.ravel().tolist()!r})" + f"vals={repr(value.ravel().tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')})" ) attributes.append((at.name, text)) continue @@ -391,6 +418,7 @@ def _translate_attributes(self, node): text += f", offset={metadata.offset!r}" if metadata.length: text += f", length={metadata.length!r}" + text += ")" attributes.append((at.name, text)) continue attributes.append((at.name, repr(value))) @@ -400,7 +428,8 @@ def _translate_attributes(self, node): def _translate_if(self, node, opsets, indent=0): """Translates a node If into python.""" sindent = _SINGLE_INDENT * indent - code = [f"{sindent}if {node.input[0]}:"] + cond = self._translate_onnx_var_ref(node.input[0]) + code = [f"{sindent}if {cond}:"] if len(node.attribute) != 2: raise RuntimeError( f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}." @@ -484,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0): rows.extend(self._emit_assign(formal_ins, actual_ins, indent)) + if node.name: + node_name = " # " + node.name + else: + node_name = "" if use_iter_var and not use_loop_cond: - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") # The following is a hacky way to suppress the generation of # "cond_out = cond_in", which ONNX forces for a FOR loop. # TODO: a cleaner solution for this. self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in) elif not use_iter_var and use_loop_cond: - rows.append(f"{sindent}while {py_cond}:") + rows.append(f"{sindent}while {py_cond}:{node_name}") elif use_iter_var and use_loop_cond: # TODO: This needs fixing - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:") rows.append(f"{sindent}{_SINGLE_INDENT * 2}break") else: @@ -685,15 +718,68 @@ def _translate_graph(self, model: onnx.ModelProto, function_name: Optional[str]) def add(line: str) -> None: result.append(line) - add("@script()") - add(f"def {function_name}{_translate_signature(graph.input, graph.output)}") + if self.skip_initializers: + indent_level = 2 + indent = _SINGLE_INDENT + else: + indent_level = 1 + indent = "" + add(f"{indent}@script()") + add(f"{indent}def {function_name}{_translate_signature(graph.input, graph.output)}") + indent = indent + _SINGLE_INDENT doc = graph.doc_string if doc: - add(f' """{doc}"""') - add(self._translate_graph_body(graph, opsets, indent=1)) + add(f'{indent}"""{doc}"""') + add(self._translate_graph_body(graph, opsets, indent=indent_level)) return_values = ", ".join(self._translate_onnx_var(x) for x in graph.output) - add(f" return {return_values}") - return "\n".join(result) + add(f"{indent}return {return_values}") + script = "\n".join(result) + if self.skipped_initializers: + value_infos = _translate_value_infos(graph.value_info) + return self._substitute_initializers(script, function_name, value_infos) + return script + + def _substitute_initializers( + self, script: str, script_function_name: str, value_infos: str + ) -> str: + init_names = self.skipped_initializers.keys() + # Formal parameters representing initializers (single level indentation) + __ = _SINGLE_INDENT + initializers_as_params = "\n".join(f"{__}{x}," for x in init_names) + + def generate_rand(name: str, value: TensorProto) -> str: + shape = ",".join(str(d) for d in value.dims) + if value.data_type == TensorProto.FLOAT: + return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" + if value.data_type == TensorProto.INT8: + return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)" + raise NotImplementedError( + f"Unable to generate random initializer for data type {value.data_type}." + ) + + random_initializer_values = "\n".join( + generate_rand(key, value) for key, value in self.skipped_initializers.items() + ) + # Actual parameter values for initializers (double level indentation) + indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names) + return f""" +value_infos = {value_infos} + +def make_model( +{initializers_as_params} +): +{script} + +{__}model = {script_function_name}.to_model_proto(value_infos=value_infos) +{__}return model + +def make_model_with_random_weights(): +{random_initializer_values} +{__}model = make_model( +{indented_initializers_as_params} +{__}) +{__}return model +""" def _import_onnx_types( self, proto: onnx.ModelProto | onnx.GraphProto | onnx.FunctionProto @@ -724,7 +810,7 @@ def add(line: str) -> None: result.append(line) # Generic imports. - add("import numpy") + add("import numpy as np") add("from onnx import TensorProto") add("from onnx.helper import make_tensor") add("from onnxscript import script, external_tensor") @@ -779,9 +865,11 @@ def visit_graph(graph: onnx.GraphProto) -> None: def export2python( model_onnx, function_name: Optional[str] = None, + *, rename: bool = False, use_operators: bool = False, inline_const: bool = False, + skip_initializers: bool = False, ): """Exports an ONNX model to the *python* syntax. @@ -791,6 +879,9 @@ def export2python( function_name: main function name use_operators: use Python operators. inline_const: replace ONNX constants inline if compact + skip_initializers: generated script will not include initializers. + Instead, a function that generates the model, given initializer values, is generated, + along with one that generates random values for the initializers. Returns: python code @@ -799,11 +890,11 @@ def export2python( .. runpython:: :showcode: :process: - import numpy + import numpy as np from sklearn.cluster import KMeans from mlprodict.onnx_conv import to_onnx from mlprodict.onnx_tools.onnx_export import export2python - X = numpy.arange(20).reshape(10, 2).astype(numpy.float32) + X = np.arange(20).reshape(10, 2).astype(np.float32) tr = KMeans(n_clusters=2) tr.fit(X) onx = to_onnx(tr, X, target_opset=14) @@ -816,5 +907,10 @@ def export2python( if not isinstance(model_onnx, (ModelProto, FunctionProto)): raise TypeError(f"The function expects a ModelProto not {type(model_onnx)!r}.") - exporter = Exporter(rename, use_operators, inline_const) + exporter = _Exporter( + rename=rename, + use_operators=use_operators, + inline_const=inline_const, + skip_initializers=skip_initializers, + ) return exporter.export(model_onnx, function_name) diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index efcc8ae8a2..1f913ed897 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -1,13 +1,13 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import dataclasses import importlib +import os import pathlib import re +import sys import unittest from typing import Pattern @@ -45,14 +45,8 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): SKIP_TESTS = ( - skip( - r"^test_ai_onnx_ml_array_feature_extractor", - "ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'", - ), - skip( - r"^test_ai_onnx_ml_binarizer", - "ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'", - ), + skip(r"^test_ai_onnx_ml_array_feature_extractor", "ORT doesn't support this op"), + skip(r"^test_ai_onnx_ml_binarizer", "ORT doesn't support this op"), skip(r"^test_center_crop_pad_crop_negative_axes_hwc", "fixme: ORT segfaults"), skip(r"_scan_", "Operator Scan is not supported by onnxscript"), skip(r"^test_scan", "Operator Scan is not supported by onnxscript"), @@ -89,8 +83,27 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): "Change when the converter supports support something like 'while i < n and cond:'", ), skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"), + skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"), + skip(r"^test_attention", "ONNX Runtime 1.23 fails on these tests"), ) +if sys.platform == "win32": + SKIP_TESTS = ( + *SKIP_TESTS, + skip(r"^test_gemm_beta", "cannot import module, import_module does not work"), + skip( + r"^test_averagepool_2d_default", + "cannot import module, import_module does not work", + ), + skip("^test_bitwise_not_3d", "cannot import module, import_module does not work"), + skip( + "^test_resize_upsample_scales_linear_half_pixel_symmetric", + "cannot import module, import_module does not work", + ), + # tests are too unstable on Windows, not always the same ones are failing. + skip("test_", "cannot import module"), + ) + def load_function(obj): return ort.InferenceSession(obj.SerializeToString(), providers=("CPUExecutionProvider",)) @@ -108,16 +121,24 @@ def run_function(obj, *inputs): def extract_functions(name: str, content: str, test_folder: pathlib.Path): if not test_folder.exists(): test_folder.mkdir(exist_ok=True, parents=True) - init = test_folder / "__init__.py" - init.touch(exist_ok=True) - file = test_folder / f"{name}.py" - file.write_text(content, encoding="utf-8") + init = str(test_folder / "__init__.py") + with open(init, "w", encoding="utf-8") as f: + f.write("\n") + filename = str(test_folder / f"{name}.py") + with open(filename, "w", encoding="utf-8") as f: + f.write(content + "\n") + assert os.path.exists(filename), ( + f"{filename!r} ({os.path.abspath(filename)!r} does not exist." + ) import_name = f"tests.{test_folder.parts[-1]}.{name}" try: mod = importlib.import_module(import_name) except (SyntaxError, ImportError) as e: raise AssertionError( - f"Unable to import {import_name!r} (file: {file!r})\n----\n{content}" + f"Unable to import {import_name!r} (e={e}) (file: {filename!r}, " + f"absolute path: {os.path.abspath(filename)!r}, " + f"current folder: {os.getcwd()}" + f"\n---- CONTENT --\n{content}" ) from e functions = { k: v for k, v in mod.__dict__.items() if isinstance(v, onnxscript.OnnxFunction) @@ -137,7 +158,7 @@ class TestOnnxBackEnd(unittest.TestCase): test_folder = root_folder / "tests" / "onnx_backend_test_code" temp_folder = root_folder / "tests" / "export" - def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options): + def _proto_to_os_and_back(self, proto: onnx.FunctionProto, **export_options): """Convert a proto to onnxscript code and convert it back to a proto.""" code = onnx_export.export2python(proto, **export_options) map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder) @@ -267,16 +288,6 @@ def _load_function(_): return session def _run_function(obj, *inputs): - print(" run ONNX") - for i, inp in enumerate(inputs): - if inp is None: - print(f" input {i}: None") - else: - print( - f" input {i}: " - f"dtype={inp.dtype!r} shape={inp.shape!r}" - f"{inp.ravel().tolist()!r}" - ) try: return run_function(obj, *inputs) except Exception as e: diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 515829488d..dfcddefbd3 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import ast @@ -25,9 +23,6 @@ from onnxscript import type_annotation as ta from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation -PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39 - - logger = logging.getLogger("onnxscript") @@ -303,7 +298,7 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: return r def _make_onnx_attr( - self, attrname: str, attrval: Any, attrtype: Optional[int] = None + self, attrname: str, attrval: Any, attrtype: int | None = None ) -> irbuilder.IRAttributeValue: def tensor_name_generator() -> str: """Return name to be used for tensor, if we need to create one.""" @@ -429,9 +424,7 @@ def _emit_copy(self, original_var: str, suggested_name: str) -> str: def _is_constant_expr(self, node: ast.AST) -> None: if isinstance(node, ast.UnaryOp): - if self._is_constant_expr(node.operand): - return True - return False + return self._is_constant_expr(node.operand) if isinstance( node, ( @@ -439,14 +432,10 @@ def _is_constant_expr(self, node: ast.AST) -> None: ast.BinOp, ast.UnaryOp, ast.Compare, - ast.Num, - ast.Str, ast.Attribute, ast.List, ast.Load, - ast.NameConstant, ast.Constant, - ast.Str, ), ): return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node)) @@ -527,10 +516,10 @@ def _translate_attr( # in a NodeProto. if val is None: if attr_meta and attr_meta.required: - self.fail(expr, "Attribute '{attr_name}' is required.") + self.fail(expr, f"Attribute '{attr_name}' is required.") return None - attr_type = attr_meta.type if attr_meta else None - attr = self._make_onnx_attr(attr_name, val, attr_type) + attr_type = int(attr_meta.type) if attr_meta else None + attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type) if attr_meta and (attr.type != attr_meta.type): self.fail( expr, @@ -582,9 +571,9 @@ def _translate_expr( def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: """Translation of an expression where "None" is permitted (eg., for an optional argument). - None is represented as a NameConstant in Python 3.7 and Constant in Python 3.9. + None is represented as a Constant in Python 3.9+. """ - if isinstance(node, (ast.NameConstant, ast.Constant)) and (node.value is None): + if isinstance(node, ast.Constant) and (node.value is None): return None return self._translate_expr(node) @@ -633,7 +622,7 @@ def _translate_subscript_expr( target = f"{var_name}_subscripted" target = self.generate_unique_name(target) indices = ast_utils.normalize_subscript_expr(node) - info = self._source_of(node.slice if PY_VERSION_GE_39 else node) + info = self._source_of(node.slice) # Create cached int constants: # TODO: Do this at a graph-scope level. @@ -804,6 +793,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: non_scalar_indices.extend(scalar_indices) if non_scalar_indices: last_axis, _ = non_scalar_indices[-1] + else: + # TODO(justinchuby): Clarify what last_axis should be when non_scalar_indices is False + last_axis = None for axis, index_expr in non_scalar_indices: index_value = self._translate_expr(index_expr) axis_attr = self._make_onnx_attr("axis", axis) @@ -943,7 +935,6 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable opname = node.attr if opname in module: return values.Op(module, node.attr) - warn(f"'{opname}' is not a known op in '{module}'") return values.Op(module, node.attr) if isinstance(node, ast.Name): function_name = node.id @@ -1243,14 +1234,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: if i != len(loop_stmt.body) - 1: self.fail(s, "Instruction break must be the last one of the loop.") - _current_scope = self._current_scope() - if s.test.id not in _current_scope: + current_scope = self._current_scope() + if s.test.id not in current_scope: self.fail( loop_stmt, f"Unable to find condition variable {s.test.id!r} in known " - f"variables {list(_current_scope)!r}.", + f"variables {list(current_scope)!r}.", ) - condition_name = _current_scope[s.test.id].value + condition_name = current_scope[s.test.id].value operator_name = "Not" continue self._translate_stmt(s) @@ -1259,14 +1250,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: if cond_while is not None: # Loop while - _current_scope = self._current_scope() - if cond_while not in _current_scope: + current_scope = self._current_scope() + if cond_while not in current_scope: self.fail( loop_stmt, f"Unable to find condition variable {cond_while!r} in known " - f"variables {list(_current_scope)!r}.", + f"variables {list(current_scope)!r}.", ) - o_cond_var = _current_scope[cond_while].value + o_cond_var = current_scope[cond_while].value self.emit( [o_cond_out], diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 58ed379686..9a7ca504a7 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -1,13 +1,10 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import ast import inspect import os import pathlib -import sys import textwrap import types import typing @@ -194,6 +191,27 @@ def cast_add(x, y): self.assertEqual(y_value_info.type.tensor_type.elem_type, onnx.TensorProto.INT64) self.assertEqual(output_value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT) + def test_set_value_info(self): + @script() + def double_square(x): + square = op.Mul(x, x) + return op.Add(square, square) + + # Converting "cast_add" to a ModelProto will generate an incomplete ModelProto, + # with input-types undefined (since the script has no type-annotation). + model = double_square.to_model_proto() + graph = model.graph + self.assertEqual(len(graph.value_info), 0) + model = double_square.to_model_proto( + io_types=FLOAT["N"], value_infos={"square": FLOAT["N"]} + ) + graph = model.graph + self.assertEqual(len(graph.value_info), 1) + value_info = graph.value_info[0] + self.assertEqual(value_info.name, "square") + self.assertEqual(value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT) + self.assertEqual(value_info.type.tensor_type.shape.dim[0].dim_param, "N") + def test_onnxfns1(self): from tests.models import onnxfns1 @@ -439,8 +457,7 @@ def f1(A: FLOAT[...]) -> FLOAT[...]: r = A[index] return r - ast_name = "_ast" if sys.version_info[:2] < (3, 9) else "ast" - self.check_failure(f1, f"Left term must be a tuple not ''") + self.check_failure(f1, "Left term must be a tuple not ''") def check_run(self, onnxfn, inputs, expected_output): # Test by converting to model and running with ORT @@ -675,6 +692,24 @@ def sum(n: INT64) -> INT64: self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64)) self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64)) + def test_function_opset_import(self): + """Test that model inherits opset version from the function.""" + from onnxscript import opset19 + + @script() + def double(x): + return opset19.Add(x, x) + + @script() + def model(x): + return double(x) + + model_proto = model.to_model_proto() + onnx_opset_import = [opset for opset in model_proto.opset_import if opset.domain == ""] + + self.assertEqual(len(onnx_opset_import), 1) + self.assertEqual(onnx_opset_import[0].version, 19) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/diagnostics/infra/__init__.py b/onnxscript/diagnostics/infra/__init__.py deleted file mode 100644 index 1d771666f2..0000000000 --- a/onnxscript/diagnostics/infra/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -from ._infra import ( - DiagnosticOptions, - Graph, - Invocation, - Level, - Location, - Rule, - RuleCollection, - Stack, - StackFrame, - Tag, - ThreadFlowLocation, - levels, -) -from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnosticError - -__all__ = [ - "Diagnostic", - "DiagnosticContext", - "DiagnosticOptions", - "Graph", - "Invocation", - "Level", - "levels", - "Location", - "Rule", - "RuleCollection", - "RuntimeErrorWithDiagnosticError", - "Stack", - "StackFrame", - "Tag", - "ThreadFlowLocation", -] diff --git a/onnxscript/diagnostics/infra/_infra.py b/onnxscript/diagnostics/infra/_infra.py deleted file mode 100644 index f225a191fe..0000000000 --- a/onnxscript/diagnostics/infra/_infra.py +++ /dev/null @@ -1,319 +0,0 @@ -"""This file defines an additional layer of abstraction on top of the SARIF OM.""" - -from __future__ import annotations - -import dataclasses -import enum -import pprint -from typing import FrozenSet, List, Mapping, Optional, Sequence, Tuple - -from onnxscript.diagnostics.infra import formatter, sarif - - -class Level(enum.IntEnum): - """The level of a diagnostic. - - This class is used to represent the level of a diagnostic. The levels are defined - by the SARIF specification, and are not modifiable. For alternative categories, - please use infra.Tag instead. When selecting a level, please consider the following - guidelines: - - - NONE: Informational result that does not indicate the presence of a problem. - - NOTE: An opportunity for improvement was found. - - WARNING: A potential problem was found. - - ERROR: A serious problem was found. - - This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer - value maps to the logging levels in Python's logging module. The mapping is as - follows: - - Level.NONE = logging.DEBUG = 10 - Level.NOTE = logging.INFO = 20 - Level.WARNING = logging.WARNING = 30 - Level.ERROR = logging.ERROR = 40 - """ - - NONE = 10 - NOTE = 20 - WARNING = 30 - ERROR = 40 - - -levels = Level - - -class Tag(enum.Enum): - """The tag of a diagnostic. This class can be inherited to define custom tags.""" - - -class PatchedPropertyBag(sarif.PropertyBag): - """Key/value pairs that provide additional information about the object. - - The definition of PropertyBag via SARIF spec is "A property bag is an object (§3.6) - containing an unordered set of properties with arbitrary names." However it is not - reflected in the json file, and therefore not captured by the python representation. - This patch adds additional **kwargs to the `__init__` method to allow recording - arbitrary key/value pairs. - """ - - def __init__(self, tags: Optional[List[str]] = None, **kwargs): - super().__init__(tags=tags) - self.__dict__.update(kwargs) - - -@dataclasses.dataclass(frozen=True) -class Rule: - id: str - name: str - message_default_template: str - short_description: Optional[str] = None - full_description: Optional[str] = None - full_description_markdown: Optional[str] = None - help_uri: Optional[str] = None - - @classmethod - def from_sarif(cls, **kwargs): - """Returns a rule from the SARIF reporting descriptor.""" - short_description = kwargs.get("short_description", {}).get("text") - full_description = kwargs.get("full_description", {}).get("text") - full_description_markdown = kwargs.get("full_description", {}).get("markdown") - help_uri = kwargs.get("help_uri") - - rule = cls( - id=kwargs["id"], - name=kwargs["name"], - message_default_template=kwargs["message_strings"]["default"]["text"], - short_description=short_description, - full_description=full_description, - full_description_markdown=full_description_markdown, - help_uri=help_uri, - ) - return rule - - def sarif(self) -> sarif.ReportingDescriptor: - """Returns a SARIF reporting descriptor of this Rule.""" - short_description = ( - sarif.MultiformatMessageString(text=self.short_description) - if self.short_description is not None - else None - ) - full_description = ( - sarif.MultiformatMessageString( - text=self.full_description, markdown=self.full_description_markdown - ) - if self.full_description is not None - else None - ) - return sarif.ReportingDescriptor( - id=self.id, - name=self.name, - short_description=short_description, - full_description=full_description, - help_uri=self.help_uri, - ) - - def format(self, level: Level, *args, **kwargs) -> Tuple[Rule, Level, str]: - """Returns a tuple of (rule, level, message) for a diagnostic. - - This method is used to format the message of a diagnostic. The message is - formatted using the default template of this rule, and the arguments passed in - as `*args` and `**kwargs`. The level is used to override the default level of - this rule. - """ - return (self, level, self.format_message(*args, **kwargs)) - - def format_message(self, *args, **kwargs) -> str: - """Returns the formatted default message of this Rule. - - This method should be overridden (with code generation) by subclasses to reflect - the exact arguments needed by the message template. This is a helper method to - create the default message for a diagnostic. - """ - return self.message_default_template.format(*args, **kwargs) - - def pretty_print(self): - pass - - -@dataclasses.dataclass -class Location: - uri: Optional[str] = None - line: Optional[int] = None - message: Optional[str] = None - start_column: Optional[int] = None - end_column: Optional[int] = None - snippet: Optional[str] = None - function: Optional[str] = None - - def sarif(self) -> sarif.Location: - """Returns the SARIF representation of this location.""" - return sarif.Location( - physical_location=sarif.PhysicalLocation( - artifact_location=sarif.ArtifactLocation(uri=self.uri), - region=sarif.Region( - start_line=self.line, - start_column=self.start_column, - end_column=self.end_column, - snippet=sarif.ArtifactContent(text=self.snippet), - ), - ), - message=sarif.Message(text=self.message) if self.message is not None else None, - ) - - def pretty_print(self): - """Prints the location in a traceback style format.""" - unknown = "" - snippet = self.snippet or unknown - uri = self.uri or unknown - function = self.function or unknown - lineno = self.line if self.line is not None else unknown - message = f" # {self.message}" if self.message is not None else "" - print(f' File "{uri}", line {lineno}, in {function}\n {snippet}{message}') - - -@dataclasses.dataclass -class StackFrame: - location: Location - - def sarif(self) -> sarif.StackFrame: - """Returns the SARIF representation of this stack frame.""" - return sarif.StackFrame(location=self.location.sarif()) - - def pretty_print(self): - """Prints the stack frame in a human-readable format.""" - self.location.pretty_print() - - -@dataclasses.dataclass -class Stack: - """Records a stack trace. The frames are in order from newest to oldest stack frame.""" - - frames: List[StackFrame] = dataclasses.field(default_factory=list) - message: Optional[str] = None - - def sarif(self) -> sarif.Stack: - """Returns the SARIF representation of this stack.""" - return sarif.Stack( - frames=[frame.sarif() for frame in self.frames], - message=sarif.Message(text=self.message) if self.message is not None else None, - ) - - def pretty_print(self): - """Prints the stack in a human-readable format.""" - formatter.pretty_print_title(f"Stack: {self.message}", fill_char="-") - for frame in reversed(self.frames): - frame.pretty_print() - - -@dataclasses.dataclass -class ThreadFlowLocation: - """Records code location and the initial state.""" - - location: Location - state: Mapping[str, str] - index: int - stack: Optional[Stack] = None - - def sarif(self) -> sarif.ThreadFlowLocation: - """Returns the SARIF representation of this thread flow location.""" - return sarif.ThreadFlowLocation( - location=self.location.sarif(), - state=self.state, - stack=self.stack.sarif() if self.stack is not None else None, - ) - - def pretty_print(self, verbose: bool = False): - """Prints the thread flow location in a human-readable format.""" - formatter.pretty_print_title(f"Step {self.index}", fill_char="-") - self.location.pretty_print() - if verbose: - print(f"State: {pprint.pformat(self.state)}") - if self.stack is not None: - self.stack.pretty_print() - - -@dataclasses.dataclass -class Graph: - """A graph of diagnostics. - - This class stores the string representation of a model graph. - The `nodes` and `edges` fields are unused in the current implementation. - """ - - graph: str - name: str - description: Optional[str] = None - - def sarif(self) -> sarif.Graph: - """Returns the SARIF representation of this graph.""" - return sarif.Graph( - description=sarif.Message(text=self.graph), - properties=PatchedPropertyBag(name=self.name, description=self.description), - ) - - def pretty_print( - self, - verbose: bool = False, - ): - """Prints the diagnostics in a human-readable format. - - Args: - verbose: If True, prints all information. Otherwise, only prints compact - information. E.g., graph name and description. - log_level: The minimum level of diagnostics to print. - """ - formatter.pretty_print_title(f"Graph: {self.name}", fill_char="-") - print(self.description) - if verbose: - print(self.graph) - - -@dataclasses.dataclass -class RuleCollection: - _rule_id_name_set: FrozenSet[Tuple[str, str]] = dataclasses.field(init=False) - - def __post_init__(self) -> None: - self._rule_id_name_set = frozenset( - { - (field.default.id, field.default.name) - for field in dataclasses.fields(self) - if isinstance(field.default, Rule) - } - ) - - def __contains__(self, rule: Rule) -> bool: - """Checks if the rule is in the collection.""" - return (rule.id, rule.name) in self._rule_id_name_set - - @classmethod - def custom_collection_from_list( - cls, new_collection_class_name: str, rules: Sequence[Rule] - ) -> RuleCollection: - """Creates a custom class inherited from RuleCollection with the list of rules.""" - return dataclasses.make_dataclass( - new_collection_class_name, - [ - ( - formatter.kebab_case_to_snake_case(rule.name), - type(rule), - dataclasses.field(default=rule), - ) - for rule in rules - ], - bases=(cls,), - )() - - -class Invocation: - # TODO: Implement this. - # Tracks top level call arguments and diagnostic options. - def __init__(self) -> None: - raise NotImplementedError() - - -@dataclasses.dataclass -class DiagnosticOptions: - """Options for diagnostic context.""" - - log_verbose: bool = dataclasses.field(default=False) - log_level: Level = dataclasses.field(default=Level.ERROR) diff --git a/onnxscript/diagnostics/infra/context.py b/onnxscript/diagnostics/infra/context.py deleted file mode 100644 index 26d0c1bd27..0000000000 --- a/onnxscript/diagnostics/infra/context.py +++ /dev/null @@ -1,347 +0,0 @@ -"""A diagnostic context based on SARIF.""" - -from __future__ import annotations - -import contextlib -import dataclasses -import gzip -import logging -import typing -from typing import Callable, Generator, List, Literal, Mapping, Optional - -from onnxscript.diagnostics import infra -from onnxscript.diagnostics.infra import formatter, sarif, utils -from onnxscript.diagnostics.infra.sarif import version as sarif_version - -if typing.TYPE_CHECKING: - from typing_extensions import Self - - -@dataclasses.dataclass -class Diagnostic: - rule: infra.Rule - level: infra.Level - message: Optional[str] = None - locations: List[infra.Location] = dataclasses.field(default_factory=list) - stacks: List[infra.Stack] = dataclasses.field(default_factory=list) - graphs: List[infra.Graph] = dataclasses.field(default_factory=list) - thread_flow_locations: List[infra.ThreadFlowLocation] = dataclasses.field( - default_factory=list - ) - additional_message: Optional[str] = None - tags: List[infra.Tag] = dataclasses.field(default_factory=list) - source_exception: Optional[Exception] = None - """The exception that caused this diagnostic to be created.""" - - def __post_init__(self) -> None: - pass - - def sarif(self) -> sarif.Result: - """Returns the SARIF Result representation of this diagnostic.""" - message = self.message or self.rule.message_default_template - if self.additional_message: - message_markdown = ( - f"{message}\n\n## Additional Message:\n\n{self.additional_message}" - ) - else: - message_markdown = message - - kind: Literal["informational", "fail"] = ( - "informational" if self.level == infra.Level.NONE else "fail" - ) - - sarif_result = sarif.Result( - message=sarif.Message(text=message, markdown=message_markdown), - level=self.level.name.lower(), # type: ignore[arg-type] - rule_id=self.rule.id, - kind=kind, - ) - sarif_result.locations = [location.sarif() for location in self.locations] - sarif_result.stacks = [stack.sarif() for stack in self.stacks] - sarif_result.graphs = [graph.sarif() for graph in self.graphs] - sarif_result.code_flows = [ - sarif.CodeFlow( - thread_flows=[ - sarif.ThreadFlow( - locations=[loc.sarif() for loc in self.thread_flow_locations] - ) - ] - ) - ] - sarif_result.properties = sarif.PropertyBag(tags=[tag.value for tag in self.tags]) - return sarif_result - - def with_location(self: Self, location: infra.Location) -> Self: - """Adds a location to the diagnostic.""" - self.locations.append(location) - return self - - def with_thread_flow_location(self: Self, location: infra.ThreadFlowLocation) -> Self: - """Adds a thread flow location to the diagnostic.""" - self.thread_flow_locations.append(location) - return self - - def with_stack(self: Self, stack: infra.Stack) -> Self: - """Adds a stack to the diagnostic.""" - self.stacks.append(stack) - return self - - def with_graph(self: Self, graph: infra.Graph) -> Self: - """Adds a graph to the diagnostic.""" - self.graphs.append(graph) - return self - - def with_additional_message(self: Self, message: str) -> Self: - """Adds an additional message to the diagnostic.""" - if self.additional_message is None: - self.additional_message = message - else: - self.additional_message = f"{self.additional_message}\n{message}" - return self - - def with_source_exception(self: Self, exception: Exception) -> Self: - """Adds the source exception to the diagnostic.""" - self.source_exception = exception - return self - - def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack: - """Records the current Python call stack.""" - frames_to_skip += 1 # Skip this function. - stack = utils.python_call_stack(frames_to_skip=frames_to_skip) - self.with_stack(stack) - if len(stack.frames) > 0: - self.with_location(stack.frames[0].location) - return stack - - def record_python_call( - self, - fn: Callable, - state: Mapping[str, str], - message: Optional[str] = None, - frames_to_skip: int = 0, - ) -> infra.ThreadFlowLocation: - """Records a python call as one thread flow step.""" - frames_to_skip += 1 # Skip this function. - stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5) - location = utils.function_location(fn) - location.message = message - # Add function location to the top of the stack. - stack.frames.insert(0, infra.StackFrame(location=location)) - thread_flow_location = infra.ThreadFlowLocation( - location=location, - state=state, - index=len(self.thread_flow_locations), - stack=stack, - ) - self.with_thread_flow_location(thread_flow_location) - return thread_flow_location - - def pretty_print(self, verbose: bool = False, log_level: infra.Level = infra.Level.ERROR): - """Prints the diagnostics in a human-readable format. - - Args: - verbose: If True, prints all information. E.g. stack frames, graphs, etc. - Otherwise, only prints compact information. E.g., rule name and display message. - log_level: The minimum level of diagnostics to print. - """ - if self.level.value < log_level.value: - return - formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}") - print(self.message) - print(self.additional_message) - - if not verbose: - print("\n") - return - - formatter.pretty_print_title("Locations", fill_char="-") - for location in self.locations: - location.pretty_print() - for stack in self.stacks: - stack.pretty_print() - formatter.pretty_print_title("Thread Flow Locations", fill_char="-") - for thread_flow_location in self.thread_flow_locations: - thread_flow_location.pretty_print(verbose=verbose) - for graph in self.graphs: - graph.pretty_print(verbose=verbose) - - print() - - # TODO: print help url to rule at the end. - - -class RuntimeErrorWithDiagnosticError(RuntimeError): - """Runtime error with enclosed diagnostic information.""" - - def __init__(self, diagnostic: Diagnostic): - super().__init__(diagnostic.message) - self.diagnostic = diagnostic - - -@dataclasses.dataclass -class DiagnosticContext: - name: str - version: str - options: infra.DiagnosticOptions = dataclasses.field( - default_factory=infra.DiagnosticOptions - ) - diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list) - logger: logging.Logger = dataclasses.field( - init=True, default_factory=lambda: logging.getLogger().getChild("diagnostics") - ) - # TODO(bowbao): Implement this. - # _invocation: infra.Invocation = dataclasses.field(init=False) - _inflight_diagnostics: List[Diagnostic] = dataclasses.field( - init=False, default_factory=list - ) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return None - - def sarif(self) -> sarif.Run: - """Returns the SARIF Run object.""" - unique_rules = {diagnostic.rule for diagnostic in self.diagnostics} - return sarif.Run( - tool=sarif.Tool( - driver=sarif.ToolComponent( - name=self.name, - version=self.version, - rules=[rule.sarif() for rule in unique_rules], - ) - ), - results=[diagnostic.sarif() for diagnostic in self.diagnostics], - ) - - def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined] - """Returns the SARIF Log object.""" - return sarif.SarifLog( - version=sarif_version.SARIF_VERSION, - schema_uri=sarif_version.SARIF_SCHEMA_LINK, - runs=[self.sarif()], - ) - - def to_json(self) -> str: - return formatter.sarif_to_json(self.sarif_log()) - - def dump(self, file_path: str, compress: bool = False) -> None: - """Dumps the SARIF log to a file.""" - if compress: - with gzip.open(file_path, "wt", encoding="utf-8") as f: - f.write(self.to_json()) - else: - with open(file_path, "w", encoding="utf-8") as f: - f.write(self.to_json()) - - def log(self, diagnostic: Diagnostic) -> None: - """Adds a diagnostic to the context. - - Use this method to add diagnostics that are not created by the context. - - Args: - diagnostic: The diagnostic to add. - """ - if not isinstance(diagnostic, Diagnostic): - raise TypeError( - f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}" - ) - self.diagnostics.append(diagnostic) - self.logger.log(diagnostic.level, diagnostic.message) - self.logger.log(diagnostic.level, diagnostic.additional_message) - - def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None: - self.log(diagnostic) - if diagnostic.level == infra.Level.ERROR: - raise RuntimeErrorWithDiagnosticError(diagnostic) from diagnostic.source_exception - - @contextlib.contextmanager - def add_inflight_diagnostic( - self, diagnostic: Diagnostic - ) -> Generator[Diagnostic, None, None]: - """Adds a diagnostic to the context. - - Use this method to add diagnostics that are not created by the context. - - Args: - diagnostic: The diagnostic to add. - """ - self._inflight_diagnostics.append(diagnostic) - try: - yield diagnostic - finally: - self._inflight_diagnostics.pop() - - def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None: - """Pushes a diagnostic to the inflight diagnostics stack. - - Args: - diagnostic: The diagnostic to push. - - Raises: - ValueError: If the rule is not supported by the tool. - """ - self._inflight_diagnostics.append(diagnostic) - - def pop_inflight_diagnostic(self) -> Diagnostic: - """Pops the last diagnostic from the inflight diagnostics stack. - - Returns: - The popped diagnostic. - """ - return self._inflight_diagnostics.pop() - - def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic: - if rule is None: - # TODO(bowbao): Create builtin-rules and create diagnostic using that. - if len(self._inflight_diagnostics) <= 0: - raise AssertionError("No inflight diagnostics") - - return self._inflight_diagnostics[-1] - else: - # TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]] - for diagnostic in reversed(self._inflight_diagnostics): - if diagnostic.rule == rule: - return diagnostic - raise AssertionError(f"No inflight diagnostic for rule {rule.name}") - - def pretty_print( - self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None - ) -> None: - """Prints the diagnostics in a human-readable format. - - Args: - verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. - If not specified, uses the value of 'self.options.log_verbose'. - log_level: The minimum level of diagnostics to print. - If not specified, uses the value of 'self.options.log_level'. - """ - if verbose is None: - verbose = self.options.log_verbose - if log_level is None: - log_level = self.options.log_level - - formatter.pretty_print_title(f"Diagnostic Run {self.name} version {self.version}") - print(f"verbose: {verbose}, log level: {log_level}") - diagnostic_stats = {level: 0 for level in infra.Level} - for diagnostic in self.diagnostics: - diagnostic_stats[diagnostic.level] += 1 - formatter.pretty_print_title( - " ".join(f"{diagnostic_stats[level]} {level.name}" for level in infra.Level) - ) - - for diagnostic in self.diagnostics: - diagnostic.pretty_print(verbose, log_level) - - unprinted_diagnostic_stats = [ - (level, count) - for level, count in diagnostic_stats.items() - if count > 0 and level.value < log_level.value - ] - if unprinted_diagnostic_stats: - print( - f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} " - "were not printed due to the log level." - ) - print() diff --git a/onnxscript/diagnostics/infra/decorator.py b/onnxscript/diagnostics/infra/decorator.py deleted file mode 100644 index e72da19c42..0000000000 --- a/onnxscript/diagnostics/infra/decorator.py +++ /dev/null @@ -1,151 +0,0 @@ -from __future__ import annotations - -import functools -import traceback -from typing import Any, Callable, Dict, Optional, Tuple, Type - -from onnxscript._internal import runtime_typing -from onnxscript.diagnostics import infra -from onnxscript.diagnostics.infra import formatter, utils - -MessageFormatterType = Callable[..., str] - - -@runtime_typing.checked -def format_message_in_text( - fn: Callable, # pylint: disable=unused-argument - *args: Any, - **kwargs: Any, -) -> str: - return f"{formatter.display_name(fn)}. " - - -@runtime_typing.checked -def format_exception_in_markdown(exception: Exception) -> str: - msg_list = ["### Exception log", "```"] - msg_list.extend( - traceback.format_exception(type(exception), exception, exception.__traceback__) - ) - msg_list.append("```") - return "\n".join(msg_list) - - -@runtime_typing.checked -def format_function_signature_in_markdown( - fn: Callable, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - format_argument: Callable[[Any], str] = formatter.format_argument, -) -> str: - msg_list = [f"### Function Signature {formatter.display_name(fn)}"] - - state = utils.function_state(fn, args, kwargs) - - for k, v in state.items(): - msg_list.append(f"- {k}: {format_argument(v)}") - - return "\n".join(msg_list) - - -@runtime_typing.checked -def format_return_values_in_markdown( - return_values: Any, - format_argument: Callable[[Any], str] = formatter.format_argument, -) -> str: - return f"- Return value: {format_argument(return_values)}" - - -ModifierCallableType = Callable[ - [infra.Diagnostic, Callable, Tuple[Any, ...], Dict[str, Any], Any], None -] - - -@runtime_typing.checked -def diagnose_call( - rule: infra.Rule, - *, - level: infra.Level = infra.Level.NONE, - diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic, - format_argument: Callable[[Any], str] = formatter.format_argument, - diagnostic_message_formatter: MessageFormatterType = format_message_in_text, -) -> Callable: - def decorator(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements - common_error_message = "diagnose_call can only be applied to callables" - if not callable(fn): - raise AssertionError( # noqa: TRY004 - f"{common_error_message}. Got {type(fn)} instead of callable." - ) - arg0 = args[0] if len(args) > 0 else None - if isinstance(ctx := arg0, infra.DiagnosticContext): - pass - elif isinstance( - ctx := getattr(arg0, "diagnostic_context", None), - infra.DiagnosticContext, - ): - pass - else: - # NOTE: At decorate time, it can't tell if a callable is function or method. - # Technically both are regarded as function at that time. - raise AssertionError( # noqa: TRY004 - f"{common_error_message}. For {fn}, " - f"If it is a function, a DiagnosticContext instance must be present as " - f"the first argument. " - f"If it is a method, a DiagnosticContext instance must be present as " - f"the attribute 'diagnostic_context' of the 'self' argument." - ) - - diag = diagnostic_type( - rule, - level, - diagnostic_message_formatter(fn, *args, **kwargs), - ) - - # pop the decorator frame - # TODO(bowbao): by default diagnostic doesn't have stack. - # So need to check before doing this. Make the code cleaner. - # Option: do not capture stack by default in diagnostic initialization. - stack: Optional[infra.Stack] = None - if len(diag.stacks) > 0: - stack = diag.stacks[0] - stack.frames.pop(0) - - # set function location - fn_location = utils.function_location(fn) - diag.locations.insert(0, fn_location) - # Add function location to the top of the stack. - if stack is not None: - stack.frames.insert(0, infra.StackFrame(location=fn_location)) - - additional_messages = [ - format_function_signature_in_markdown(fn, args, kwargs, format_argument), - ] - - return_values: Any = None - with ctx.add_inflight_diagnostic(diag) as diag: - try: - return_values = fn(*args, **kwargs) - additional_messages.append( - format_return_values_in_markdown(return_values, format_argument) - ) - except Exception as e: # pylint: disable=broad-exception-caught - # Record exception. - diag.level = infra.levels.ERROR - # TODO(bowbao): Message emitting api. - diag.message = diag.message or "" - diag.message += f"Raised from:\n {type(e).__name__}: {e}" - diag.with_source_exception(e) - additional_messages.append(format_exception_in_markdown(e)) - else: - return return_values - finally: - diag.with_additional_message("\n".join(additional_messages).strip()) - ctx.log_and_raise_if_error(diag) - - return wrapper - - return decorator - - -# TODO(bowbao): decorator to report only when failed. diff --git a/onnxscript/diagnostics/infra/formatter.py b/onnxscript/diagnostics/infra/formatter.py deleted file mode 100644 index c54e81fed4..0000000000 --- a/onnxscript/diagnostics/infra/formatter.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -import dataclasses -import json -import re -from typing import Any, Callable, Dict, List, Optional, Union - -from onnxscript._internal import runtime_typing -from onnxscript.diagnostics.infra import sarif - -# A list of types in the SARIF module to support pretty printing. -# This is solely for type annotation for the functions below. -_SarifClass = Union[ - sarif.SarifLog, - sarif.Run, - sarif.ReportingDescriptor, - sarif.Result, -] - - -@runtime_typing.checked -def snake_case_to_camel_case(s: str) -> str: - splits = s.split("_") - if len(splits) <= 1: - return s - return "".join([splits[0], *map(str.capitalize, splits[1:])]) - - -@runtime_typing.checked -def camel_case_to_snake_case(s: str) -> str: - return re.sub(r"([A-Z])", r"_\1", s).lower() - - -@runtime_typing.checked -def kebab_case_to_snake_case(s: str) -> str: - return s.replace("-", "_") - - -@runtime_typing.checked -def _convert_key( - object: Union[Dict[str, Any], Any], convert: Callable[[str], str] -) -> Union[Dict[str, Any], Any]: - """Convert and update keys in a dictionary with "convert". - - Any value that is a dictionary will be recursively updated. - Any value that is a list will be recursively searched. - - Args: - object: The object to update. - convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`. - - Returns: - The updated object. - """ - if not isinstance(object, Dict): - return object - new_dict = {} - for k, v in object.items(): - new_k = convert(k) - if isinstance(v, Dict): - new_v = _convert_key(v, convert) - elif isinstance(v, List): - new_v = [_convert_key(elem, convert) for elem in v] - else: - new_v = v - if new_v is None: - # Otherwise unnesseraily bloated sarif log with "null"s. - continue - if new_v == -1: - # WAR: -1 as default value shouldn't be logged into sarif. - continue - - new_dict[new_k] = new_v - - return new_dict - - -@runtime_typing.checked -def sarif_to_json(attr_cls_obj: _SarifClass, indent: Optional[str] = " ") -> str: - dict = dataclasses.asdict(attr_cls_obj) - dict = _convert_key(dict, snake_case_to_camel_case) - return json.dumps(dict, indent=indent, separators=(",", ":")) - - -@runtime_typing.checked -def pretty_print_title( - title: str, width: int = 80, fill_char: str = "=", print_output: bool = True -) -> str: - """Pretty prints title in below format: - - ==================== title ==================== - """ - msg = f" {title} ".center(width, fill_char) - if print_output: - print(msg) - return msg - - -@runtime_typing.checked -def pretty_print_item_title( - title: str, fill_char: str = "=", print_output: bool = True -) -> str: - """Pretty prints title in below format: - - title - ===== - """ - msg_list = [] - msg_list.append(title) - msg_list.append(fill_char * len(title)) - - msg = "\n".join(msg_list) - if print_output: - print(msg) - return msg - - -@runtime_typing.checked -def format_argument(obj: Any) -> str: - return f"{type(obj)}" - - -@runtime_typing.checked -def display_name(fn: Callable) -> str: - if hasattr(fn, "__qualname__"): - return fn.__qualname__ - elif hasattr(fn, "__name__"): - return fn.__name__ - else: - return str(fn) diff --git a/onnxscript/diagnostics/infra/sarif/__init__.py b/onnxscript/diagnostics/infra/sarif/__init__.py deleted file mode 100644 index e610c3b754..0000000000 --- a/onnxscript/diagnostics/infra/sarif/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from onnxscript.diagnostics.infra.sarif._address import Address -from onnxscript.diagnostics.infra.sarif._artifact import Artifact -from onnxscript.diagnostics.infra.sarif._artifact_change import ArtifactChange -from onnxscript.diagnostics.infra.sarif._artifact_content import ArtifactContent -from onnxscript.diagnostics.infra.sarif._artifact_location import ArtifactLocation -from onnxscript.diagnostics.infra.sarif._attachment import Attachment -from onnxscript.diagnostics.infra.sarif._code_flow import CodeFlow -from onnxscript.diagnostics.infra.sarif._configuration_override import ( - ConfigurationOverride, -) -from onnxscript.diagnostics.infra.sarif._conversion import Conversion -from onnxscript.diagnostics.infra.sarif._edge import Edge -from onnxscript.diagnostics.infra.sarif._edge_traversal import EdgeTraversal -from onnxscript.diagnostics.infra.sarif._exception import Exception -from onnxscript.diagnostics.infra.sarif._external_properties import ExternalProperties -from onnxscript.diagnostics.infra.sarif._external_property_file_reference import ( - ExternalPropertyFileReference, -) -from onnxscript.diagnostics.infra.sarif._external_property_file_references import ( - ExternalPropertyFileReferences, -) -from onnxscript.diagnostics.infra.sarif._fix import Fix -from onnxscript.diagnostics.infra.sarif._graph import Graph -from onnxscript.diagnostics.infra.sarif._graph_traversal import GraphTraversal -from onnxscript.diagnostics.infra.sarif._invocation import Invocation -from onnxscript.diagnostics.infra.sarif._location import Location -from onnxscript.diagnostics.infra.sarif._location_relationship import ( - LocationRelationship, -) -from onnxscript.diagnostics.infra.sarif._logical_location import LogicalLocation -from onnxscript.diagnostics.infra.sarif._message import Message -from onnxscript.diagnostics.infra.sarif._multiformat_message_string import ( - MultiformatMessageString, -) -from onnxscript.diagnostics.infra.sarif._node import Node -from onnxscript.diagnostics.infra.sarif._notification import Notification -from onnxscript.diagnostics.infra.sarif._physical_location import PhysicalLocation -from onnxscript.diagnostics.infra.sarif._property_bag import PropertyBag -from onnxscript.diagnostics.infra.sarif._rectangle import Rectangle -from onnxscript.diagnostics.infra.sarif._region import Region -from onnxscript.diagnostics.infra.sarif._replacement import Replacement -from onnxscript.diagnostics.infra.sarif._reporting_configuration import ( - ReportingConfiguration, -) -from onnxscript.diagnostics.infra.sarif._reporting_descriptor import ReportingDescriptor -from onnxscript.diagnostics.infra.sarif._reporting_descriptor_reference import ( - ReportingDescriptorReference, -) -from onnxscript.diagnostics.infra.sarif._reporting_descriptor_relationship import ( - ReportingDescriptorRelationship, -) -from onnxscript.diagnostics.infra.sarif._result import Result -from onnxscript.diagnostics.infra.sarif._result_provenance import ResultProvenance -from onnxscript.diagnostics.infra.sarif._run import Run -from onnxscript.diagnostics.infra.sarif._run_automation_details import ( - RunAutomationDetails, -) -from onnxscript.diagnostics.infra.sarif._sarif_log import SarifLog -from onnxscript.diagnostics.infra.sarif._special_locations import SpecialLocations -from onnxscript.diagnostics.infra.sarif._stack import Stack -from onnxscript.diagnostics.infra.sarif._stack_frame import StackFrame -from onnxscript.diagnostics.infra.sarif._suppression import Suppression -from onnxscript.diagnostics.infra.sarif._thread_flow import ThreadFlow -from onnxscript.diagnostics.infra.sarif._thread_flow_location import ThreadFlowLocation -from onnxscript.diagnostics.infra.sarif._tool import Tool -from onnxscript.diagnostics.infra.sarif._tool_component import ToolComponent -from onnxscript.diagnostics.infra.sarif._tool_component_reference import ( - ToolComponentReference, -) -from onnxscript.diagnostics.infra.sarif._translation_metadata import TranslationMetadata -from onnxscript.diagnostics.infra.sarif._version_control_details import ( - VersionControlDetails, -) -from onnxscript.diagnostics.infra.sarif._web_request import WebRequest -from onnxscript.diagnostics.infra.sarif._web_response import WebResponse - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_address.py b/onnxscript/diagnostics/infra/sarif/_address.py deleted file mode 100644 index c4b691f348..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_address.py +++ /dev/null @@ -1,46 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class Address: - """A physical or virtual address, or a range of addresses, in an 'addressable region' (memory or a binary file).""" - - absolute_address: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "absoluteAddress"} - ) - fully_qualified_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullyQualifiedName"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - kind: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "kind"} - ) - length: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "length"} - ) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - offset_from_parent: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "offsetFromParent"} - ) - parent_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "parentIndex"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - relative_address: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "relativeAddress"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact.py b/onnxscript/diagnostics/infra/sarif/_artifact.py deleted file mode 100644 index afec8b5e97..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact.py +++ /dev/null @@ -1,84 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_content, - _artifact_location, - _message, - _property_bag, -) - - -@dataclasses.dataclass -class Artifact: - """A single artifact. In some cases, this artifact might be nested within another artifact.""" - - contents: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "contents"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - encoding: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "encoding"} - ) - hashes: Any = dataclasses.field(default=None, metadata={"schema_property_name": "hashes"}) - last_modified_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "lastModifiedTimeUtc"} - ) - length: int = dataclasses.field(default=-1, metadata={"schema_property_name": "length"}) - location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - mime_type: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "mimeType"} - ) - offset: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "offset"} - ) - parent_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "parentIndex"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - roles: Optional[ - List[ - Literal[ - "analysisTarget", - "attachment", - "responseFile", - "resultFile", - "standardStream", - "tracedFile", - "unmodified", - "modified", - "added", - "deleted", - "renamed", - "uncontrolled", - "driver", - "extension", - "translation", - "taxonomy", - "policy", - "referencedOnCommandLine", - "memoryContents", - "directory", - "userSpecifiedConfiguration", - "toolSpecifiedConfiguration", - "debugOutputFile", - ] - ] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "roles"}) - source_language: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "sourceLanguage"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_change.py b/onnxscript/diagnostics/infra/sarif/_artifact_change.py deleted file mode 100644 index 3db2c0444b..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact_change.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _property_bag, - _replacement, -) - - -@dataclasses.dataclass -class ArtifactChange: - """A change to a single artifact.""" - - artifact_location: _artifact_location.ArtifactLocation = dataclasses.field( - metadata={"schema_property_name": "artifactLocation"} - ) - replacements: List[_replacement.Replacement] = dataclasses.field( - metadata={"schema_property_name": "replacements"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_content.py b/onnxscript/diagnostics/infra/sarif/_artifact_content.py deleted file mode 100644 index 4038066198..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact_content.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _multiformat_message_string, - _property_bag, -) - - -@dataclasses.dataclass -class ArtifactContent: - """Represents the contents of an artifact.""" - - binary: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "binary"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - rendered: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "rendered"}) - ) - text: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "text"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_location.py b/onnxscript/diagnostics/infra/sarif/_artifact_location.py deleted file mode 100644 index ed6f9b3916..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact_location.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class ArtifactLocation: - """Specifies the location of an artifact.""" - - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "uri"} - ) - uri_base_id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "uriBaseId"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_attachment.py b/onnxscript/diagnostics/infra/sarif/_attachment.py deleted file mode 100644 index b58b858e0c..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_attachment.py +++ /dev/null @@ -1,39 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _message, - _property_bag, - _rectangle, - _region, -) - - -@dataclasses.dataclass -class Attachment: - """An artifact relevant to a result.""" - - artifact_location: _artifact_location.ArtifactLocation = dataclasses.field( - metadata={"schema_property_name": "artifactLocation"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - rectangles: Optional[List[_rectangle.Rectangle]] = dataclasses.field( - default=None, metadata={"schema_property_name": "rectangles"} - ) - regions: Optional[List[_region.Region]] = dataclasses.field( - default=None, metadata={"schema_property_name": "regions"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_code_flow.py b/onnxscript/diagnostics/infra/sarif/_code_flow.py deleted file mode 100644 index 69615f18f2..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_code_flow.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag, _thread_flow - - -@dataclasses.dataclass -class CodeFlow: - """A set of threadFlows which together describe a pattern of code execution relevant to detecting a result.""" - - thread_flows: List[_thread_flow.ThreadFlow] = dataclasses.field( - metadata={"schema_property_name": "threadFlows"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_configuration_override.py b/onnxscript/diagnostics/infra/sarif/_configuration_override.py deleted file mode 100644 index c2fa3ae0a6..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_configuration_override.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _property_bag, - _reporting_configuration, - _reporting_descriptor_reference, -) - - -@dataclasses.dataclass -class ConfigurationOverride: - """Information about how a specific rule or notification was reconfigured at runtime.""" - - configuration: _reporting_configuration.ReportingConfiguration = dataclasses.field( - metadata={"schema_property_name": "configuration"} - ) - descriptor: _reporting_descriptor_reference.ReportingDescriptorReference = ( - dataclasses.field(metadata={"schema_property_name": "descriptor"}) - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_conversion.py b/onnxscript/diagnostics/infra/sarif/_conversion.py deleted file mode 100644 index 6078c525f0..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_conversion.py +++ /dev/null @@ -1,35 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _invocation, - _property_bag, - _tool, -) - - -@dataclasses.dataclass -class Conversion: - """Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format.""" - - tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"}) - analysis_tool_log_files: Optional[List[_artifact_location.ArtifactLocation]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "analysisToolLogFiles"} - ) - ) - invocation: Optional[_invocation.Invocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "invocation"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_edge.py b/onnxscript/diagnostics/infra/sarif/_edge.py deleted file mode 100644 index 1142e61dca..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_edge.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class Edge: - """Represents a directed edge in a graph.""" - - id: str = dataclasses.field(metadata={"schema_property_name": "id"}) - source_node_id: str = dataclasses.field(metadata={"schema_property_name": "sourceNodeId"}) - target_node_id: str = dataclasses.field(metadata={"schema_property_name": "targetNodeId"}) - label: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "label"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_edge_traversal.py b/onnxscript/diagnostics/infra/sarif/_edge_traversal.py deleted file mode 100644 index dbaba449e4..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_edge_traversal.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class EdgeTraversal: - """Represents the traversal of a single edge during a graph traversal.""" - - edge_id: str = dataclasses.field(metadata={"schema_property_name": "edgeId"}) - final_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "finalState"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - step_over_edge_count: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "stepOverEdgeCount"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_exception.py b/onnxscript/diagnostics/infra/sarif/_exception.py deleted file mode 100644 index 71c0db73a8..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_exception.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _exception, _property_bag, _stack - - -@dataclasses.dataclass -class Exception: - """Describes a runtime exception encountered during the execution of an analysis tool.""" - - inner_exceptions: Optional[List[_exception.Exception]] = dataclasses.field( - default=None, metadata={"schema_property_name": "innerExceptions"} - ) - kind: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "kind"} - ) - message: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - stack: Optional[_stack.Stack] = dataclasses.field( - default=None, metadata={"schema_property_name": "stack"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_external_properties.py b/onnxscript/diagnostics/infra/sarif/_external_properties.py deleted file mode 100644 index d63a16aff8..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_external_properties.py +++ /dev/null @@ -1,96 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _address, - _artifact, - _conversion, - _graph, - _invocation, - _logical_location, - _property_bag, - _result, - _thread_flow_location, - _tool_component, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class ExternalProperties: - """The top-level element of an external property file.""" - - addresses: Optional[List[_address.Address]] = dataclasses.field( - default=None, metadata={"schema_property_name": "addresses"} - ) - artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field( - default=None, metadata={"schema_property_name": "artifacts"} - ) - conversion: Optional[_conversion.Conversion] = dataclasses.field( - default=None, metadata={"schema_property_name": "conversion"} - ) - driver: Optional[_tool_component.ToolComponent] = dataclasses.field( - default=None, metadata={"schema_property_name": "driver"} - ) - extensions: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "extensions"} - ) - externalized_properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "externalizedProperties"} - ) - graphs: Optional[List[_graph.Graph]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphs"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - invocations: Optional[List[_invocation.Invocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "invocations"} - ) - logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "logicalLocations"} - ) - policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "policies"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - results: Optional[List[_result.Result]] = dataclasses.field( - default=None, metadata={"schema_property_name": "results"} - ) - run_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "runGuid"} - ) - schema: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "schema"} - ) - taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "taxonomies"} - ) - thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "threadFlowLocations"} - ) - ) - translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "translations"} - ) - version: Optional[Literal["2.1.0"]] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - web_requests: Optional[List[_web_request.WebRequest]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequests"} - ) - web_responses: Optional[List[_web_response.WebResponse]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponses"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py b/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py deleted file mode 100644 index b5bfec0320..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py +++ /dev/null @@ -1,30 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag - - -@dataclasses.dataclass -class ExternalPropertyFileReference: - """Contains information that enables a SARIF consumer to locate the external property file that contains the value of an externalized property associated with the run.""" - - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - item_count: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "itemCount"} - ) - location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py b/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py deleted file mode 100644 index d596a7a87a..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py +++ /dev/null @@ -1,76 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _external_property_file_reference, - _property_bag, -) - - -@dataclasses.dataclass -class ExternalPropertyFileReferences: - """References to external property files that should be inlined with the content of a root log file.""" - - addresses: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "addresses"}) - artifacts: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "artifacts"}) - conversion: Optional[_external_property_file_reference.ExternalPropertyFileReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "conversion"}) - ) - driver: Optional[_external_property_file_reference.ExternalPropertyFileReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "driver"}) - ) - extensions: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "extensions"}) - externalized_properties: Optional[ - _external_property_file_reference.ExternalPropertyFileReference - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "externalizedProperties"} - ) - graphs: Optional[List[_external_property_file_reference.ExternalPropertyFileReference]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "graphs"}) - ) - invocations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "invocations"}) - logical_locations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "logicalLocations"}) - policies: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "policies"}) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - results: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "results"}) - taxonomies: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "taxonomies"}) - thread_flow_locations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "threadFlowLocations"} - ) - translations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "translations"}) - web_requests: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "webRequests"}) - web_responses: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "webResponses"}) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_fix.py b/onnxscript/diagnostics/infra/sarif/_fix.py deleted file mode 100644 index 042f70f47a..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_fix.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_change, _message, _property_bag - - -@dataclasses.dataclass -class Fix: - """A proposed fix for the problem represented by a result object. A fix specifies a set of artifacts to modify. For each artifact, it specifies a set of bytes to remove, and provides a set of new bytes to replace them.""" - - artifact_changes: List[_artifact_change.ArtifactChange] = dataclasses.field( - metadata={"schema_property_name": "artifactChanges"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_graph.py b/onnxscript/diagnostics/infra/sarif/_graph.py deleted file mode 100644 index f068e663de..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_graph.py +++ /dev/null @@ -1,30 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _edge, _message, _node, _property_bag - - -@dataclasses.dataclass -class Graph: - """A network of nodes and directed edges that describes some aspect of the structure of the code (for example, a call graph).""" - - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - edges: Optional[List[_edge.Edge]] = dataclasses.field( - default=None, metadata={"schema_property_name": "edges"} - ) - nodes: Optional[List[_node.Node]] = dataclasses.field( - default=None, metadata={"schema_property_name": "nodes"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_graph_traversal.py b/onnxscript/diagnostics/infra/sarif/_graph_traversal.py deleted file mode 100644 index ec9c92a9f8..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_graph_traversal.py +++ /dev/null @@ -1,39 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import _edge_traversal, _message, _property_bag - - -@dataclasses.dataclass -class GraphTraversal: - """Represents a path through a graph.""" - - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - edge_traversals: Optional[List[_edge_traversal.EdgeTraversal]] = dataclasses.field( - default=None, metadata={"schema_property_name": "edgeTraversals"} - ) - immutable_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "immutableState"} - ) - initial_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "initialState"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - result_graph_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "resultGraphIndex"} - ) - run_graph_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "runGraphIndex"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_invocation.py b/onnxscript/diagnostics/infra/sarif/_invocation.py deleted file mode 100644 index 6f96c9a86c..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_invocation.py +++ /dev/null @@ -1,111 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _configuration_override, - _notification, - _property_bag, -) - - -@dataclasses.dataclass -class Invocation: - """The runtime environment of the analysis tool run.""" - - execution_successful: bool = dataclasses.field( - metadata={"schema_property_name": "executionSuccessful"} - ) - account: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "account"} - ) - arguments: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "arguments"} - ) - command_line: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "commandLine"} - ) - end_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "endTimeUtc"} - ) - environment_variables: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "environmentVariables"} - ) - executable_location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "executableLocation"} - ) - exit_code: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitCode"} - ) - exit_code_description: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitCodeDescription"} - ) - exit_signal_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitSignalName"} - ) - exit_signal_number: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitSignalNumber"} - ) - machine: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "machine"} - ) - notification_configuration_overrides: Optional[ - List[_configuration_override.ConfigurationOverride] - ] = dataclasses.field( - default=None, - metadata={"schema_property_name": "notificationConfigurationOverrides"}, - ) - process_id: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "processId"} - ) - process_start_failure_message: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "processStartFailureMessage"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - response_files: Optional[List[_artifact_location.ArtifactLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "responseFiles"} - ) - rule_configuration_overrides: Optional[ - List[_configuration_override.ConfigurationOverride] - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "ruleConfigurationOverrides"} - ) - start_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "startTimeUtc"} - ) - stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stderr"} - ) - stdin: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stdin"} - ) - stdout: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stdout"} - ) - stdout_stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stdoutStderr"} - ) - tool_configuration_notifications: Optional[List[_notification.Notification]] = ( - dataclasses.field( - default=None, - metadata={"schema_property_name": "toolConfigurationNotifications"}, - ) - ) - tool_execution_notifications: Optional[List[_notification.Notification]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "toolExecutionNotifications"} - ) - ) - working_directory: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "workingDirectory"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_location.py b/onnxscript/diagnostics/infra/sarif/_location.py deleted file mode 100644 index 319856f8df..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_location.py +++ /dev/null @@ -1,44 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _location_relationship, - _logical_location, - _message, - _physical_location, - _property_bag, - _region, -) - - -@dataclasses.dataclass -class Location: - """A location within a programming artifact.""" - - annotations: Optional[List[_region.Region]] = dataclasses.field( - default=None, metadata={"schema_property_name": "annotations"} - ) - id: int = dataclasses.field(default=-1, metadata={"schema_property_name": "id"}) - logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "logicalLocations"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - physical_location: Optional[_physical_location.PhysicalLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "physicalLocation"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - relationships: Optional[List[_location_relationship.LocationRelationship]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "relationships"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_location_relationship.py b/onnxscript/diagnostics/infra/sarif/_location_relationship.py deleted file mode 100644 index 35ca00c8a6..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_location_relationship.py +++ /dev/null @@ -1,28 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class LocationRelationship: - """Information about the relation of one location to another.""" - - target: int = dataclasses.field(metadata={"schema_property_name": "target"}) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - kinds: List[str] = dataclasses.field( - default_factory=lambda: ["relevant"], metadata={"schema_property_name": "kinds"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_logical_location.py b/onnxscript/diagnostics/infra/sarif/_logical_location.py deleted file mode 100644 index 7f2880eef2..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_logical_location.py +++ /dev/null @@ -1,37 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class LogicalLocation: - """A logical location of a construct that produced a result.""" - - decorated_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "decoratedName"} - ) - fully_qualified_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullyQualifiedName"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - kind: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "kind"} - ) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - parent_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "parentIndex"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_message.py b/onnxscript/diagnostics/infra/sarif/_message.py deleted file mode 100644 index 0c9adce220..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_message.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class Message: - """Encapsulates a message intended to be read by the end user.""" - - arguments: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "arguments"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - markdown: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "markdown"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - text: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "text"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py b/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py deleted file mode 100644 index 154b9cc416..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py +++ /dev/null @@ -1,25 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class MultiformatMessageString: - """A message string or message format string rendered in multiple formats.""" - - text: str = dataclasses.field(metadata={"schema_property_name": "text"}) - markdown: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "markdown"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_node.py b/onnxscript/diagnostics/infra/sarif/_node.py deleted file mode 100644 index 0f11e37318..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_node.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _location, _message, _node, _property_bag - - -@dataclasses.dataclass -class Node: - """Represents a node in a graph.""" - - id: str = dataclasses.field(metadata={"schema_property_name": "id"}) - children: Optional[List[_node.Node]] = dataclasses.field( - default=None, metadata={"schema_property_name": "children"} - ) - label: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "label"} - ) - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_notification.py b/onnxscript/diagnostics/infra/sarif/_notification.py deleted file mode 100644 index f41a9f8d5b..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_notification.py +++ /dev/null @@ -1,49 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _exception, - _location, - _message, - _property_bag, - _reporting_descriptor_reference, -) - - -@dataclasses.dataclass -class Notification: - """Describes a condition relevant to the tool itself, as opposed to being relevant to a target being analyzed by the tool.""" - - message: _message.Message = dataclasses.field(metadata={"schema_property_name": "message"}) - associated_rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "associatedRule"}) - ) - descriptor: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "descriptor"}) - ) - exception: Optional[_exception.Exception] = dataclasses.field( - default=None, metadata={"schema_property_name": "exception"} - ) - level: Literal["none", "note", "warning", "error"] = dataclasses.field( - default="warning", metadata={"schema_property_name": "level"} - ) - locations: Optional[List[_location.Location]] = dataclasses.field( - default=None, metadata={"schema_property_name": "locations"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - thread_id: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "threadId"} - ) - time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "timeUtc"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_physical_location.py b/onnxscript/diagnostics/infra/sarif/_physical_location.py deleted file mode 100644 index 357e85af4e..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_physical_location.py +++ /dev/null @@ -1,38 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _address, - _artifact_location, - _property_bag, - _region, -) - - -@dataclasses.dataclass -class PhysicalLocation: - """A physical location relevant to a result. Specifies a reference to a programming artifact together with a range of bytes or characters within that artifact.""" - - address: Optional[_address.Address] = dataclasses.field( - default=None, metadata={"schema_property_name": "address"} - ) - artifact_location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "artifactLocation"} - ) - context_region: Optional[_region.Region] = dataclasses.field( - default=None, metadata={"schema_property_name": "contextRegion"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - region: Optional[_region.Region] = dataclasses.field( - default=None, metadata={"schema_property_name": "region"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_property_bag.py b/onnxscript/diagnostics/infra/sarif/_property_bag.py deleted file mode 100644 index 0b95c6e6e5..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_property_bag.py +++ /dev/null @@ -1,19 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - - -@dataclasses.dataclass -class PropertyBag: - """Key/value pairs that provide additional information about the object.""" - - tags: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "tags"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_rectangle.py b/onnxscript/diagnostics/infra/sarif/_rectangle.py deleted file mode 100644 index a7c9aecd1a..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_rectangle.py +++ /dev/null @@ -1,36 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class Rectangle: - """An area within an image.""" - - bottom: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "bottom"} - ) - left: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "left"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - right: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "right"} - ) - top: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "top"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_region.py b/onnxscript/diagnostics/infra/sarif/_region.py deleted file mode 100644 index 35a4b7f316..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_region.py +++ /dev/null @@ -1,58 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_content, - _message, - _property_bag, -) - - -@dataclasses.dataclass -class Region: - """A region within an artifact where a result was detected.""" - - byte_length: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "byteLength"} - ) - byte_offset: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "byteOffset"} - ) - char_length: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "charLength"} - ) - char_offset: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "charOffset"} - ) - end_column: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "endColumn"} - ) - end_line: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "endLine"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - snippet: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "snippet"} - ) - source_language: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "sourceLanguage"} - ) - start_column: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "startColumn"} - ) - start_line: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "startLine"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_replacement.py b/onnxscript/diagnostics/infra/sarif/_replacement.py deleted file mode 100644 index 125ed75708..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_replacement.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag, _region - - -@dataclasses.dataclass -class Replacement: - """The replacement of a single region of an artifact.""" - - deleted_region: _region.Region = dataclasses.field( - metadata={"schema_property_name": "deletedRegion"} - ) - inserted_content: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "insertedContent"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py b/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py deleted file mode 100644 index e3da0a77b8..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Literal, Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class ReportingConfiguration: - """Information about a rule or notification that can be configured at runtime.""" - - enabled: bool = dataclasses.field( - default=True, metadata={"schema_property_name": "enabled"} - ) - level: Literal["none", "note", "warning", "error"] = dataclasses.field( - default="warning", metadata={"schema_property_name": "level"} - ) - parameters: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "parameters"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - rank: float = dataclasses.field(default=-1.0, metadata={"schema_property_name": "rank"}) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py deleted file mode 100644 index 85e14f3763..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py +++ /dev/null @@ -1,65 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _multiformat_message_string, - _property_bag, - _reporting_configuration, - _reporting_descriptor_relationship, -) - - -@dataclasses.dataclass -class ReportingDescriptor: - """Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime reporting.""" - - id: str = dataclasses.field(metadata={"schema_property_name": "id"}) - default_configuration: Optional[_reporting_configuration.ReportingConfiguration] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "defaultConfiguration"} - ) - ) - deprecated_guids: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "deprecatedGuids"} - ) - deprecated_ids: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "deprecatedIds"} - ) - deprecated_names: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "deprecatedNames"} - ) - full_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"}) - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - help: Optional[_multiformat_message_string.MultiformatMessageString] = dataclasses.field( - default=None, metadata={"schema_property_name": "help"} - ) - help_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "helpUri"} - ) - message_strings: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "messageStrings"} - ) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - relationships: Optional[ - List[_reporting_descriptor_relationship.ReportingDescriptorRelationship] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "relationships"}) - short_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py deleted file mode 100644 index f4e6f2260d..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag, _tool_component_reference - - -@dataclasses.dataclass -class ReportingDescriptorReference: - """Information about how to locate a relevant reporting descriptor.""" - - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - tool_component: Optional[_tool_component_reference.ToolComponentReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "toolComponent"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py deleted file mode 100644 index 52db517db5..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py +++ /dev/null @@ -1,34 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _message, - _property_bag, - _reporting_descriptor_reference, -) - - -@dataclasses.dataclass -class ReportingDescriptorRelationship: - """Information about the relation of one reporting descriptor to another.""" - - target: _reporting_descriptor_reference.ReportingDescriptorReference = dataclasses.field( - metadata={"schema_property_name": "target"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - kinds: List[str] = dataclasses.field( - default_factory=lambda: ["relevant"], metadata={"schema_property_name": "kinds"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_result.py b/onnxscript/diagnostics/infra/sarif/_result.py deleted file mode 100644 index 3dfa564b54..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_result.py +++ /dev/null @@ -1,120 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _attachment, - _code_flow, - _fix, - _graph, - _graph_traversal, - _location, - _message, - _property_bag, - _reporting_descriptor_reference, - _result_provenance, - _stack, - _suppression, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class Result: - """A result produced by an analysis tool.""" - - message: _message.Message = dataclasses.field(metadata={"schema_property_name": "message"}) - analysis_target: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "analysisTarget"} - ) - attachments: Optional[List[_attachment.Attachment]] = dataclasses.field( - default=None, metadata={"schema_property_name": "attachments"} - ) - baseline_state: Optional[Literal["new", "unchanged", "updated", "absent"]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "baselineState"}) - ) - code_flows: Optional[List[_code_flow.CodeFlow]] = dataclasses.field( - default=None, metadata={"schema_property_name": "codeFlows"} - ) - correlation_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "correlationGuid"} - ) - fingerprints: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "fingerprints"} - ) - fixes: Optional[List[_fix.Fix]] = dataclasses.field( - default=None, metadata={"schema_property_name": "fixes"} - ) - graph_traversals: Optional[List[_graph_traversal.GraphTraversal]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphTraversals"} - ) - graphs: Optional[List[_graph.Graph]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphs"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - hosted_viewer_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "hostedViewerUri"} - ) - kind: Literal["notApplicable", "pass", "fail", "review", "open", "informational"] = ( - dataclasses.field(default="fail", metadata={"schema_property_name": "kind"}) - ) - level: Literal["none", "note", "warning", "error"] = dataclasses.field( - default="warning", metadata={"schema_property_name": "level"} - ) - locations: Optional[List[_location.Location]] = dataclasses.field( - default=None, metadata={"schema_property_name": "locations"} - ) - occurrence_count: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "occurrenceCount"} - ) - partial_fingerprints: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "partialFingerprints"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - provenance: Optional[_result_provenance.ResultProvenance] = dataclasses.field( - default=None, metadata={"schema_property_name": "provenance"} - ) - rank: float = dataclasses.field(default=-1.0, metadata={"schema_property_name": "rank"}) - related_locations: Optional[List[_location.Location]] = dataclasses.field( - default=None, metadata={"schema_property_name": "relatedLocations"} - ) - rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "rule"}) - ) - rule_id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "ruleId"} - ) - rule_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "ruleIndex"} - ) - stacks: Optional[List[_stack.Stack]] = dataclasses.field( - default=None, metadata={"schema_property_name": "stacks"} - ) - suppressions: Optional[List[_suppression.Suppression]] = dataclasses.field( - default=None, metadata={"schema_property_name": "suppressions"} - ) - taxa: Optional[List[_reporting_descriptor_reference.ReportingDescriptorReference]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "taxa"}) - ) - web_request: Optional[_web_request.WebRequest] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequest"} - ) - web_response: Optional[_web_response.WebResponse] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponse"} - ) - work_item_uris: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "workItemUris"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_result_provenance.py b/onnxscript/diagnostics/infra/sarif/_result_provenance.py deleted file mode 100644 index 74ea9e1e9f..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_result_provenance.py +++ /dev/null @@ -1,39 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _physical_location, _property_bag - - -@dataclasses.dataclass -class ResultProvenance: - """Contains information about how and when a result was detected.""" - - conversion_sources: Optional[List[_physical_location.PhysicalLocation]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "conversionSources"}) - ) - first_detection_run_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "firstDetectionRunGuid"} - ) - first_detection_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "firstDetectionTimeUtc"} - ) - invocation_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "invocationIndex"} - ) - last_detection_run_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "lastDetectionRunGuid"} - ) - last_detection_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "lastDetectionTimeUtc"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_run.py b/onnxscript/diagnostics/infra/sarif/_run.py deleted file mode 100644 index 8df4f9b577..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_run.py +++ /dev/null @@ -1,126 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _address, - _artifact, - _conversion, - _external_property_file_references, - _graph, - _invocation, - _logical_location, - _property_bag, - _result, - _run_automation_details, - _special_locations, - _thread_flow_location, - _tool, - _tool_component, - _version_control_details, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class Run: - """Describes a single run of an analysis tool, and contains the reported output of that run.""" - - tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"}) - addresses: Optional[List[_address.Address]] = dataclasses.field( - default=None, metadata={"schema_property_name": "addresses"} - ) - artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field( - default=None, metadata={"schema_property_name": "artifacts"} - ) - automation_details: Optional[_run_automation_details.RunAutomationDetails] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "automationDetails"}) - ) - baseline_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "baselineGuid"} - ) - column_kind: Optional[Literal["utf16CodeUnits", "unicodeCodePoints"]] = dataclasses.field( - default=None, metadata={"schema_property_name": "columnKind"} - ) - conversion: Optional[_conversion.Conversion] = dataclasses.field( - default=None, metadata={"schema_property_name": "conversion"} - ) - default_encoding: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "defaultEncoding"} - ) - default_source_language: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "defaultSourceLanguage"} - ) - external_property_file_references: Optional[ - _external_property_file_references.ExternalPropertyFileReferences - ] = dataclasses.field( - default=None, - metadata={"schema_property_name": "externalPropertyFileReferences"}, - ) - graphs: Optional[List[_graph.Graph]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphs"} - ) - invocations: Optional[List[_invocation.Invocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "invocations"} - ) - language: str = dataclasses.field( - default="en-US", metadata={"schema_property_name": "language"} - ) - logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "logicalLocations"} - ) - newline_sequences: List[str] = dataclasses.field( - default_factory=lambda: ["\r\n", "\n"], - metadata={"schema_property_name": "newlineSequences"}, - ) - original_uri_base_ids: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "originalUriBaseIds"} - ) - policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "policies"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - redaction_tokens: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "redactionTokens"} - ) - results: Optional[List[_result.Result]] = dataclasses.field( - default=None, metadata={"schema_property_name": "results"} - ) - run_aggregates: Optional[List[_run_automation_details.RunAutomationDetails]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "runAggregates"}) - ) - special_locations: Optional[_special_locations.SpecialLocations] = dataclasses.field( - default=None, metadata={"schema_property_name": "specialLocations"} - ) - taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "taxonomies"} - ) - thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "threadFlowLocations"} - ) - ) - translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "translations"} - ) - version_control_provenance: Optional[ - List[_version_control_details.VersionControlDetails] - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "versionControlProvenance"} - ) - web_requests: Optional[List[_web_request.WebRequest]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequests"} - ) - web_responses: Optional[List[_web_response.WebResponse]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponses"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_run_automation_details.py b/onnxscript/diagnostics/infra/sarif/_run_automation_details.py deleted file mode 100644 index f41dfcc284..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_run_automation_details.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class RunAutomationDetails: - """Information that describes a run's identity and role within an engineering system process.""" - - correlation_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "correlationGuid"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_sarif_log.py b/onnxscript/diagnostics/infra/sarif/_sarif_log.py deleted file mode 100644 index aa39c52f15..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_sarif_log.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import _external_properties, _property_bag, _run - - -@dataclasses.dataclass -class SarifLog: - """Static Analysis Results Format (SARIF) Version 2.1.0 JSON Schema: a standard format for the output of static analysis tools.""" - - runs: List[_run.Run] = dataclasses.field(metadata={"schema_property_name": "runs"}) - version: Literal["2.1.0"] = dataclasses.field(metadata={"schema_property_name": "version"}) - schema_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "$schema"} - ) - inline_external_properties: Optional[List[_external_properties.ExternalProperties]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "inlineExternalProperties"} - ) - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_special_locations.py b/onnxscript/diagnostics/infra/sarif/_special_locations.py deleted file mode 100644 index ee78979514..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_special_locations.py +++ /dev/null @@ -1,24 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag - - -@dataclasses.dataclass -class SpecialLocations: - """Defines locations of special significance to SARIF consumers.""" - - display_base: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "displayBase"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_stack.py b/onnxscript/diagnostics/infra/sarif/_stack.py deleted file mode 100644 index e250b75df4..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_stack.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag, _stack_frame - - -@dataclasses.dataclass -class Stack: - """A call stack that is relevant to a result.""" - - frames: List[_stack_frame.StackFrame] = dataclasses.field( - metadata={"schema_property_name": "frames"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_stack_frame.py b/onnxscript/diagnostics/infra/sarif/_stack_frame.py deleted file mode 100644 index 24d9fe8201..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_stack_frame.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _location, _property_bag - - -@dataclasses.dataclass -class StackFrame: - """A function call within a stack trace.""" - - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - module: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "module"} - ) - parameters: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "parameters"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - thread_id: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "threadId"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_suppression.py b/onnxscript/diagnostics/infra/sarif/_suppression.py deleted file mode 100644 index ae477178b0..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_suppression.py +++ /dev/null @@ -1,36 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Literal, Optional - -from onnxscript.diagnostics.infra.sarif import _location, _property_bag - - -@dataclasses.dataclass -class Suppression: - """A suppression that is relevant to a result.""" - - kind: Literal["inSource", "external"] = dataclasses.field( - metadata={"schema_property_name": "kind"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - justification: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "justification"} - ) - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - state: Optional[Literal["accepted", "underReview", "rejected"]] = dataclasses.field( - default=None, metadata={"schema_property_name": "state"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_thread_flow.py b/onnxscript/diagnostics/infra/sarif/_thread_flow.py deleted file mode 100644 index d3d1693677..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_thread_flow.py +++ /dev/null @@ -1,40 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _message, - _property_bag, - _thread_flow_location, -) - - -@dataclasses.dataclass -class ThreadFlow: - """Describes a sequence of code locations that specify a path through a single thread of execution such as an operating system or fiber.""" - - locations: List[_thread_flow_location.ThreadFlowLocation] = dataclasses.field( - metadata={"schema_property_name": "locations"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - immutable_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "immutableState"} - ) - initial_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "initialState"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py b/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py deleted file mode 100644 index 949c42d80e..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py +++ /dev/null @@ -1,63 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _location, - _property_bag, - _reporting_descriptor_reference, - _stack, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class ThreadFlowLocation: - """A location visited by an analysis tool while simulating or monitoring the execution of a program.""" - - execution_order: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "executionOrder"} - ) - execution_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "executionTimeUtc"} - ) - importance: Literal["important", "essential", "unimportant"] = dataclasses.field( - default="important", metadata={"schema_property_name": "importance"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - kinds: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "kinds"} - ) - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - module: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "module"} - ) - nesting_level: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "nestingLevel"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - stack: Optional[_stack.Stack] = dataclasses.field( - default=None, metadata={"schema_property_name": "stack"} - ) - state: Any = dataclasses.field(default=None, metadata={"schema_property_name": "state"}) - taxa: Optional[List[_reporting_descriptor_reference.ReportingDescriptorReference]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "taxa"}) - ) - web_request: Optional[_web_request.WebRequest] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequest"} - ) - web_response: Optional[_web_response.WebResponse] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponse"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_tool.py b/onnxscript/diagnostics/infra/sarif/_tool.py deleted file mode 100644 index 79589ddf77..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_tool.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag, _tool_component - - -@dataclasses.dataclass -class Tool: - """The analysis tool that was run.""" - - driver: _tool_component.ToolComponent = dataclasses.field( - metadata={"schema_property_name": "driver"} - ) - extensions: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "extensions"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_tool_component.py b/onnxscript/diagnostics/infra/sarif/_tool_component.py deleted file mode 100644 index 47925ed748..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_tool_component.py +++ /dev/null @@ -1,115 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _multiformat_message_string, - _property_bag, - _reporting_descriptor, - _tool_component_reference, - _translation_metadata, -) - - -@dataclasses.dataclass -class ToolComponent: - """A component, such as a plug-in or the driver, of the analysis tool that was run.""" - - name: str = dataclasses.field(metadata={"schema_property_name": "name"}) - associated_component: Optional[_tool_component_reference.ToolComponentReference] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "associatedComponent"} - ) - ) - contents: List[Literal["localizedData", "nonLocalizedData"]] = dataclasses.field( - default_factory=lambda: ["localizedData", "nonLocalizedData"], - metadata={"schema_property_name": "contents"}, - ) - dotted_quad_file_version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "dottedQuadFileVersion"} - ) - download_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "downloadUri"} - ) - full_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"}) - ) - full_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullName"} - ) - global_message_strings: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "globalMessageStrings"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - information_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "informationUri"} - ) - is_comprehensive: Optional[bool] = dataclasses.field( - default=None, metadata={"schema_property_name": "isComprehensive"} - ) - language: str = dataclasses.field( - default="en-US", metadata={"schema_property_name": "language"} - ) - localized_data_semantic_version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "localizedDataSemanticVersion"} - ) - locations: Optional[List[_artifact_location.ArtifactLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "locations"} - ) - minimum_required_localized_data_semantic_version: Optional[str] = dataclasses.field( - default=None, - metadata={"schema_property_name": "minimumRequiredLocalizedDataSemanticVersion"}, - ) - notifications: Optional[List[_reporting_descriptor.ReportingDescriptor]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "notifications"}) - ) - organization: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "organization"} - ) - product: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "product"} - ) - product_suite: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "productSuite"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - release_date_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "releaseDateUtc"} - ) - rules: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field( - default=None, metadata={"schema_property_name": "rules"} - ) - semantic_version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "semanticVersion"} - ) - short_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"}) - ) - supported_taxonomies: Optional[List[_tool_component_reference.ToolComponentReference]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "supportedTaxonomies"} - ) - ) - taxa: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field( - default=None, metadata={"schema_property_name": "taxa"} - ) - translation_metadata: Optional[_translation_metadata.TranslationMetadata] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "translationMetadata"} - ) - ) - version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py b/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py deleted file mode 100644 index 09cc2b9087..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py +++ /dev/null @@ -1,28 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class ToolComponentReference: - """Identifies a particular toolComponent object, either the driver or an extension.""" - - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_translation_metadata.py b/onnxscript/diagnostics/infra/sarif/_translation_metadata.py deleted file mode 100644 index f05125a599..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_translation_metadata.py +++ /dev/null @@ -1,40 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _multiformat_message_string, - _property_bag, -) - - -@dataclasses.dataclass -class TranslationMetadata: - """Provides additional metadata related to translation.""" - - name: str = dataclasses.field(metadata={"schema_property_name": "name"}) - download_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "downloadUri"} - ) - full_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"}) - ) - full_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullName"} - ) - information_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "informationUri"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - short_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_version_control_details.py b/onnxscript/diagnostics/infra/sarif/_version_control_details.py deleted file mode 100644 index f56498bb69..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_version_control_details.py +++ /dev/null @@ -1,37 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag - - -@dataclasses.dataclass -class VersionControlDetails: - """Specifies the information necessary to retrieve a desired revision from a version control system.""" - - repository_uri: str = dataclasses.field(metadata={"schema_property_name": "repositoryUri"}) - as_of_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "asOfTimeUtc"} - ) - branch: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "branch"} - ) - mapped_to: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "mappedTo"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - revision_id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "revisionId"} - ) - revision_tag: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "revisionTag"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_web_request.py b/onnxscript/diagnostics/infra/sarif/_web_request.py deleted file mode 100644 index b574882f9b..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_web_request.py +++ /dev/null @@ -1,43 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag - - -@dataclasses.dataclass -class WebRequest: - """Describes an HTTP request.""" - - body: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "body"} - ) - headers: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "headers"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - method: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "method"} - ) - parameters: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "parameters"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - protocol: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "protocol"} - ) - target: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "target"} - ) - version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_web_response.py b/onnxscript/diagnostics/infra/sarif/_web_response.py deleted file mode 100644 index 3753036ab1..0000000000 --- a/onnxscript/diagnostics/infra/sarif/_web_response.py +++ /dev/null @@ -1,43 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag - - -@dataclasses.dataclass -class WebResponse: - """Describes the response to an HTTP request.""" - - body: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "body"} - ) - headers: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "headers"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - no_response_received: Optional[bool] = dataclasses.field( - default=None, metadata={"schema_property_name": "noResponseReceived"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - protocol: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "protocol"} - ) - reason_phrase: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "reasonPhrase"} - ) - status_code: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "statusCode"} - ) - version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/version.py b/onnxscript/diagnostics/infra/sarif/version.py deleted file mode 100644 index 020a28bf76..0000000000 --- a/onnxscript/diagnostics/infra/sarif/version.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Final - -SARIF_VERSION: Final = "2.1.0" -SARIF_SCHEMA_LINK: Final = ( - "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json" -) -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/utils.py b/onnxscript/diagnostics/infra/utils.py deleted file mode 100644 index bc8f5f9c78..0000000000 --- a/onnxscript/diagnostics/infra/utils.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -import functools -import inspect -import traceback -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple - -from onnxscript._internal import runtime_typing -from onnxscript.diagnostics.infra import _infra, formatter - - -@runtime_typing.checked -def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame: - """Returns a StackFrame for the given traceback.FrameSummary.""" - snippet = frame.line - - return _infra.StackFrame( - location=_infra.Location( - uri=frame.filename, - line=frame.lineno, - snippet=snippet, - function=frame.name, - message=snippet, - ) - ) - - -@runtime_typing.checked -def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack: - """Returns the current Python call stack.""" - if frames_to_skip < 0: - raise ValueError("frames_to_skip must be non-negative") - if frames_to_log < 0: - raise ValueError("frames_to_log must be non-negative") - frames_to_skip += 2 # Skip this function and beartype. - stack = _infra.Stack() - # Frames are returned in order of oldest to newest. - frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log) - frames.reverse() - stack.frames = [python_frame(frame) for frame in frames[frames_to_skip:]] - stack.message = "Python call stack" - return stack - - -@functools.lru_cache -def _function_source_info(fn: Callable) -> Tuple[Sequence[str], int, Optional[str]]: - """Returns the source lines, line number, and source file path for the given function. - - Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined. - Caching is applied to reduce the performance impact of this function. - """ - source_lines, lineno = inspect.getsourcelines(fn) - return source_lines, lineno, inspect.getsourcefile(fn) - - -@runtime_typing.checked -def function_location(fn: Callable) -> _infra.Location: - """Returns a Location for the given function.""" - source_lines, lineno, uri = _function_source_info(fn) - snippet = source_lines[0].strip() if len(source_lines) > 0 else "" - return _infra.Location( - uri=uri, - line=lineno, - snippet=snippet, - message=formatter.display_name(fn), - ) - - -@runtime_typing.checked -def function_state( - fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any] -) -> Mapping[str, Any]: - bind = inspect.signature(fn).bind(*args, **kwargs) - return bind.arguments diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index a936824cab..1d87ee135e 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import abc @@ -22,7 +20,6 @@ import numpy as np import onnx import onnx.defs -import onnx.helper import onnx.reference from typing_extensions import TypeAlias @@ -292,16 +289,16 @@ def eval_function( has_array = False for arg, param_schema in tagged_args: if param_schema.is_input: - adapted_arg, _has_array = _adapt_to_eager_mode(arg) - has_array = has_array or _has_array + adapted_arg, has_array_ = _adapt_to_eager_mode(arg) + has_array = has_array or has_array_ adapted_args.append(adapted_arg) else: adapted_args.append(arg) for key, (arg, param_schema) in tagged_kwargs.items(): if param_schema.is_input: - adapted_arg, _has_array = _adapt_to_eager_mode(arg) - has_array = has_array or _has_array + adapted_arg, has_array_ = _adapt_to_eager_mode(arg) + has_array = has_array or has_array_ adapted_kwargs[key] = adapted_arg else: adapted_kwargs[key] = arg @@ -369,7 +366,7 @@ def _onnxscript_to_numpy_value(v): if isinstance(v, list): return [_onnxscript_to_numpy_value(x) for x in v] if isinstance(v, tuple): - if len(v) > 0 and type(v[0]) is int: # noqa: E721 # pylint: disable=unidiomatic-typecheck + if len(v) > 0 and type(v[0]) is int: # pylint: disable=unidiomatic-typecheck return np.array(v, dtype=np.int64) return np.array(v) if v is None: @@ -389,8 +386,10 @@ def _numpy_to_onnxscript_value( ): """Converts an ORT encoding of an ONNX value into the encoding used by onnxscript.""" if isinstance(v, np.ndarray): - return tensor.Tensor(v) - if np.issctype(type(v)): # noqa: NPY201 + # ORT may reuse buffers when the output numpy array is provided back as input. + # We need to make a copy to ensure that the tensor is not modified in-place. + return tensor.Tensor(v.copy()) + if issubclass(type(v), np.generic): # Numpy scalar types that are not ndarray # https://numpy.org/doc/stable/reference/arrays.scalars.html return tensor.Tensor(np.array(v)) @@ -421,7 +420,7 @@ def make_tensor_name() -> str: return f"attr_{key}" return autocast.pyvalue_to_onnx_attribute( - key, value, make_tensor_name, schema.attributes[key].type + key, value, make_tensor_name, int(schema.attributes[key].type) ) # Construct ONNX model with a single op call: @@ -430,21 +429,22 @@ def make_tensor_name() -> str: num_outputs = compute_num_outputs(schema, args, kwargs) outputs = [f"output{i}" for i in range(num_outputs)] - node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) + node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251 node.attribute.extend( make_attr(key, value) for key, value in kwargs.items() if value is not None ) input_value_infos = utils.values_to_value_infos(zip(inputs, args)) implicit_value_infos = utils.values_to_value_infos(implicit_args.items()) output_value_infos = [ - onnx.helper.make_value_info(name, onnx.TypeProto()) for name in outputs + onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251 + for name in outputs ] - graph = onnx.helper.make_graph( + graph = onnx.helper.make_graph( # noqa: TID251 [node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos ) - opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) - model = onnx.helper.make_model( + opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251 + model = onnx.helper.make_model( # noqa: TID251 graph, opset_imports=[opset_id], ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain), diff --git a/onnxscript/evaluator_test.py b/onnxscript/evaluator_test.py index a5ad41a78f..d42b1bab75 100644 --- a/onnxscript/evaluator_test.py +++ b/onnxscript/evaluator_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import numpy as np diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py index 37232c84eb..c5b87898c9 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import copy @@ -151,11 +153,9 @@ def __repr__(self): " Type Constraints: ", ] # Trick to get unique type constraints but maintain the order. - ordered_unique_type_constraints = { - v: None for v in self.input_type_constraints.values() - } + ordered_unique_type_constraints = dict.fromkeys(self.input_type_constraints.values()) ordered_unique_type_constraints.update( - {v: None for v in self.output_type_constraints.values()} + dict.fromkeys(self.output_type_constraints.values()) ) repr_strs += [ f" {type_constraint.name}: {type_constraint.type_strs}" @@ -175,9 +175,9 @@ def __repr__(self): repr_strs += [ " Intermediate Type Constraints: ", ] - ordered_unique_type_constraints = { - v: None for v in self.intermediate_type_constraints.values() - } + ordered_unique_type_constraints = dict.fromkeys( + self.intermediate_type_constraints.values() + ) repr_strs += [ f" {type_constraint.name}: {type_constraint.type_strs}" for type_constraint in ordered_unique_type_constraints @@ -210,15 +210,15 @@ def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConst ) # Rename type constraints to T0, T1, T2, ... - _seen_type_constraints: Set[TypeConstraint] = set() + seen_type_constraints: Set[TypeConstraint] = set() for type_constraint in ( *input_type_constraints.values(), *output_type_constraints.values(), *intermediate_type_constraints.values(), ): - if type_constraint is not None and type_constraint not in _seen_type_constraints: - type_constraint.name = f"T{len(_seen_type_constraints)}" - _seen_type_constraints.add(type_constraint) + if type_constraint is not None and type_constraint not in seen_type_constraints: + type_constraint.name = f"T{len(seen_type_constraints)}" + seen_type_constraints.add(type_constraint) return OnnxFunctionTypeConstraints( input_type_constraints, output_type_constraints, intermediate_type_constraints diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index 25586085ef..a8d15c242a 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Test cases for type constraint deduction functionality.""" from __future__ import annotations diff --git a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py index 44c3980668..eb2d8015a4 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py @@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None: functions[module_name] = {} op_name = get_op_name(func) if op_name in functions[module_name]: - logging.warning( + logging.warning( # noqa: LOG015 "Duplicated function: %s, overload: %s", op_name, func.func.name.overload_name ) continue diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index e96d24ed4a..ebbdd43bd8 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -258,7 +258,7 @@ def _get_func_schema_in_namespace(namespaces: List[_OpNamespace]) -> Dict[str, F # to "resize(Tensor a, SymInt[] shape) -> Tensor" if "!" in op_overload_packet.schema: op_overload_packet.schema = re.sub( # type: ignore[attr-defined] - "[(][A-Za-z]![)]", "", op_overload_packet.schema + r"[(][A-Za-z]![)]", "", op_overload_packet.schema ) # FIXME: remove below code if the issue below is fixed. @@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None: if module_name not in functions: functions[module_name] = {} if op_name in functions[module_name]: - logging.warning( + logging.warning( # noqa: LOG015 "Duplicated function: %s, overload: %s", op_name, func_schema.name.overload_name, diff --git a/onnxscript/function_libs/torch_lib/__init__.py b/onnxscript/function_libs/torch_lib/__init__.py index 4c4966c2b4..18e9054a6f 100644 --- a/onnxscript/function_libs/torch_lib/__init__.py +++ b/onnxscript/function_libs/torch_lib/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """A torch function library for onnxscript. The modules are named after the torch module names for grouping: diff --git a/onnxscript/function_libs/torch_lib/_constants.py b/onnxscript/function_libs/torch_lib/_constants.py index 58cc2c0680..f4e14061ec 100644 --- a/onnxscript/function_libs/torch_lib/_constants.py +++ b/onnxscript/function_libs/torch_lib/_constants.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Shared constants for the library.""" DOMAIN = "pkg.onnxscript.torch_lib" diff --git a/onnxscript/function_libs/torch_lib/_flags.py b/onnxscript/function_libs/torch_lib/_flags.py index 560cd5baaa..79593f3464 100644 --- a/onnxscript/function_libs/torch_lib/_flags.py +++ b/onnxscript/function_libs/torch_lib/_flags.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Experimental flags. NOTE: These flags are experimental only. Any flag here can be removed at any @@ -15,6 +17,7 @@ def _load_boolean_flag( *, this_will: str, deprecated: bool = False, + default: bool = False, ) -> bool: """Load a boolean flag from environment variable. @@ -22,7 +25,9 @@ def _load_boolean_flag( name: The name of the environment variable. this_will: A string that describes what this flag will do. deprecated: Whether this flag is deprecated. + default: The default value if envvar not defined. """ + undefined = os.getenv(name) is None state = os.getenv(name) == "1" if state: if deprecated: @@ -32,6 +37,8 @@ def _load_boolean_flag( ) else: logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will) + if undefined: + state = default return state @@ -42,8 +49,5 @@ def _load_boolean_flag( EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag( "TORCHLIB_EXPERIMENTAL_PREFER_TRACING", this_will="trace all traceable functions to fold if branches and collapse constant expressions", -) -EXPERIMENTAL_USE_IR: bool = _load_boolean_flag( - "TORCHLIB_EXPERIMENTAL_USE_IR", - this_will="use the ONNX IR instead of the PyTorch Graph for graph building", + default=True, ) diff --git a/onnxscript/function_libs/torch_lib/graph_building/__init__.py b/onnxscript/function_libs/torch_lib/graph_building/__init__.py index e70f7f4c27..70a35d729f 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/__init__.py +++ b/onnxscript/function_libs/torch_lib/graph_building/__init__.py @@ -1,35 +1,5 @@ -"""APIs for building an ONNX graph from a PyTorch model. - -This module exposes only three classes that will be used to build an ONNX graph -by the ONNX exporter in PyTorch: - -- :class:`TorchScriptTensor`: Represents a symbolic value in the ONNX graph. -- :class:`TorchScriptGraph`: Stores the graph being built. -- :class:`TorchScriptTracingEvaluator`: An evaluator that will record all operators - applied on the ``TorchScriptTensor``. It has a reference to the ``TorchScriptGraph`` - being built, will write to it, and will handle eager evaluations of Torch Lib - functions when desired. - -The usage is in https://github.com/pytorch/pytorch/blob/136f8378e1b5a8cb7127977b8d068fbf9c3e1247/torch/onnx/_internal/fx/fx_onnx_interpreter.py#L698-L702, -and it is very simple:: - - with onnxscript.evaluator.default_as(onnxscript_tracer): # onnxscript_tracer is a TorchScriptTracingEvaluator - output: Union[ - onnxscript_graph_building.TorchScriptTensor, - Tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ] = symbolic_fn(*onnx_args, **onnx_kwargs) - -Here, we set the default evaluator to be ``onnxscript_tracer`` so -that ONNX Script will dispatch all operators calls to the evaluator. The ``symbolic_fn`` -can be a pure Python function (e.g. trace-only) or an ONNX Script function. Either way, -they are recorded by ``onnxscript_tracer`` and onto the graph. - -The outputs, as ``TorchScriptTensor``, are then handed by to the exporter. On line -https://github.com/pytorch/pytorch/blob/136f8378e1b5a8cb7127977b8d068fbf9c3e1247/torch/onnx/_internal/fx/fx_onnx_interpreter.py#L707 -the exporter fills in type and shape information from PyTorch by calling the setters -on ``TorchScriptTensor.dtype`` and ``TorchScriptTensor.shape``. -""" - +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations __all__ = [ @@ -38,17 +8,17 @@ "TorchScriptTracingEvaluator", ] -from onnxscript.function_libs.torch_lib import _flags -if _flags.EXPERIMENTAL_USE_IR: - from ._graph_building_ir import ( - TorchScriptGraph, - TorchScriptTensor, - TorchScriptTracingEvaluator, - ) -else: - from ._graph_building_torch import ( # type: ignore[assignment] - TorchScriptGraph, - TorchScriptTensor, - TorchScriptTracingEvaluator, - ) +class _RemovedClass: + """A onnxscript tensor that wraps a torchscript Value.""" + + def __init__(self, *_, **__): + raise NotImplementedError( + "Support for dynamo_export has been dropped since onnxscript 0.4.0. " + "Please use `torch.onnx.export(..., dynamo=True)`, or downgrade to onnxscript<0.4" + ) + + +TorchScriptTensor = _RemovedClass +TorchScriptGraph = _RemovedClass +TorchScriptTracingEvaluator = _RemovedClass diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py deleted file mode 100644 index a26a612ba8..0000000000 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py +++ /dev/null @@ -1,755 +0,0 @@ -"""Graph building functions using the ONNX IR, compatible with the original TorchScriptGraph usage.""" - -from __future__ import annotations - -import ctypes -from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union - -import numpy as np -import onnx -import onnx.checker -import onnx.defs -import onnx.helper -import onnx.shape_inference -import torch -from typing_extensions import TypeAlias - -import onnxscript -from onnxscript import evaluator, ir -from onnxscript import tensor as onnxscript_tensor -from onnxscript._internal import param_manipulation -from onnxscript.function_libs.torch_lib import _flags -from onnxscript.function_libs.torch_lib.ops import common as common_ops - -__all__ = [ - "TorchScriptTensor", - "TorchScriptGraph", - "TorchScriptTracingEvaluator", -] - - -ValidArgumentType: TypeAlias = Union[ - "TorchScriptTensor", - Sequence["TorchScriptTensor"], - Sequence[float], - Sequence[int], - complex, - str, - int, - float, - bool, - None, -] -ValidInputType: TypeAlias = Union[ - "TorchScriptTensor", - Sequence["TorchScriptTensor"], - Sequence[float], - Sequence[int], - complex, - str, - int, - float, - bool, - None, -] - -_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { - torch.bfloat16: ir.DataType.BFLOAT16, - torch.bool: ir.DataType.BOOL, - torch.complex128: ir.DataType.COMPLEX128, - torch.complex64: ir.DataType.COMPLEX64, - torch.float16: ir.DataType.FLOAT16, - torch.float32: ir.DataType.FLOAT, - torch.float64: ir.DataType.DOUBLE, - torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, - torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, - torch.float8_e5m2: ir.DataType.FLOAT8E5M2, - torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, - torch.int16: ir.DataType.INT16, - torch.int32: ir.DataType.INT32, - torch.int64: ir.DataType.INT64, - torch.int8: ir.DataType.INT8, - torch.uint8: ir.DataType.UINT8, -} - - -def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: - return _TORCH_DTYPE_TO_ONNX[dtype] - - -class _TorchTensor(ir.Tensor): # pylint: disable=too-many-ancestors - def __init__(self, tensor: torch.Tensor): - super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype)) - - def tobytes(self) -> bytes: - # Support native PyTorch types so we can use types like bloat16 - assert isinstance(self.raw, torch.Tensor) - tensor = self.raw.detach().cpu().contiguous() - return bytes( - (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( - tensor.data_ptr() - ) - ) - - -class TorchScriptTensor(ir.Value, onnxscript_tensor.Tensor): - """A onnxscript tensor that wraps a torchscript Value.""" - - def __init__( - self, - _=None, # Unused argument for backward compatibility - producer=None, - index=None, - name: str | None = None, - ): - onnxscript_tensor.Tensor.__init__(self, None) - ir.Value.__init__(self, producer, index=index, name=name) - self._is_complex: bool = False - self._concrete_value: np.ndarray | None = None - self._device: torch.device | None = None - - @property - def value(self) -> Optional[np.ndarray]: - return self._concrete_value - - @value.setter - def value(self, value: np.ndarray) -> None: - self._concrete_value = value - - @property # type: ignore[override] - def rank(self) -> int | None: - if self.shape is None: - return None - return len(self.shape) - - @property # type: ignore[override] - def shape(self) -> ir.Shape | None: - return super().shape - - @shape.setter - def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]): - # Normalize torch symbolic dimension size to str. - torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool) - self._shape = ir.Shape( - tuple(str(dim.node) if isinstance(dim, torch_sym_types) else dim for dim in shape) # type: ignore[union-attr] - ) - - @property - def dtype(self) -> ir.DataType | None: - return super().dtype - - @dtype.setter - def dtype(self, dtype: torch.dtype | ir.DataType | None): - if dtype is None: - onnx_dtype = ir.DataType.UNDEFINED - elif isinstance(dtype, ir.DataType): - onnx_dtype = dtype - else: - onnx_dtype = _torch_dtype_to_onnx_dtype(dtype) - if self._type is None: - self._type = ir.TensorType(onnx_dtype) - else: - self._type.dtype = onnx_dtype - - # TODO: Remove this when there is no mismatch output shapes between device: - # https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457 - @property - def device(self) -> torch.device | None: - return self._device - - @device.setter - def device(self, device: torch.device): - self._device = device - - @property - def is_complex(self) -> bool: - return self._is_complex - - @is_complex.setter - def is_complex(self, is_complex: bool): - self._is_complex = is_complex - - @property - def onnx_dtype(self) -> int: - raise NotImplementedError("onnx_dtype is not supported for TorchScriptTensor.") - - def value_info(self) -> Optional[onnx.ValueInfoProto]: - raise NotImplementedError("value_info is not supported for TorchScriptTensor.") - - -class _Node(ir.Node): - """A node that will produce TorchScriptTensor as outputs for compatibility.""" - - def __init__( - self, - domain: str, - op_type: str, - inputs: Sequence[ir.Value | None], - attributes: Sequence[ir.Attr | ir.RefAttr] = (), - *, - overload: str = "", - num_outputs: int = 1, - version: int | None = None, - name: str | None = None, - doc_string: str | None = None, - ): - super().__init__( - domain=domain, - op_type=op_type, - inputs=inputs, - attributes=attributes, - overload=overload, - num_outputs=num_outputs, - version=version, - name=name, - doc_string=doc_string, - ) - self._outputs: tuple[TorchScriptTensor, ...] = tuple( - TorchScriptTensor(producer=self, index=i) for i in range(num_outputs) - ) - - @property # type: ignore[misc] - def outputs(self) -> Sequence[TorchScriptTensor]: - return self._outputs - - -class TorchScriptTracingEvaluator(evaluator.Evaluator): - """An onnxscript Evaluator that captures the graph.""" - - def __init__(self, graph: TorchScriptGraph): - self._graph: TorchScriptGraph = graph - - @property - def graph(self) -> TorchScriptGraph: - return self._graph - - def eval(self, schema, inputs: Sequence[ValidInputType], attributes): - return self._graph.add_op_call(schema, inputs, attributes) - - def eval_function( # type: ignore[override] - self, - function: onnxscript.OnnxFunction, - args: Sequence[ValidArgumentType], - kwargs: Mapping[str, ValidArgumentType], - ): - if _flags.EXPERIMENTAL_PREFER_TRACING: - # Special cases for handling IsScalar and Rank - if function.name == "IsScalar": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], TorchScriptTensor): - if args[0].rank is not None: - return args[0].rank == 0 - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): # noqa: SIM103 - return False - else: - # Python constants are scalars - return True - if function.name == "Rank": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], TorchScriptTensor): - if args[0].rank is not None: - return args[0].rank - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): - if all(isinstance(arg, (int, float)) for arg in args[0]): - return 1 - else: - # Fall to call add_function_call - pass - else: - # Python constants are scalars - return 0 - elif function.experimental_traceable: - # Trace the function call instead of adding the function as a node - return function.function(*args, **kwargs) - - # args/kwargs are TorchScriptTensor/python built-in based - param_schemas = function.param_schemas() - ( - inputs, - attributes, - ) = param_manipulation.separate_input_attributes_from_arguments( - param_schemas, args, kwargs, fill_defaults=True, allow_extra_kwargs=True - ) - - # Cast attributes to the correct type based on function signature - op_schema = function.op_schema - assert op_schema is not None - for name, value in attributes.items(): - attribute = op_schema.attributes[name] - if attribute.type == onnx.defs.OpSchema.AttrType.FLOAT: - # Cast int to float if the attribute is FLOAT - attributes[name] = float(value) - - # In PyTorch, an attribute annotated as `int[1]?` accepts an integer - # or a sequence. When the attribute is an integer, it is treated as - # a single element sequence. ONNX requires an attribute to either be - # an integer or a sequence. So we promote the value to a sequence here. - if attribute.type == onnx.defs.OpSchema.AttrType.INTS and isinstance(value, int): - attributes[name] = (value,) - if attribute.type == onnx.defs.OpSchema.AttrType.FLOATS and isinstance( - value, float - ): - attributes[name] = (value,) - return self._graph.add_function_call(function, inputs, attributes) - - -def _build_attribute( - key: str, - value: Union[ - float, - int, - str, - Sequence[float], - Sequence[int], - torch.Tensor, - _TorchTensor, - ir.TensorProtocol, - ], -): - """Initializes the right attribute based on type of value.""" - if isinstance(value, float): - return ir.AttrFloat32(key, value) - if isinstance(value, int): - return ir.AttrInt64(key, value) - if isinstance(value, str): - return ir.AttrString(key, value) - if isinstance(value, torch.Tensor): - return ir.AttrTensor(key, _TorchTensor(value)) - if isinstance(value, (_TorchTensor, ir.TensorProtocol)): - return ir.AttrTensor(key, value) - if isinstance(value, Sequence): - if not value: - # Treat empty sequences as empty list tensors - # TODO(justinchuby): Revisit ways to determine the type of the empty list - return ir.AttrInt64s(key, []) - if isinstance(value[0], float): - return ir.AttrFloat32s(key, list(value)) # type: ignore[arg-type] - if isinstance(value[0], int): - return ir.AttrInt64s(key, list(value)) # type: ignore - raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'") - raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'") - - -def _create_op_call_in_graph( - graph: ir.Graph, - domain: str, - op_type: str, - *, - inputs: Sequence[TorchScriptTensor], - attributes: Mapping[str, Any], - num_outputs: int = 1, -) -> Sequence[TorchScriptTensor]: - """Creates a node representing an onnx op in `graph`. - - Args: - graph: The torch graph to add the node to. - domain: The domain of the op. - op_type: The name of the op. E.g. "Add". - inputs: The onnx inputs to the op. - attributes: The onnx attributes to the op. - num_outputs: The number of outputs the op has. - - Returns: - The outputs of the created node. - """ - # Filter out None attributes, this can be convenient client side because - # now they can pass through None attributes, and have them not show up - attributes = {k: v for k, v in attributes.items() if v is not None} - - node = _Node( - domain, - op_type, - inputs=inputs, - attributes=[_build_attribute(key, value) for key, value in attributes.items()], - num_outputs=num_outputs, - ) - graph.append(node) - - return node.outputs - - -def _shared_functions() -> list[ir.Function]: - """Hack to always include the share ops.""" - - # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed - return [ - ir.serde.deserialize_function(common_ops.Rank.to_function_proto()), - ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto()), - ] - - -class TorchScriptGraph: - def __init__( - self, - parent_torch_script_graph: Optional[TorchScriptGraph] = None, - domain_name: Optional[str] = None, - ): - self._graph = ir.Graph((), (), nodes=(), name="main_graph") - # All the functions used, deduplicated by name - # key: (name, domain) - self._function_store: Dict[ir.OperatorIdentifier, ir.Function] = {} - self._initializers: Dict[str, torch.Tensor] = {} - # Mapping from initializer name to input(TorchScriptTensor). - self._initializers_inputs: Dict[str, TorchScriptTensor] = {} - # Mapping from initializer name to input(TorchScriptTensor) from parent graph. - self._initializers_inputs_from_parent: Dict[str, TorchScriptTensor] = {} - # Mapping from model local function type name to function graph. - # Local function type name is expected to be unique. Converter creates - # a unique name and a unique function graph for every module call. - self._sub_torch_script_graphs: Dict[str, TorchScriptGraph] = {} - # Parent graph. None if this is the top level graph. - self._parent_torch_script_graph = parent_torch_script_graph - # Domain name of the graph. None if this is the top level graph. - self._domain_name: Optional[str] = domain_name - - if self._domain_name is None and self._parent_torch_script_graph is not None: - raise RuntimeError( - "Domain name is not set. It is required because this 'TorchScriptGraph' instance " - "is a subgraph that represents an ONNX local function." - ) - - @property - def initializers(self) -> Mapping[str, torch.Tensor]: - return self._initializers - - # NOTE: This setter is used in torch converter when we activate fake mode, - # we need to filter out the initializers that has fake tensor. This - # is because we don't want to introduce fake tensor in onnxscript. - @initializers.setter - def initializers(self, initializers: Dict[str, torch.Tensor]): - self._initializers = initializers - - @property - def initializers_inputs(self) -> Mapping[str, TorchScriptTensor]: - return self._initializers_inputs - - @property - def initializers_inputs_from_parent(self) -> Mapping[str, TorchScriptTensor]: - return self._initializers_inputs_from_parent - - @property - def num_outputs(self) -> int: - return len(self._graph.outputs) - - @property - def domain_name(self) -> Optional[str]: - return self._domain_name - - def add_input( - self, - input_name: Optional[str], - shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - ) -> TorchScriptTensor | None: - if input_name is None: - # This input argument is None, which is mapped - # to a NULL value in TorchScript type system. - value = None - else: - value = TorchScriptTensor(name=input_name) - value.shape = shape # type: ignore[arg-type,assignment] - value.device = device - if dtype is not None: - value.dtype = dtype # type: ignore[assignment] - # TODO(titaiwang): This approach loses the information that "same SymInts - # indicates same shape", for example, [symint0, symint0, symint1] - # would all be [None, None, None] - # torch_value.setType( - # torch_value.type().with_sizes( - # [dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr] - # ) - # ) - self._graph.inputs.append(value) # type: ignore[arg-type] - return value - - def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor: - if name in self._initializers_inputs: - # NOTE: Previously it raises when `name` is already set. This is relaxed - # because this will be invoked multiple times when submodule is called - # multiple times. - if name in self._initializers and self._initializers[name] is not value: - raise ValueError( - f"Initializer '{name}' exists already with a different value." - ) - return self._initializers_inputs[name] # type: ignore[return-value] - - if ( - self != self._parent_torch_script_graph - and self._parent_torch_script_graph is not None - ): - # Only the root graph can have initializers. Add as initializer - # to root graph, and add as input to current graph. - self._initializers_inputs_from_parent[name] = ( - self._parent_torch_script_graph.add_initializer(name, value) - ) - else: - input = TorchScriptTensor(name=name) - self._initializers_inputs[name] = input - self._initializers[name] = value - return input - - def register_outputs( - self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]] - ): - if isinstance(outputs, TorchScriptTensor): - outputs = (outputs,) - for output in outputs: - assert isinstance( - output, TorchScriptTensor - ), f"output must be a TorchScriptTensor, not {type(output)}" - self._graph.outputs.append(output) - - def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]: - """Add a constant to the graph. - - Returns: - A single element of sequence of the constant value. - """ - if constant is None: - return (None,) - - if isinstance(constant, bool): - # Be sure to put bool before int, because bool is a subclass of int - constant_tensor = torch.tensor(constant, dtype=torch.bool) - elif isinstance(constant, float): - constant_tensor = torch.tensor(constant, dtype=torch.float) - elif isinstance(constant, int): - constant_tensor = torch.tensor(constant, dtype=torch.int64) - elif isinstance(constant, (tuple, list)) and all( - isinstance(val, int) for val in constant - ): - constant_tensor = torch.tensor(constant, dtype=torch.int64) - elif isinstance(constant, (tuple, list)) and all( - isinstance(val, float) for val in constant - ): - constant_tensor = torch.tensor(constant, dtype=torch.float) - elif isinstance(constant, complex): - # NOTE: ONNX doesn't support tensor of complex64/complex128, so we - # convert them to float32/float64 with real representation. - constant_tensor = torch.view_as_real(torch.tensor(constant).resolve_conj()) - else: - raise TypeError( - f"Constant input '{constant}' of type '{type(constant)}' is not supported" - ) - onnx_tensor = _TorchTensor(constant_tensor) - value = _create_op_call_in_graph( - self._graph, - "", - "Constant", - inputs=(), - attributes=dict(value=onnx_tensor), - ) - return value - - def _add_ir_graph_op_call( - self, - *, - domain: str, - op_type: str, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - num_outputs: int, - ) -> Sequence[TorchScriptTensor]: - graph_inputs: list[TorchScriptTensor] = [] - assert isinstance(onnx_inputs, Sequence) - for input in onnx_inputs: - # NOTE(titaiwang): input could be empty list - if ( - isinstance(input, Sequence) - and input - and all(isinstance(elem, TorchScriptTensor) for elem in input) - ): - # If all elements in the Sequence are TorchScriptTensor we know it - # should be a Sequence input in ONNX. - input_sequence = _create_op_call_in_graph( - self._graph, - "", - "SequenceConstruct", - inputs=input, # type: ignore - attributes={}, - ) - graph_inputs.extend(input_sequence) - elif not isinstance(input, TorchScriptTensor): - graph_inputs.extend(self._add_constant_to_graph(input)) # type: ignore - else: - # TODO(justinchuby): What is this case? - graph_inputs.append(input) - for key, value in onnx_attributes.items(): - assert not isinstance( - value, TorchScriptTensor - ), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}." - tensors = _create_op_call_in_graph( - self._graph, - domain, - op_type, - inputs=graph_inputs, - attributes=onnx_attributes, - num_outputs=num_outputs, - ) - assert tensors, "Expected at least one output from ONNX op call." - # NOTE: TorchScriptTensor is created here, however neither dtype nor shape is - # set. It is expected that exporter will modify the tensor being returned and - # set these info. - return tensors - - def _fetch_function_dict( - self, opset_version: int - ) -> Mapping[ir.OperatorIdentifier, ir.Function]: - function_dict: Dict[ir.OperatorIdentifier, ir.Function] = {} - # Fetch local function protos. E.g., local functions representing module calls. - for ( - sub_graph_name, - sub_torch_script_graph, - ) in self._sub_torch_script_graphs.items(): - function_dict.update(sub_torch_script_graph._fetch_function_dict(opset_version)) # pylint: disable=protected-access - domain = sub_torch_script_graph.domain_name - assert domain is not None - name_domain = (sub_graph_name, domain, "") - assert ( - name_domain not in function_dict - ), f"Sub graph name already exists. {name_domain}" - function_dict[name_domain] = sub_torch_script_graph._to_function( # pylint: disable=protected-access - opset_version, sub_graph_name - ) - # Fetch torchlib function protos. - for identifier, function in self._function_store.items(): - function_dict[identifier] = function - return function_dict - - def add_op_call( - self, - onnx_op_schema: onnx.defs.OpSchema, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Sequence[TorchScriptTensor]]: - # Compute outputs from the onnx_op op schema - num_outputs = evaluator.compute_num_outputs( - onnx_op_schema, onnx_inputs, onnx_attributes - ) - result = self._add_ir_graph_op_call( - domain="", - op_type=onnx_op_schema.name, - onnx_inputs=onnx_inputs, - onnx_attributes=onnx_attributes, - num_outputs=num_outputs, - ) - - if num_outputs == 1: - return result[0] - - return result - - def add_function_call( - self, - onnx_function: onnxscript.OnnxFunction, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Sequence[TorchScriptTensor]]: - ir_function = ir.serde.deserialize_function(onnx_function.to_function_proto()) - self._function_store[ir_function.identifier()] = ir_function - num_outputs = len(onnx_function.function_ir.outputs) - # Compute outputs from the function schema - result = self._add_ir_graph_op_call( - domain=ir_function.domain, - op_type=ir_function.name, - onnx_inputs=onnx_inputs, - onnx_attributes=onnx_attributes, - num_outputs=num_outputs, - ) - - if num_outputs == 1: - return result[0] - - return result - - def add_module_call( - self, - name: str, - sub_torch_script_graph: TorchScriptGraph, - onnx_inputs: Sequence[ValidInputType], - ) -> Union[TorchScriptTensor, Sequence[TorchScriptTensor]]: - self._sub_torch_script_graphs[name] = sub_torch_script_graph - domain_name = sub_torch_script_graph.domain_name - assert domain_name is not None - - num_outputs = sub_torch_script_graph.num_outputs - result = self._add_ir_graph_op_call( - domain=domain_name, - op_type=name, - onnx_inputs=( - *onnx_inputs, - *sub_torch_script_graph.initializers_inputs_from_parent.values(), - ), - onnx_attributes={}, - num_outputs=num_outputs, - ) - - if num_outputs == 1: - return result[0] - - return result - - def _to_function(self, opset_version: int, function_name: str) -> ir.Function: - assert len(self.initializers) == 0, "Model local functions cannot have initializers." - - # Dissect the model proto and transform to function proto. - domain = self.domain_name - if domain is None: - raise RuntimeError("Domain name is not set.") - onnx_function = ir.Function( - domain=domain, - name=function_name, - graph=self._graph, - attributes=(), - ) - onnx_function.opset_imports[""] = opset_version - - return onnx_function - - def to_model_proto( - self, opset_version: int, include_initializers: bool = True - ) -> onnx.ModelProto: - function_dict: Mapping[ir.OperatorIdentifier, ir.Function] = self._fetch_function_dict( - opset_version - ) - unique_custom_domains: Dict[str, int] = {"": opset_version} - - for function in function_dict.values(): - # TODO(BowenBao): All local function domain versions are hardcoded as 1. - unique_custom_domains[function.domain] = 1 - - if include_initializers: - self._graph.initializers.update( - {name: _TorchTensor(value) for name, value in self._initializers.items()} - ) - else: - self._graph.initializers.clear() - - onnx_model = ir.Model( - self._graph, - ir_version=8, - producer_name=f"pytorch {torch.__version__}", - functions=[*function_dict.values(), *_shared_functions()], - ) - - onnx_model.opset_imports.update(unique_custom_domains) - # Include the library shared opset domain - # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed - onnx_model.opset_imports[common_ops.common_opset.domain] = ( - common_ops.common_opset.version - ) - model_proto = ir.serde.serialize_model(onnx_model) - return model_proto diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py deleted file mode 100644 index c07ba3ce81..0000000000 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ /dev/null @@ -1,1093 +0,0 @@ -"""Graph building functions for torchscript graph backend.""" - -from __future__ import annotations - -import os -import tempfile -import typing -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union - -import numpy as np -import onnx -import onnx.checker -import onnx.defs -import onnx.helper -import onnx.shape_inference -import torch -from typing_extensions import TypeAlias - -import onnxscript -from onnxscript import evaluator -from onnxscript import tensor as onnxscript_tensor -from onnxscript._internal import param_manipulation, runtime_typing -from onnxscript.function_libs.torch_lib import _flags -from onnxscript.function_libs.torch_lib.ops import common as common_ops - -__all__ = [ - "TorchScriptTensor", - "TorchScriptGraph", - "TorchScriptTracingEvaluator", -] - - -ValidArgumentType: TypeAlias = Union[ - "TorchScriptTensor", - Sequence["TorchScriptTensor"], - Sequence[float], - Sequence[int], - complex, - str, - int, - float, - bool, - None, -] -ValidInputType: TypeAlias = Union[ - "TorchScriptTensor", - Sequence["TorchScriptTensor"], - Sequence[float], - Sequence[int], - complex, - str, - int, - float, - bool, - None, -] -ValidTorchValueType: TypeAlias = Union[ - torch.Value, - Sequence[torch.Value], - Sequence[float], - Sequence[int], - complex, - str, - int, - float, - bool, - None, -] - -# Be sure to leave ample room for the rest of the proto fields. -_LARGE_MODEL_SIZE_THRESHOLD = int(2**30 * 1.8) # 1.8GB - -# TODO(justinchuby): Build a context manager to handle source information. - - -def _rename_intermediate_value(name: str) -> str: - """Prepend `_val_` to a numeric tensor name make it valid in ONNX. - - The TorchScript graph creates numeric value names by default. e.g. `0`, `1`. - These are not legal ONNX tensor names, since ONNX requires the names to be valid - C variable names. - - It also improves readability by making the names less likely to be confused - with shape values. - """ - if name.isdigit(): - # Prefix with `_` to avoid name collision - return f"_val_{name}" - return name - - -def _function_id(domain: str | None, name: str) -> str: - """Create a unique function id for a function in a domain. - - Used for generating model level unique ids for values inside a function. - """ - return f"{domain if domain is not None else ''}::{name}" - - -class TorchScriptTensor(onnxscript_tensor.Tensor): - """A onnxscript tensor that wraps a torchscript Value.""" - - def __init__( - self, - value: torch.Value, - ): - super().__init__(None) - self._torch_value: torch.Value = value - self._concrete_value: Optional[np.ndarray] = None - self._shape: Optional[Tuple[int | str | None, ...]] = None - self._torch_dtype: Optional[torch.dtype] = None - self._name: Optional[str] = None - self._is_complex: bool = False - self._device: Optional[torch.device] = None - - def __repr__(self): - return f"TorchScriptTensor('{self._torch_value!r}')" - - @property # type: ignore[override] - def value(self) -> Optional[np.ndarray]: - return self._concrete_value - - @value.setter - def value(self, value: np.ndarray): - self._concrete_value = value - - @property - @runtime_typing.checked - def name(self) -> str: - if self._name is not None: - return self._name - return self._torch_value.debugName() - - @name.setter - @runtime_typing.checked - def name(self, name: str): - self._name = name - self._torch_value.setDebugName(name) - - @property # type: ignore[override] - def rank(self) -> int | None: - if self._shape is not None: - return len(self._shape) - - value_type = self._torch_value.type() - if value_type is None: - return None - value_type = typing.cast(torch.TensorType, value_type) - return value_type.dim() - - @property # type: ignore[override] - def shape(self) -> Tuple[int | str | None, ...] | None: - if self._shape is not None: - return self._shape - - value_type = self._torch_value.type() - if value_type is None: - return None - value_type = typing.cast(torch.TensorType, value_type) - if isinstance(value_type, torch.OptionalType): - shape = value_type.getElementType().varyingSizes() # type: ignore[attr-defined] - else: - shape = value_type.varyingSizes() - if shape is None: - return None - return tuple(shape) - - @shape.setter - def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]): - # Normalize torch symbolic dimension size to str. - torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool) - self._shape = tuple( - str(dim.node) if isinstance(dim, torch_sym_types) else dim # type: ignore[union-attr] - for dim in shape - ) - # jit api does not support assigning symbolic shapes, - # hence symbols are replaced as None. - jit_shape = tuple(dim if isinstance(dim, int) else None for dim in shape) - self._torch_value.setType(self._torch_value.type().with_sizes(list(jit_shape))) - - @property # type: ignore[override] - def dtype(self) -> torch.dtype | None: - # TODO: Return numpy dtype - if self._torch_dtype is not None: - return self._torch_dtype - # Local import to avoid circular dependency - from torch.onnx import _type_utils # pylint: disable=import-outside-toplevel - - torch_dtype = _type_utils.JitScalarType.from_value( # type: ignore[attr-defined] - self._torch_value, default=_type_utils.JitScalarType.UNDEFINED - ) - if torch_dtype == _type_utils.JitScalarType.UNDEFINED: - return None - self._torch_dtype = torch_dtype.dtype() - return self._torch_dtype - - @dtype.setter - def dtype(self, dtype: torch.dtype): - self._torch_dtype = dtype - self._torch_value.setType(self._torch_value.type().with_dtype(dtype)) - - @property - def is_complex(self) -> bool: - return self._is_complex - - @is_complex.setter - def is_complex(self, is_complex: bool): - self._is_complex = is_complex - - # TODO: Remove this when there is no mismatch output shapes between device: - # https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457 - @property - def device(self) -> torch.device | None: - return self._device - - @device.setter - def device(self, device: torch.device): - self._device = device - - @property - def onnx_dtype(self): - # Local import to avoid circular dependency - from torch.onnx import _type_utils # pylint: disable=import-outside-toplevel - - return _type_utils.JitScalarType.from_value( # type: ignore[attr-defined] - self._torch_value, _type_utils.JitScalarType.UNDEFINED - ).onnx_type() - - def symbolic_value(self) -> torch.Value: - """The symbolic Value in torch.Graph.""" - return self._torch_value - - def value_info(self) -> Optional[onnx.ValueInfoProto]: - try: - dtype = self.onnx_dtype.value - except torch.onnx.errors.OnnxExporterError: - return None - if dtype == onnx.TensorProto.UNDEFINED: - return None - return onnx.helper.make_tensor_value_info(self.name, dtype, self.shape) - - -@runtime_typing.checked -def _unwrap_tensor_to_torch_value( - value: Union[ - ValidArgumentType, Mapping[str, ValidArgumentType], Sequence[ValidArgumentType] - ], -) -> Union[ - ValidTorchValueType, - Dict[str, ValidTorchValueType], - List[ValidTorchValueType], - Tuple[ValidTorchValueType, ...], -]: - """Unwrap the TorchScriptTensor to torch.Value.""" - if isinstance(value, TorchScriptTensor): - return value.symbolic_value() - if isinstance(value, dict): - return {k: _unwrap_tensor_to_torch_value(v) for k, v in value.items()} # type: ignore[misc,return-value] - if isinstance(value, list): - return [_unwrap_tensor_to_torch_value(v) for v in value] # type: ignore[misc,return-value] - if isinstance(value, tuple): - return tuple(_unwrap_tensor_to_torch_value(v) for v in value) # type: ignore[misc,return-value] - - # A normal python value - return value # type: ignore[return-value] - - -@runtime_typing.checked -def _wrap_torch_value_to_tensor( - value: Union[ - torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType] - ], - *, - shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, -) -> Union[ - ValidArgumentType, - Dict[str, ValidArgumentType], - List[ValidArgumentType], - Tuple[ValidArgumentType, ...], -]: - """Wrap torch.Value to TorchScriptTensor.""" - if isinstance(value, torch.Value): - tensor = TorchScriptTensor(value) - if shape is not None: - tensor.shape = shape - if dtype is not None: - tensor.dtype = dtype - if device is not None: - tensor.device = device - return tensor - if isinstance(value, dict): - return {k: _wrap_torch_value_to_tensor(v) for k, v in value.items()} # type: ignore[misc,return-value] - if isinstance(value, list): - return [_wrap_torch_value_to_tensor(v) for v in value] # type: ignore[misc,return-value] - if isinstance(value, tuple): - return tuple(_wrap_torch_value_to_tensor(v) for v in value) # type: ignore[misc,return-value] - - return value # type: ignore[return-value] - - -def _unwrap_tensors_to_torch_values(tensors): - # TODO(justinchuby): Do we really need this? - if isinstance(tensors, Sequence): - return [_unwrap_tensor_to_torch_value(output) for output in tensors] - return _unwrap_tensor_to_torch_value(tensors) - - -class TorchScriptTracingEvaluator(evaluator.Evaluator): - """An onnxscript Evaluator that captures the graph into torchscript.""" - - def __init__(self, graph: TorchScriptGraph): - self._graph: TorchScriptGraph = graph - - @property - def graph(self) -> TorchScriptGraph: - return self._graph - - def eval(self, schema, inputs, attributes): - if _flags.EXPERIMENTAL_PREFER_TRACING: - if schema.name == "CastLike": - assert len(inputs) == 2 - # Skip CastLike if the input and output types are the same - src_input = inputs[0] - target_input = inputs[1] - dtypes_available = ( - isinstance(src_input, TorchScriptTensor) - and isinstance(target_input, TorchScriptTensor) - and src_input.dtype is not None - and target_input.dtype is not None - ) - if dtypes_available: - if src_input.dtype == target_input.dtype: - # Same type. No cast needed - return src_input - else: - # Create a Cast node - return self._graph.add_op_call( - onnx.defs.get_schema("Cast"), - (src_input,), - {"to": target_input.onnx_dtype}, - ) - return self._graph.add_op_call(schema, inputs, attributes) - - @runtime_typing.checked - def eval_function( # type: ignore[override] - self, - function: onnxscript.OnnxFunction, - args: Sequence[ValidArgumentType], - kwargs: Mapping[str, ValidArgumentType], - ): - if _flags.EXPERIMENTAL_PREFER_TRACING: - # Special cases for handling IsScalar and Rank - if function.name == "IsScalar": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], TorchScriptTensor): - if args[0].rank is not None: - return args[0].rank == 0 - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): # noqa: SIM103 - return False - else: - # Python constants are scalars - return True - if function.name == "Rank": - if len(args) != 1: - raise TypeError( - f"Expected 1 positional argument for function '{function}', got {len(args)}." - ) - if isinstance(args[0], TorchScriptTensor): - if args[0].rank is not None: - return args[0].rank - else: - # Fall to call add_function_call - pass - elif isinstance(args[0], Sequence): - if all(isinstance(arg, (int, float)) for arg in args[0]): - return 1 - else: - # Fall to call add_function_call - pass - else: - # Python constants are scalars - return 0 - elif function.experimental_traceable: - # Trace the function call instead of adding the function as a node - return function.function(*args, **kwargs) - - # args/kwargs are TorchScriptTensor/python built-in based - param_schemas = function.param_schemas() - ( - inputs, - attributes, - ) = param_manipulation.separate_input_attributes_from_arguments( - param_schemas, args, kwargs, fill_defaults=True, allow_extra_kwargs=True - ) - - # Cast attributes to the correct type based on function signature - op_schema = function.op_schema - assert op_schema is not None - for name, value in attributes.items(): - attribute = op_schema.attributes[name] - if attribute.type == onnx.defs.OpSchema.AttrType.FLOAT: - # Cast int to float if the attribute is FLOAT - attributes[name] = float(value) - - # In PyTorch, an attribute annotated as `int[1]?` accepts an integer - # or a sequence. When the attribute is an integer, it is treated as - # a single element sequence. ONNX requires an attribute to either be - # an integer or a sequence. So we promote the value to a sequence here. - if attribute.type == onnx.defs.OpSchema.AttrType.INTS and isinstance(value, int): - attributes[name] = (value,) - if attribute.type == onnx.defs.OpSchema.AttrType.FLOATS and isinstance( - value, float - ): - attributes[name] = (value,) - return self._graph.add_function_call(function, inputs, attributes) - - -@runtime_typing.checked -def _add_attribute_to_torchscript_node( - node: torch.Node, - key: str, - value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor], -): - """Initializes the right attribute based on type of value.""" - if isinstance(value, float): - return node.f_(key, value) - if isinstance(value, int): - return node.i_(key, value) - if isinstance(value, (str, bytes)): - return node.s_(key, value) # type: ignore[arg-type] - if isinstance(value, torch.Tensor): - return node.t_(key, value) - if isinstance(value, Sequence): - if not value: - # Treat empty sequences as empty list tensors - # TODO(justinchuby): Revisit ways to determine the type of the empty list - return node.is_(key, list(value)) # type: ignore[attr-defined] - if isinstance(value[0], float): - return node.fs_(key, list(value)) # type: ignore[arg-type] - if isinstance(value[0], int): - return node.is_(key, list(value)) # type: ignore[attr-defined] - raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'") - raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'") - - -@runtime_typing.checked -def _create_op_call_in_torch_graph( - graph: torch.Graph, - opname: str, - *, - inputs: Sequence[torch.Value], - attributes: Mapping[str, Any], - n_outputs: int = 1, -) -> Tuple[torch.Value, ...]: - """Creates a node representing an onnx op in `graph`. - - Args: - graph: The torch graph to add the node to. - opname: The name of the op to add. E.g. "onnx::Add". - inputs: The onnx inputs to the op. - attributes: The onnx attributes to the op. - n_outputs: The number of outputs the op has. - - Returns: - The outputs of the created node. - """ - # Filter out None attributes, this can be convenient client side because - # now they can pass through None attributes, and have them not show up - attributes = {k: v for k, v in attributes.items() if v is not None} - - node = graph.create(opname, inputs, n_outputs) - node = graph.insertNode(node) - node_ouputs = tuple(node.outputs()) - - assert len(node_ouputs) == n_outputs - # Add all attributes - for key, value in sorted(attributes.items()): - _add_attribute_to_torchscript_node(node, key, value) - - return node_ouputs - - -def _tensor_rawdata_size(tensor: torch.Tensor) -> int: - """Estimate the size of a tensor in bytes. - - Args: - tensor: The tensor to estimate the size of. - - Returns: - The estimated size of the tensor in bytes. - """ - return tensor.numel() * tensor.element_size() - - -def _shared_functions() -> list[onnx.FunctionProto]: - """Hack to always include the share ops.""" - - # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed - return [ - common_ops.Rank.to_function_proto(), - common_ops.IsScalar.to_function_proto(), - ] - - -class TorchScriptGraph: - def __init__( - self, - parent_torch_script_graph: Optional[TorchScriptGraph] = None, - domain_name: Optional[str] = None, - ): - self._torch_graph = torch.Graph() - # All the functions used, deduplicated by name - # key: (name, domain) - self._function_store: Dict[Tuple[str, str], onnxscript.OnnxFunction] = {} - # Mapping from intializer name to data(torch.Tensor). - self._initializers: Dict[str, torch.Tensor] = {} - # Mapping from intializer name to input(TorchScriptTensor). - self._initializers_inputs: Dict[str, TorchScriptTensor] = {} - # Mapping from intializer name to input(TorchScriptTensor) from parent graph. - self._initializers_inputs_from_parent: Dict[str, TorchScriptTensor] = {} - # Mapping from model local function type name to function graph. - # Local function type name is expected to be unique. Converter creates - # a unique name and a unique function graph for every module call. - self._sub_torch_script_graphs: Dict[str, TorchScriptGraph] = {} - # Parent graph. None if this is the top level graph. - self._parent_torch_script_graph = parent_torch_script_graph - # Domain name of the graph. None if this is the top level graph. - self._domain_name: Optional[str] = domain_name - # Mapping from `torch.Value` to `TorchScriptTensor`. - # Because `torch.Value` does not provide API to set and retrieve symbolic shapes, - # and because `TorchScriptTensor` is not accessible through the `torch.Graph` graph, - # this mapping is used to keep track of the `TorchScriptTensor` associated with - # `torch.Value`. - # `TorchScriptTensor` records dtype and symbolic shapes. - # This info is later serialized as `ValueInfoProto` inside ONNX, to - # provide shape and dtype information for nodes within nested function calls. - # https://github.com/onnx/onnx/issues/5487 - self._value_to_tensor: Dict[torch.Value, TorchScriptTensor] = {} - - if self._domain_name is None and self._parent_torch_script_graph is not None: - raise RuntimeError( - "Domain name is not set. It is required because this 'TorchScriptGraph' instance " - "is a subgraph that represents an ONNX local function." - ) - - @property - def torch_graph(self): - return self._torch_graph - - @property - def initializers(self) -> Mapping[str, torch.Tensor]: - return self._initializers - - # NOTE: This setter is used in torch converter when we activate fake mode, - # we need to filter out the initializers that has fake tensor. This - # is because we don't want to introduce fake tensor in onnxscript. - @initializers.setter - def initializers(self, initializers: Dict[str, torch.Tensor]): - self._initializers = initializers - - @property - def initializers_inputs(self) -> Mapping[str, TorchScriptTensor]: - return self._initializers_inputs - - @property - def initializers_inputs_from_parent(self) -> Mapping[str, TorchScriptTensor]: - return self._initializers_inputs_from_parent - - @property - def num_outputs(self) -> int: - return len(list(self._torch_graph.outputs())) - - @property - def domain_name(self) -> Optional[str]: - return self._domain_name - - @runtime_typing.checked - def add_input( - self, - input_name: Optional[str], - shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - ) -> TorchScriptTensor: - if input_name is None: - # This input argument is None, which is mapped - # to a NULL value in TorchScript type system. - torch_value = _create_op_call_in_torch_graph( - self._torch_graph, "prim::Constant", inputs=(), attributes={} - )[0] - torch_value.setType(torch.OptionalType.ofTensor()) - else: - torch_value = self._torch_graph.addInput(input_name) - torch_value.setType(torch_value.type().with_dtype(dtype)) # type: ignore[arg-type] - # TODO(titaiwang): This approach loses the information that "same SymInts - # indicates same shape", for example, [symint0, symint0, symint1] - # would all be [None, None, None] - torch_value.setType( - torch_value.type().with_sizes( - [dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr] - ) - ) - tensor_value = _wrap_torch_value_to_tensor( - torch_value, shape=shape, dtype=dtype, device=device - ) - if isinstance(tensor_value, TorchScriptTensor): - # NOTE: Only track value that maps to tensor. - # Value that maps to Sequence/Dict of tensors is not tracked. - self._value_to_tensor[torch_value] = tensor_value - return tensor_value # type: ignore[return-value] - - @runtime_typing.checked - def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor: - if name in self._initializers_inputs: - # NOTE: Previously it raises when `name` is already set. This is relaxed - # because this will be invoked multiple times when submodule is called - # multiple times. - if name in self._initializers and self._initializers[name] is not value: - raise ValueError( - f"Initializer '{name}' exists already with a different value." - ) - return self._initializers_inputs[name] # type: ignore[return-value] - - if ( - self != self._parent_torch_script_graph - and self._parent_torch_script_graph is not None - ): - # Only the root graph can have initializers. Add as initializer - # to root graph, and add as input to current graph. - self._initializers_inputs_from_parent[name] = ( - self._parent_torch_script_graph.add_initializer(name, value) - ) - else: - self._initializers[name] = value - - torch_value = self._torch_graph.addInput(name) - torch_value.setType(torch.TensorType.create_from_tensor(value)) - tensor_value = _wrap_torch_value_to_tensor( - torch_value, shape=value.shape, dtype=value.dtype - ) - if isinstance(tensor_value, TorchScriptTensor): - self._value_to_tensor[torch_value] = tensor_value - self._initializers_inputs[name] = tensor_value # type: ignore[assignment] - return tensor_value # type: ignore[return-value] - - @runtime_typing.checked - def register_outputs( - self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]] - ): - unwrapped_outputs = _unwrap_tensors_to_torch_values(outputs) - if isinstance(unwrapped_outputs, torch.Value): - self._torch_graph.registerOutput(unwrapped_outputs) - return - assert isinstance(unwrapped_outputs, Sequence) - for ts_output in unwrapped_outputs: - assert isinstance( - ts_output, torch.Value - ), f"ts_output must be a torch.Value, not {type(ts_output)}" - self._torch_graph.registerOutput(ts_output) - return - - def _add_constant_to_graph(self, constant) -> torch.Value: - if constant is None: - value = _create_op_call_in_torch_graph( - self._torch_graph, "prim::Constant", inputs=(), attributes={} - )[0] - value.setType(torch.OptionalType.ofTensor()) - value.setDebugName(_rename_intermediate_value(value.debugName())) - return value - - if isinstance(constant, bool): - # Be sure to put bool before int, because bool is a subclass of int - constant_tensor = torch.tensor(constant, dtype=torch.bool) - elif isinstance(constant, float): - constant_tensor = torch.tensor(constant, dtype=torch.float) - elif isinstance(constant, int): - constant_tensor = torch.tensor(constant, dtype=torch.int64) - elif isinstance(constant, (tuple, list)) and all( - isinstance(val, int) for val in constant - ): - constant_tensor = torch.tensor(constant, dtype=torch.int64) - elif isinstance(constant, (tuple, list)) and all( - isinstance(val, float) for val in constant - ): - constant_tensor = torch.tensor(constant, dtype=torch.float) - elif isinstance(constant, complex): - # NOTE: ONNX doesn't support tensor of complex64/complex128, so we - # convert them to float32/float64 with real representation. - constant_tensor = torch.view_as_real(torch.tensor(constant).resolve_conj()) - else: - raise TypeError( - f"Constant input '{constant}' of type '{type(constant)}' is not supported" - ) - value = _create_op_call_in_torch_graph( - self._torch_graph, - "onnx::Constant", - inputs=(), - attributes=dict(value=constant_tensor), - )[0] - value.setDebugName(_rename_intermediate_value(value.debugName())) - return value - - @runtime_typing.checked - def _add_torchscript_op_call( - self, - name: str, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - n_outputs: int, - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: - unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs) - graph_inputs = [] - assert isinstance(unwrapped_inputs, Sequence) - for input in unwrapped_inputs: - # NOTE(titaiwang): input could be empty list - if ( - isinstance(input, Sequence) - and input - and all(isinstance(elem, torch.Value) for elem in input) - ): - # If all elements in the Sequence are torch.Values we know it - # should be a Sequence input in ONNX. - input_sequence = _create_op_call_in_torch_graph( - self._torch_graph, - "onnx::SequenceConstruct", - inputs=input, - attributes={}, - )[0] - graph_inputs.append(input_sequence) - elif not isinstance(input, torch.Value): - graph_inputs.append(self._add_constant_to_graph(input)) - else: - graph_inputs.append(input) - for key, value in onnx_attributes.items(): - assert not isinstance( - value, TorchScriptTensor - ), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}." - result = _create_op_call_in_torch_graph( - self._torch_graph, - name, - inputs=graph_inputs, - attributes=onnx_attributes, - n_outputs=n_outputs, - ) - assert result, "Expected at least one output from ONNX op call." - # NOTE: TorchScriptTensor is created here, however neither dtype nor shape is - # set. It is expected that exporter will modify the tensor being returned and - # set these info. - if len(result) == 1: - tensor = TorchScriptTensor(result[0]) - tensor.name = _rename_intermediate_value(tensor.name) - self._value_to_tensor[result[0]] = tensor - return tensor - tensors = tuple(TorchScriptTensor(v) for v in result) - self._value_to_tensor.update(dict(zip(result, tensors))) - for tensor in tensors: - tensor.name = _rename_intermediate_value(tensor.name) - return tensors - - @runtime_typing.checked - def fetch_function_proto_dict( - self, opset_version: int - ) -> Mapping[Tuple[str, str], onnx.FunctionProto]: - function_proto_dict: Dict[Tuple[str, str], onnx.FunctionProto] = {} - # Fetch local function protos. E.g., local functions representing module calls. - for ( - sub_graph_name, - sub_torch_script_graph, - ) in self._sub_torch_script_graphs.items(): - function_proto_dict.update( - sub_torch_script_graph.fetch_function_proto_dict(opset_version) - ) - domain = sub_torch_script_graph.domain_name - assert domain is not None - name_domain = ( - sub_graph_name, - domain, - ) - assert ( - name_domain not in function_proto_dict - ), f"Sub graph name already exists. {name_domain}" - function_proto_dict[name_domain] = sub_torch_script_graph.to_function_proto( - opset_version, sub_graph_name - ) - # Fetch torchlib function protos. - for name_domain, function in self._function_store.items(): - function_proto_dict[name_domain] = function.to_function_proto() - return function_proto_dict - - @runtime_typing.checked - def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): - existing_value_info = {info.name: info for info in onnx_model.graph.value_info} - - # Override value_info for top level graph inputs. - for input in self.torch_graph.inputs(): - if input not in self._value_to_tensor: - raise RuntimeError(f"Input '{input.debugName()}' has no type.") - tensor = self._value_to_tensor[input] - if (value_info := tensor.value_info()) is None: - continue - for i, input_info in enumerate(onnx_model.graph.input): - if input_info.name == input.debugName(): - # See NOTE: _C.Value re-naming. - value_info.name = input_info.name - onnx_model.graph.input.insert(i, value_info) - onnx_model.graph.input.remove(input_info) - break - - # Override value_info for top level graph outputs. - for output in self.torch_graph.outputs(): - if output not in self._value_to_tensor: - raise RuntimeError(f"Output '{output.debugName()}' has no type.") - tensor = self._value_to_tensor[output] - if (value_info := tensor.value_info()) is None: - continue - for i, output_info in enumerate(onnx_model.graph.output): - if output_info.name == output.debugName(): - # See NOTE: _C.Value re-naming. - value_info.name = output_info.name - onnx_model.graph.output.insert(i, value_info) - onnx_model.graph.output.remove(output_info) - break - - # Remove existing static/incomplete value info. - del onnx_model.graph.value_info[:] - - # Insert value info for nodes within nested function calls. - # NOTE: This is an experimental feature, will be replaced by ValueInfo inside FunctionProto - # in ONNX 1.16. https://github.com/microsoft/onnxscript/issues/1268 - # The naming strategy is subject to change. Since all local functions representing - # nn.Modules exported by dynamo exporter have unique call sites, their function - # op_type name can serve to form the unique identifier for value info. - # Store inside top level GraphProto. - new_value_info = self.generate_subgraphs_value_info_proto() - # Insert value info for nodes in top level graph. - new_value_info.update(self.generate_maingraph_value_info_proto()) - # Do not store input, output or initializer into value_info - for input in onnx_model.graph.input: - new_value_info.pop(input.name, None) - for output in onnx_model.graph.output: - new_value_info.pop(output.name, None) - for tensor in onnx_model.graph.initializer: # type: ignore[assignment] - new_value_info.pop(tensor.name, None) - existing_value_info.update(new_value_info) - onnx_model.graph.value_info.extend(existing_value_info.values()) - - return onnx_model - - @runtime_typing.checked - def add_op_call( - self, - onnx_op_schema: onnx.defs.OpSchema, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: - # Compute outputs from the onnx_op op schema - n_outputs = evaluator.compute_num_outputs(onnx_op_schema, onnx_inputs, onnx_attributes) - result = self._add_torchscript_op_call( - f"onnx::{onnx_op_schema.name}", - onnx_inputs, - onnx_attributes, - n_outputs=n_outputs, - ) - - return result - - @runtime_typing.checked - def add_function_call( - self, - onnx_function: onnxscript.OnnxFunction, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: - identifier = (onnx_function.name, onnx_function.function_ir.domain) - self._function_store[identifier] = onnx_function - - # Compute outputs from the function schema - result = self._add_torchscript_op_call( - f"{onnx_function.function_ir.domain}::{onnx_function.name}", - onnx_inputs, - onnx_attributes, - n_outputs=len(onnx_function.function_ir.outputs), - ) - - return result - - @runtime_typing.checked - def add_module_call( - self, - name: str, - sub_torch_script_graph: TorchScriptGraph, - onnx_inputs: Sequence[ValidInputType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: - self._sub_torch_script_graphs[name] = sub_torch_script_graph - domain_name = sub_torch_script_graph.domain_name - assert domain_name is not None - return self._add_torchscript_op_call( - f"{domain_name}::{name}", - onnx_inputs=( - *onnx_inputs, - *sub_torch_script_graph.initializers_inputs_from_parent.values(), - ), - onnx_attributes={}, - n_outputs=sub_torch_script_graph.num_outputs, - ) - - def generate_function_value_info_proto( - self, function_op_type: str - ) -> Mapping[str, onnx.ValueInfoProto]: - named_value_info: Dict[str, onnx.ValueInfoProto] = {} - function_id = _function_id(self.domain_name, function_op_type) - for torch_value, tensor in self._value_to_tensor.items(): - if (value_info := tensor.value_info()) is None: - continue - name = f"{function_id}/{torch_value.debugName()}" - value_info.name = name - named_value_info[name] = value_info - named_value_info.update(self.generate_subgraphs_value_info_proto()) - return named_value_info - - @runtime_typing.checked - def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: - """Unique naming strategies for values inside subgraphs, i.e. local functions. - - {function_domain::function_op_type}/{value_name} - - NOTE: Mainly designed for specialized functions, which are local functions - with only one call site. For non-specialized functions, it is assumed that - the `value_info` carried in `TorchScriptTensor` represents the general - compatible shape and type. - """ - named_value_info: Dict[str, onnx.ValueInfoProto] = {} - for name, sub_graph in self._sub_torch_script_graphs.items(): - named_value_info.update(sub_graph.generate_function_value_info_proto(name)) - return named_value_info - - @runtime_typing.checked - def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: - """Returns value info proto for values in the main graph.""" - named_value_info: Dict[str, onnx.ValueInfoProto] = {} - for torch_value, tensor in self._value_to_tensor.items(): - if (value_info := tensor.value_info()) is None: - continue - # NOTE: _C.Value re-naming. - # _C.Value's debugName is unstable. - # When duplicated names are encountered, all names involved are updated by - # TorchScript naming strategy. Hence the previous name stored in value_info - # can be outdated. - value_info.name = torch_value.debugName() - named_value_info[torch_value.debugName()] = value_info - return named_value_info - - @runtime_typing.checked - def to_function_proto(self, opset_version: int, function_name: str) -> onnx.FunctionProto: - assert len(self.initializers) == 0, "Model local functions cannot have initializers." - ( - proto, - _, - _, - _, - ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access - initializers={}, - onnx_opset_version=opset_version, - dynamic_axes={}, - defer_weight_export=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX, - strip_doc_string=False, - keep_initializers_as_inputs=False, - custom_opsets={}, - add_node_names=True, - onnx_file_path="", # Large model export. Out of scope. - node_attr_to_name={}, # Current module as function feature does not utilize attributes. - ) - - onnx_model = onnx.load_from_string(proto) - - # Dissect the model proto and transform to function proto. - domain = self.domain_name - if domain is None: - raise RuntimeError("Domain name is not set.") - onnx_function = onnx.helper.make_function( - domain=domain, - fname=function_name, - inputs=[input.name for input in onnx_model.graph.input], - outputs=[output.name for output in onnx_model.graph.output], - nodes=onnx_model.graph.node, - opset_imports=onnx_model.opset_import, - doc_string=onnx_model.doc_string, - ) - return onnx_function - - @runtime_typing.checked - def to_model_proto( - self, opset_version: int, include_initializers: bool = True - ) -> onnx.ModelProto: - function_proto_dict: Mapping[Tuple[str, str], onnx.FunctionProto] = ( - self.fetch_function_proto_dict(opset_version) - ) - unique_custom_domains: Dict[str, int] = {} - - for function_proto in function_proto_dict.values(): - # TODO(BowenBao): All local function domain versions are hardcoded as 1. - unique_custom_domains[function_proto.domain] = 1 - - initializers_size = sum( - _tensor_rawdata_size(tensor) for tensor in self.initializers.values() - ) - - large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD - - export_kwargs: dict[str, Any] = dict( - initializers=self.initializers - if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS - else {}, - onnx_opset_version=opset_version, - dynamic_axes={}, - defer_weight_export=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX, - strip_doc_string=False, - keep_initializers_as_inputs=_flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS, - custom_opsets={}, - add_node_names=True, - node_attr_to_name={}, - ) - - # We decided to cache the model to disk when the model is large. - # Alternatively, we could build the ONNX `TensorProto`s in memory - # and append them to the model proto. - # We did not do it because it is harder to get right (vs. PyTorch's battle-tested - # implementation) and creating the `TensorProto`s naively (by converting to numpy) - # is slow. - cache_model_to_disk = large_model and include_initializers - - if cache_model_to_disk: - with tempfile.TemporaryDirectory() as temp_dir: - onnx_file_path = os.path.join(temp_dir, "exported_model.onnx") - export_kwargs["onnx_file_path"] = onnx_file_path - ( - proto, - _, - _, - _, - ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access - **export_kwargs - ) - onnx_model = onnx.load_from_string(proto) - onnx.load_external_data_for_model(onnx_model, temp_dir) - else: - ( - proto, - _, - _, - _, - ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access - **export_kwargs - ) - onnx_model = onnx.load_from_string(proto) - - onnx_model.functions.extend(function_proto_dict.values()) - onnx_model.functions.extend(_shared_functions()) - - # Override value_infos with symbolic shapes. - onnx_model = self._override_with_symbolic_value_info_proto(onnx_model) - - # `_export_onnx` only exports opset_imports that is visible to it. It does not - # export opset_imports for nested functions, since it does not have access to - # them. We manually add them back and merge with existing opset_imports in the - # model proto. - while len(onnx_model.opset_import) > 0: - opsetid = onnx_model.opset_import.pop() - unique_custom_domains[opsetid.domain] = opsetid.version - onnx_model.opset_import.extend( - [ - onnx.helper.make_opsetid(domain, version) - for domain, version in unique_custom_domains.items() - ] - ) - # Include the library shared opset domain - # TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed - onnx_model.opset_import.append( - onnx.helper.make_opsetid( - common_ops.common_opset.domain, common_ops.common_opset.version - ) - ) - return onnx_model diff --git a/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py deleted file mode 100644 index 76464b70ef..0000000000 --- a/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py +++ /dev/null @@ -1,231 +0,0 @@ -"""Test cases for graph building functionality.""" - -# mypy: disable-error-code="arg-type,type-arg,valid-type" -from __future__ import annotations - -import os -import sys -import unittest - -import torch - -import onnxscript -import onnxscript.testing -from onnxscript import FLOAT, evaluator -from onnxscript import opset18 as op -from onnxscript._internal import version_utils -from onnxscript.function_libs.torch_lib import graph_building, ops - -IS_WINDOWS = os.name == "nt" - - -class TestTorchScriptTracingEvaluator(unittest.TestCase): - def setUp(self): - self.opset_version = 18 - # TODO: Add test for initializer. Currently skipped since to `assert_isomorphic` - # does not check for initializers. - self.onnxscript_graph = graph_building.TorchScriptGraph() - self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph) - - def test_torchscript_tensor_keeps_torch_device(self): - x_tensor = torch.ones((1, 2, 3), dtype=torch.float32) - x = self.onnxscript_graph.add_input( - "x", x_tensor.shape, x_tensor.dtype, x_tensor.device - ) - self.assertEqual(x.device, x_tensor.device) - - x.device = torch.device("cuda") - self.assertEqual(x.device, torch.device("cuda")) - - def test_traced_constant_op_is_same_as_compiled_graph(self): - """Test for op.Constant created in graph builder""" - with evaluator.default_as(self.tracer): - output = op.Constant(value_float=0.5) - - self.onnxscript_graph.register_outputs(output) - traced = self.onnxscript_graph.to_model_proto(self.opset_version) - - @onnxscript.script() - def expected_model(): - return op.Constant(value_float=0.5) - - expected = expected_model.to_model_proto() - - onnxscript.testing.assert_isomorphic(traced, expected) - - def test_traced_graph_on_single_node_is_same_as_compiled_graph(self): - aten_relu = ops.nn.aten_relu - - x_tensor = torch.ones((1, 2, 3), dtype=torch.float32) - x = self.onnxscript_graph.add_input("x", x_tensor.shape, x_tensor.dtype) - with evaluator.default_as(self.tracer): - output = aten_relu(x) - - self.onnxscript_graph.register_outputs(output) - traced = self.onnxscript_graph.to_model_proto(self.opset_version) - - @onnxscript.script(default_opset=op) - def expected_model(x: FLOAT[1, 2, 3]): - return aten_relu(x) - - expected = expected_model.to_model_proto() - - onnxscript.testing.assert_isomorphic(traced, expected) - - @unittest.expectedFailure # The scripted version does not have output type - def test_traced_graph_on_single_node_multi_output_is_same_as_compiled_graph(self): - aten_topk = ops.core.aten_topk - - x_tensor = torch.ones((1, 2, 3), dtype=torch.float32) - x = self.onnxscript_graph.add_input("x", x_tensor.shape, x_tensor.dtype) - with evaluator.default_as(self.tracer): - output = aten_topk(x, 2) - - self.onnxscript_graph.register_outputs(output) - traced = self.onnxscript_graph.to_model_proto(self.opset_version) - - @onnxscript.script(default_opset=op) - def expected_model(x: FLOAT[1, 2, 3]): - values, indices = aten_topk(x, 2) - return values, indices - - expected = expected_model.to_model_proto() - onnxscript.testing.assert_isomorphic(traced, expected) - - def test_model_local_function_constructed_by_traced_graph_is_same_as_compiled_graph( - self, - ): - aten_abs = ops.core.aten_abs - aten_relu = ops.nn.aten_relu - - inner_graph = graph_building.TorchScriptGraph(domain_name="test_domain") - inner_tracer = graph_building.TorchScriptTracingEvaluator(inner_graph) - - x_tensor = torch.ones((1, 2, 3), dtype=torch.float32) - x = inner_graph.add_input("x", x_tensor.shape, x_tensor.dtype) - with evaluator.default_as(inner_tracer): - output = aten_abs(x) - inner_graph.register_outputs(output) - - outer_graph = graph_building.TorchScriptGraph() - outer_tracer = graph_building.TorchScriptTracingEvaluator(outer_graph) - x_tensor = torch.ones((1, 2, 3), dtype=torch.float32) - x = outer_graph.add_input("x", x_tensor.shape, x_tensor.dtype) - with evaluator.default_as(outer_tracer): - output = aten_relu(x) - output = outer_graph.add_module_call("inner", inner_graph, (output,)) - outer_graph.register_outputs(output) - traced = outer_graph.to_model_proto(self.opset_version) - - @onnxscript.script( - opset=onnxscript.values.Opset("test_domain", 1), - default_opset=op, - ) - def inner(x: FLOAT[1, 2, 3]): - return aten_abs(x) - - @onnxscript.script(default_opset=op) - def outer(x: FLOAT[1, 2, 3]): - output = aten_relu(x) - return inner(output) - - expected = outer.to_model_proto() - onnxscript.testing.assert_isomorphic(traced, expected) - - def test_add_input_with_optionaltype_does_not_raise_torch_internal_error(self): - graph = graph_building.TorchScriptGraph() - x = graph.add_input(input_name=None) - with evaluator.default_as(self.tracer): - _ = x.shape - - -class TestTorchScriptGraph(unittest.TestCase): - def test_add_initializer_raises_when_the_same_name_used_for_different_tensors(self): - graph = graph_building.TorchScriptGraph() - graph.add_initializer("x", torch.ones((1, 2, 3), dtype=torch.float32)) - with self.assertRaises(ValueError): - graph.add_initializer("x", torch.ones((1, 2, 3), dtype=torch.float32)) - - def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(self): - graph = graph_building.TorchScriptGraph() - x_tensor = torch.ones((1, 2, 3), dtype=torch.float32) - graph.add_initializer("x", x_tensor) - graph.add_initializer("x", x_tensor) - - -class _MLP(torch.nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super().__init__() - self.fc1 = torch.nn.Linear(input_size, hidden_size) - self.fc2 = torch.nn.Linear(hidden_size, output_size) - self.relu = torch.nn.ReLU() - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - - -@unittest.skipIf( - IS_WINDOWS and version_utils.torch_older_than("2.3"), - "dynamo_export not supported on Windows in PyTorch<2.3", -) -@unittest.skipIf( - sys.version_info > (3, 11), - "dynamo_export not supported due to torch.compile not functional for python>3.11", -) -class TestModelSaving(unittest.TestCase): - def test_save_initializer_to_files_for_large_model(self): - # # of model parameters: - # input_size x hidden_size + hidden_size + - # hidden_size x output_size + output_size - # ~= 3GB below - batch_size, input_size, hidden_size, output_size = 1, 4, 50000000, 10 - model = _MLP(input_size, hidden_size, output_size) - x = torch.randn(batch_size, input_size) - - model_proto = torch.onnx.dynamo_export(model, x).model_proto - # Assert model is larger than 2GB (~=3GB) - self.assertGreater(model_proto.ByteSize(), 2**31) - - def test_input_output_and_initializer_are_not_stored_in_value_info(self): - batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 - model = _MLP(input_size, hidden_size, output_size) - x = torch.randn(batch_size, input_size) - - model_proto = torch.onnx.dynamo_export(model, x).model_proto - v_names = {v.name for v in model_proto.graph.value_info} - - for i in model_proto.graph.input: - self.assertNotIn(i.name, v_names) - for o in model_proto.graph.output: - self.assertNotIn(o.name, v_names) - for i in model_proto.graph.initializer: - self.assertNotIn(i.name, v_names) - - @unittest.skipIf( - not version_utils.torch_older_than("2.4"), - "PyTorch 2.4-preview optimizes the functions away", - ) - def test_experimental_function_value_info_are_stored_in_graph_value_info(self): - batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 - model = _MLP(input_size, hidden_size, output_size) - x = torch.randn(batch_size, input_size) - - model_proto = torch.onnx.dynamo_export(model, x).model_proto - v_names = {v.name for v in model_proto.graph.value_info} - torch_functions = [ - f for f in model_proto.functions if f.domain.startswith("pkg.torch") - ] - self.assertNotEqual(len(torch_functions), 0) - for f in torch_functions: - for n in f.node: - for i in n.input: - self.assertIn(f"{f.domain}::{f.name}/{i}", v_names) - for o in n.output: - self.assertIn(f"{f.domain}::{f.name}/{o}", v_names) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/function_libs/torch_lib/ops/__init__.py b/onnxscript/function_libs/torch_lib/ops/__init__.py index 5a1cfd76c0..b7bedaa4b8 100644 --- a/onnxscript/function_libs/torch_lib/ops/__init__.py +++ b/onnxscript/function_libs/torch_lib/ops/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. __all__ = [ "core", "fft", @@ -5,9 +7,21 @@ "nested", "nn", "prims", + "quantized_decomposed", "sparse", "special", "vision", ] -from . import core, fft, linalg, nested, nn, prims, sparse, special, vision +from . import ( + core, + fft, + linalg, + nested, + nn, + prims, + quantized_decomposed, + sparse, + special, + vision, +) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index ecef6852b8..b3ebbc1c53 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -1,12 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Common operators shared in the torchlib library.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" +from __future__ import annotations + +from collections.abc import Sequence + +import numpy.typing as npt +import onnx + import onnxscript import onnxscript.values -from onnxscript import BOOL, INT64 +from onnxscript import BOOL, INT64, ir from onnxscript import opset18 as op from onnxscript.function_libs.torch_lib import _constants, tensor_typing from onnxscript.function_libs.torch_lib.tensor_typing import RealType -from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT +from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType COMPLEX64_TYPE = COMPLEX64.dtype COMPLEX128_TYPE = COMPLEX128.dtype @@ -54,3 +64,38 @@ def cast_to(a: RealType, dtype: int) -> RealType: result = op.Cast(a, to=dtype) return result + + +def constant( + array: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible, + dtype: int | onnx.TensorProto.DataType | ir.DataType, +) -> TensorType: + """Utility for creating a constant tensor. + + Args: + array: The array to convert to a constant tensor. + dtype: The data type of the tensor. + + Returns: + A constant node. + """ + return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype))) + + +def merge_dims(dims: Sequence[int | INT64]) -> INT64: + """Concatenate dimensions into a single value.""" + + if not dims: + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) + + neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1])) + + result_dims = [ + op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d) + for d in dims + ] + + # Set the output type to INT64 so op.Concat can be used + for dim in result_dims: + dim.dtype = ir.DataType.INT64 + return op.Concat(*result_dims, axis=0) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a7a6073643..5127f3f9f6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1,7 +1,5 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" """torch.ops.aten operators under the `core` module. @@ -9,30 +7,31 @@ - All functions should not have the script() decorator. This is because we want to delay the compilation of the function. """ +# pylint: disable=unused-argument from __future__ import annotations import math from typing import Any, Optional, Sequence, Tuple, Union +import numpy as np +import torch + from onnxscript import ( - BFLOAT16, BOOL, COMPLEX64, COMPLEX128, DOUBLE, FLOAT, - FLOAT16, INT8, INT16, INT32, INT64, UINT8, - UINT16, - UINT32, - UINT64, graph, + ir, ) +from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( @@ -40,7 +39,6 @@ RealType, TFloat, TFloatHighPrecision, - TFloatOrBFloat16, TInt, TReal, TRealOrUInt8, @@ -56,123 +54,102 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi -IsScalar = common_ops.IsScalar Rank = common_ops.Rank -@torch_op("aten::_local_scalar_dense") -def aten__local_scalar_dense(self: Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) -> FLOAT: - """_local_scalar_dense(Tensor self) -> Scalar""" - - # Return the first element in tensor as a scalar. - return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=FLOAT.dtype) - - -@torch_op("aten::_local_scalar_dense") -def aten__local_scalar_dense_int(self: IntType) -> INT64: +@torch_op("aten::_local_scalar_dense", trace_only=True) +def aten__local_scalar_dense(self: TensorType) -> TensorType: """_local_scalar_dense(Tensor self) -> Scalar""" # Return the first element in tensor as a scalar. - return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=INT64.dtype) + if self.dtype.is_floating_point(): + dtype = ir.DataType.FLOAT + elif self.dtype == ir.DataType.BOOL: + dtype = ir.DataType.BOOL + else: + dtype = ir.DataType.INT64 + return op.Cast(op.Gather(op.Reshape(self, [-1]), 0), to=dtype) @torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax_half( - self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool -) -> FLOAT: +def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - # trace_only because we need to cast conditionally based on half_to_float - if half_to_float: + self_is_scalar = len(self.shape) == 0 + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: self = op.Cast(self, to=FLOAT.dtype) - - return aten__log_softmax(self, dim, half_to_float) - - -@torch_op("aten::_log_softmax", traceable=True) -def aten__log_softmax( - self: TFloatHighPrecision, - dim: int, - half_to_float: bool, # pylint: disable=unused-argument -) -> TFloatHighPrecision: - """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - self_is_scalar = IsScalar(self) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.LogSoftmax(self, axis=dim) - if self_is_scalar: # squeeze to scalar due to input is scalar - result = op.Squeeze(result) + if self_is_scalar: + result = op.Squeeze(result, op.Constant(value_ints=[0])) return result @torch_op("aten::_softmax", trace_only=True) -def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: +def aten__softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - # trace_only because we need to cast conditionally based on half_to_float - if half_to_float: - self = op.Cast(self, to=FLOAT.dtype) - - return aten_softmax_no_dtype(self, dim) + self_is_scalar = len(self.shape) == 0 + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: + self = op.Cast(self, to=FLOAT.dtype) -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax( - self: TFloatHighPrecision, dim: int, half_to_float: bool -) -> TFloatHighPrecision: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only to reuse aten_softmax_no_dtype + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + result = op.Softmax(self, axis=dim) + if self_is_scalar: + # Convert to scalar when input is scalar + result = op.Squeeze(result) - del half_to_float # Unused - return aten_softmax_no_dtype(self, dim) + return result -@torch_op(("aten::abs", "_operator::abs")) +@torch_op(("aten::abs", "_operator::abs"), trace_only=True) def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" return op.Abs(self) -@torch_op("aten::abs", complex=True) +@torch_op("aten::abs", complex=True, trace_only=True) def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" - # self_real = self[..., 0] - self_real = op.Slice(self, [0], [1], axes=[-1]) - # self_imag = self[..., 1] - self_imag = op.Slice(self, [1], [2], axes=[-1]) - real_pow = op.Pow(self_real, 2) - imag_pow = op.Pow(self_imag, 2) - real_plus_imag = op.Add(real_pow, imag_pow) - return op.Squeeze(op.Sqrt(real_plus_imag), axes=[-1]) + + return op.ReduceL2(self, [-1], keepdims=False) -@torch_op("aten::acos") +@torch_op("aten::acos", trace_only=True) def aten_acos(self: TFloat) -> TFloat: """acos(Tensor self) -> Tensor""" return op.Acos(self) -@torch_op("aten::acosh") +@torch_op("aten::acosh", trace_only=True) def aten_acosh(self: TFloat) -> TFloat: """acosh(Tensor self) -> Tensor""" return op.Acosh(self) -@torch_op(("aten::add", "aten::add.Tensor", "_operator::add")) -def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) +def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(microsoft/onnxruntime#15977): Improve fp16 precision - alpha = op.CastLike(alpha, other) - other = op.Mul(other, alpha) + + if self.dtype == ir.DataType.BOOL: + # alpha can also be bool + if alpha == 0: + return op.Identity(self) + return op.Or(self, other) + + if alpha != 1.0: + alpha = op.CastLike(alpha, other) + other = op.Mul(other, alpha) return op.Add(self, other) -@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"), trace_only=True, complex=True) +@torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True) def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -232,7 +209,7 @@ def aten_addcmul( return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2)) -@torch_op("aten::addmm") +@torch_op("aten::addmm", trace_only=True) def aten_addmm( self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0 ) -> TReal: @@ -241,6 +218,9 @@ def aten_addmm( # NOTE: ONNX Runtime does not support int inputs to Gemm as of 1.16. # To support int inputs, consider an overriding implementation that casts to float and back. + alpha = float(alpha) + beta = float(beta) + # addmm only accepts 2d tensors: https://pytorch.org/docs/stable/generated/torch.addmm.html return op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta) @@ -254,7 +234,7 @@ def aten_addmv( return op.Add(op.Mul(self, beta), op.Mul(op.MatMul(mat, vec), alpha)) -@torch_op("aten::addr", traceable=True) +@torch_op("aten::addr", trace_only=True) def aten_addr( self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0 ) -> TReal: @@ -303,7 +283,7 @@ def aten_affine_grid_generator_backward( raise NotImplementedError() -@torch_op("aten::alias") +@torch_op("aten::alias", trace_only=True) def aten_alias(self: TTensor) -> TTensor: """alias(Tensor(a) self) -> Tensor(a)""" @@ -334,11 +314,11 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op("aten::all", traceable=True) +@torch_op("aten::all", trace_only=True) def aten_all(self: TTensor) -> BOOL: """all(Tensor self) -> Tensor""" - if IsScalar(self): + if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -348,19 +328,15 @@ def aten_all(self: TTensor) -> BOOL: return result -@torch_op("aten::all.dim", traceable=True) +@torch_op("aten::all.dim", trace_only=True) def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: """all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor""" - if IsScalar(self): - result = op.Cast(self, to=BOOL.dtype) - else: - self_bool = op.Cast(self, to=BOOL.dtype) - self_int = op.Cast(self_bool, to=INT64.dtype) - dims = op.Reshape(dim, op.Constant(value_ints=[-1])) - all_true = op.ReduceMin(self_int, dims, keepdims=keepdim) - result = op.Cast(all_true, to=BOOL.dtype) - return result + self_bool = op.Cast(self, to=BOOL.dtype) + self_int = op.Cast(self_bool, to=INT64.dtype) + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + all_true = op.ReduceMin(self_int, dims, keepdims=keepdim) + return op.Cast(all_true, to=BOOL.dtype) @torch_op("aten::all.dims", trace_only=True) @@ -368,7 +344,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" if not dim: - return aten_all_dims_no_dim(self, keepdim) + return _aten_all_dims_no_dim(self, keepdim) for d in dim: self = aten_all_dim(self, d, keepdim=True) if not keepdim: @@ -376,13 +352,10 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::all.dims", traceable=True) -def aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: +def _aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" - # dim is None and thus not supplied - - if IsScalar(self): + if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -398,7 +371,7 @@ def aten_allclose( other: TReal, rtol: float = 1e-05, atol: float = 1e-08, - equal_nan: bool = False, # pylint: disable=unused-argument + equal_nan: bool = False, ) -> BOOL: """allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool""" @@ -456,11 +429,11 @@ def aten_angle(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::any", traceable=True) +@torch_op("aten::any", trace_only=True) def aten_any(self: TTensor) -> BOOL: """any(Tensor self) -> Tensor""" - if IsScalar(self): + if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -471,21 +444,17 @@ def aten_any(self: TTensor) -> BOOL: return result -@torch_op("aten::any.dim", traceable=True) +@torch_op("aten::any.dim", trace_only=True) def aten_any_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: """any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor""" - if IsScalar(self): - result = op.Cast(self, to=BOOL.dtype) - else: - self_bool = op.Cast(self, to=BOOL.dtype) - # op.ReduceMax() in the next step cannot process BOOL inputs, so convert to INT64 - self_int = op.Cast(self_bool, to=INT64.dtype) - # Change dim from int to INT64[1] - dims = op.Reshape(dim, op.Constant(value_ints=[-1])) - any_true = op.ReduceMax(self_int, dims, keepdims=keepdim) - result = op.Cast(any_true, to=BOOL.dtype) - return result + self_bool = op.Cast(self, to=BOOL.dtype) + # op.ReduceMax() in the next step cannot process BOOL inputs, so convert to INT64 + self_int = op.Cast(self_bool, to=INT64.dtype) + # Change dim from int to INT64[1] + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + any_true = op.ReduceMax(self_int, dims, keepdims=keepdim) + return op.Cast(any_true, to=BOOL.dtype) @torch_op("aten::any.dims", trace_only=True) @@ -493,7 +462,7 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" if not dim: - return aten_any_dims_no_dim(self, keepdim) + return _aten_any_dims_no_dim(self, keepdim) for d in dim: self = aten_any_dim(self, d, keepdim=True) if not keepdim: @@ -501,13 +470,8 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::any.dims", traceable=True) -def aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: - """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" - - # dim is None and thus not supplied - - if IsScalar(self): +def _aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: + if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) else: self_bool = op.Cast(self, to=BOOL.dtype) @@ -538,13 +502,16 @@ def _integral_to_be_adjusted(dtype: int) -> bool: @torch_op("aten::arange", trace_only=True) -def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1) -> TensorType: +def aten_arange( + end: TRealUnlessFloat16OrInt8, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - if dtype == -1: + if dtype == -1 or dtype is None: zero = op.CastLike(0.0, end) one = op.CastLike(1.0, end) result = op.Range(zero, end, one) @@ -568,14 +535,16 @@ def aten_arange(end: Union[DOUBLE, FLOAT, INT16, INT32, INT64], dtype: int = -1) @torch_op("aten::arange.start", trace_only=True) def aten_arange_start( - start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, dtype: int = -1 + start: TRealUnlessFloat16OrInt8, + end: TRealUnlessFloat16OrInt8, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - if dtype == -1: + if dtype == -1 or dtype is None: one = op.CastLike(1.0, end) result = op.Range(start, end, one) elif _range_supported(dtype): @@ -596,7 +565,6 @@ def aten_arange_start( return result -@torch_op("aten::arange.start_step", private=True) def _adjust_args_for_arange_int_dtype( start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, @@ -617,16 +585,51 @@ def _adjust_args_for_arange_int_dtype( def aten_arange_start_step( start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, - step: TRealUnlessFloat16OrInt8, + step: TRealUnlessFloat16OrInt8 = 1.0, dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - if dtype == -1: - result = op.Range(start, end, step) + # TODO: Because this is a trace_only function, the inputs are not promoted to + # Tensor until it hits ONNX ops. However, if it's dynamic, it should be + # Tensor at this point. + # https://github.com/microsoft/onnxscript/issues/1914 + if isinstance(start, (int, float)): + start_is_int = isinstance(start, int) + else: + start_is_int = start.dtype in { + INT16.dtype, + INT32.dtype, + INT64.dtype, + } + if isinstance(end, (int, float)): + end_is_int = isinstance(end, int) + else: + end_is_int = end.dtype in { + INT16.dtype, + INT32.dtype, + INT64.dtype, + } + if isinstance(step, (int, float)): + step_is_int = isinstance(step, int) + else: + step_is_int = step.dtype in { + INT16.dtype, + INT32.dtype, + INT64.dtype, + } + if start_is_int and end_is_int and step_is_int: + result = op.Range(start, end, step) + else: + # to float + start = op.Cast(start, to=FLOAT.dtype) + end = op.Cast(end, to=FLOAT.dtype) + step = op.Cast(step, to=FLOAT.dtype) + result = op.Range(start, end, step) elif _integral_to_be_adjusted(dtype): # PyTorch arange op handles these integral types differently from INT64, # so we have to adjust these arguments accordingly. @@ -693,11 +696,23 @@ def aten_arctanh(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::argmax", traceable=True) -def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: +@torch_op("aten::argmax", trace_only=True) +def aten_argmax( + self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False +) -> INT64: + """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" + + if dim is None: + result = _aten_argmax(self, keepdim) + else: + result = _aten_argmax_dim(self, dim, keepdim) + return result + + +def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = IsScalar(self) + self_is_scaler = len(self.shape) == 0 self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.ArgMax(self, keepdims=keepdim) if self_is_scaler: @@ -706,11 +721,10 @@ def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmax", traceable=True) -def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: +def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = IsScalar(self) + self_is_scaler = len(self.shape) == 0 if self_is_scaler: self = op.Reshape(self, op.Constant(value_ints=[-1])) @@ -721,11 +735,23 @@ def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fals return result -@torch_op("aten::argmin", traceable=True) -def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: +@torch_op("aten::argmin", trace_only=True) +def aten_argmin( + self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False +) -> INT64: + """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" + + if dim is None: + result = _aten_argmin(self, keepdim) + else: + result = _aten_argmin_dim(self, dim, keepdim) + return result + + +def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = IsScalar(self) + self_is_scaler = len(self.shape) == 0 self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.ArgMin(self, keepdims=keepdim) if self_is_scaler: @@ -734,11 +760,10 @@ def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmin", traceable=True) -def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: +def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" - self_is_scaler = IsScalar(self) + self_is_scaler = len(self.shape) == 0 if self_is_scaler: self = op.Reshape(self, op.Constant(value_ints=[-1])) @@ -763,7 +788,7 @@ def aten_argwhere(self: TensorType) -> TensorType: @torch_op("aten::as_strided", trace_only=True) def aten_as_strided( - self: TTensor, size: INT64, stride: INT64, storage_offset: int = 0 + self: TTensor, size: INT64, stride: Sequence[int], storage_offset: int = 0 ) -> TTensor: """as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)""" @@ -851,55 +876,60 @@ def aten_as_strided_scatter( raise NotImplementedError() -@torch_op("aten::asin") +@torch_op("aten::asin", trace_only=True) def aten_asin(self: TFloat) -> TFloat: """asin(Tensor self) -> Tensor""" return op.Asin(self) -@torch_op("aten::asinh") +@torch_op("aten::asinh", trace_only=True) def aten_asinh(self: TFloat) -> TFloat: """asinh(Tensor self) -> Tensor""" return op.Asinh(self) -@torch_op("aten::atan") +@torch_op("aten::atan", trace_only=True) def aten_atan(self: TFloat) -> TFloat: """atan(Tensor self) -> Tensor""" return op.Atan(self) -@torch_op("aten::atan2") +@torch_op("aten::atan2", trace_only=True) def aten_atan2(self: TFloat, other: TFloat) -> TFloat: """atan2(Tensor self, Tensor other) -> Tensor""" # self is y, and other is x on coordinate slope = op.Div(self, other) atan = op.Atan(slope) + zero = common_ops.constant(0.0, dtype=self.dtype) + pi = common_ops.constant(_MATH_PI, dtype=self.dtype) + + second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi) + result = op.Where(op.Less(other, zero), second_third_quadrant, atan) - second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI) - result = op.Where(other < 0.0, second_third_quadrant, atan) + # Map NaN to 0 to match PyTorch behavior + result = op.Where(op.IsNaN(result), zero, result) return result -@torch_op("aten::atanh") +@torch_op("aten::atanh", trace_only=True) def aten_atanh(self: TFloat) -> TFloat: """atanh(Tensor self) -> Tensor""" return op.Atanh(self) -@torch_op("aten::atleast_1d", traceable=True) +@torch_op("aten::atleast_1d", trace_only=True) def aten_atleast_1d(self: TTensor) -> TTensor: """atleast_1d(Tensor self) -> Tensor""" - if IsScalar(self): + if len(self.shape) == 0: self = op.Reshape(self, op.Constant(value_ints=[1])) - return self + return op.Identity(self) @torch_op("aten::atleast_1d.Sequence") @@ -923,7 +953,7 @@ def aten_atleast_2d(self: TTensor) -> TTensor: if Rank(self) <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1])) - return self + return op.Identity(self) @torch_op("aten::atleast_2d.Sequence") @@ -941,7 +971,7 @@ def reshape_to_2d(tensor): return op.SequenceMap(self, body=reshape_to_2d) -@torch_op("aten::atleast_3d", traceable=True) +@torch_op("aten::atleast_3d", trace_only=True) def aten_atleast_3d(self: TTensor) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" @@ -950,7 +980,7 @@ def aten_atleast_3d(self: TTensor) -> TTensor: self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) elif rank == 2: self = op.Unsqueeze(self, op.Constant(value_ints=[-1])) - return self + return op.Identity(self) @torch_op("aten::atleast_3d.Sequence") @@ -970,20 +1000,25 @@ def reshape_to_3d(tensor): return op.SequenceMap(self, body=reshape_to_3d) -@torch_op("aten::baddbmm") +@torch_op("aten::baddbmm", trace_only=True) def aten_baddbmm( self: TRealOrUInt8, batch1: TRealUnlessInt16OrInt8, batch2: TRealUnlessInt16OrInt8, - beta: float = 1.0, - alpha: float = 1.0, + beta: Optional[TFloat] = None, + alpha: Optional[TFloat] = None, ) -> TRealUnlessInt16OrInt8: """baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" + # beta and alpha can be SymFloat batch_mul = op.MatMul(batch1, batch2) - alpha_cast = op.CastLike(alpha, self) - mul_a = op.Mul(batch_mul, alpha_cast) - beta_cast = op.CastLike(beta, self) - mul_b = op.Mul(self, beta_cast) + if alpha is None or alpha == 1: + mul_a = batch_mul + else: + mul_a = op.Mul(batch_mul, op.CastLike(alpha, self)) + if beta is None or beta == 1: + mul_b = self + else: + mul_b = op.Mul(self, op.CastLike(beta, self)) return op.Add(mul_a, mul_b) @@ -1099,7 +1134,7 @@ def aten_batch_norm_update_stats( raise NotImplementedError() -@torch_op("aten::bernoulli") +@torch_op("aten::bernoulli", trace_only=True) def aten_bernoulli(self: TFloat) -> TFloat: """Proximal implementation of aten::bernoulli.default @@ -1126,6 +1161,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor: return op.CastLike(sampled, self) +@torch_op("aten::bilinear", trace_only=True) def aten_bilinear( input1: TensorType, input2: TensorType, @@ -1134,7 +1170,23 @@ def aten_bilinear( ) -> TensorType: """bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor""" - raise NotImplementedError() + # Bilinear transformation: y = x1^T A x2 + b + # input1 shape: (..., in1_features) + # input2 shape: (..., in2_features) + # weight shape: (out_features, in1_features, in2_features) + # bias shape: (out_features) - optional + # output shape: (..., out_features) + + # Use Einsum to compute the bilinear transformation + # "...i,oij,...j->...o" means: + # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] + result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") + + # Add bias if provided + if bias is not None: + result = op.Add(result, bias) + + return result def aten_binary_cross_entropy_with_logits( @@ -1167,176 +1219,179 @@ def aten_binomial( @torch_op( ( - "aten::bitwise_and", "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", "_operator::and_", - ) + ), + trace_only=True, ) -def aten_bitwise_and(self: TInt, other: TInt) -> TInt: +def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_and implements the BOOL variant - return op.BitwiseAnd(self, other) + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + if dtype.is_integer(): + return op.BitwiseAnd(self, other) + if dtype == ir.DataType.BOOL: + return op.And(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") -@torch_op("aten::bitwise_left_shift") -def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - - result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT16.dtype) +@torch_op("aten::bitwise_and.Scalar", trace_only=True) +def aten_bitwise_and_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_and(self, other_tensor) -@torch_op("aten::bitwise_left_shift") -def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - result = op.BitShift(self, other, direction="LEFT") +@torch_op("aten::bitwise_and.Scalar_Tensor", trace_only=True) +def aten_bitwise_and_scalar_tensor(self: float, other: TTensor) -> TTensor: + """bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" - return op.Cast(result, to=INT32.dtype) + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_and(self_tensor, other) -@torch_op("aten::bitwise_left_shift") -def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: +@torch_op( + ( + "aten::bitwise_left_shift.Tensor", + "_operator::__lshift__", + ), + trace_only=True, +) +def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + # assert other >= 0 - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) + if dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + elif dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + elif dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + elif dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + else: + raise NotImplementedError(f"Not implemented for type {dtype}") + + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT64.dtype) + return op.Cast(result, to=signed_dtype) -@torch_op("aten::bitwise_left_shift") -def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) +@torch_op( + ("aten::bitwise_left_shift.Tensor_Scalar", "aten::__lshift__.Scalar"), trace_only=True +) +def aten_bitwise_left_shift_tensor_scalar(self: TInt, other: int) -> TInt: + """bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_left_shift(self, other_tensor) - result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT8.dtype) +@torch_op("aten::bitwise_left_shift.Scalar_Tensor", trace_only=True) +def aten_bitwise_left_shift_scalar_tensor(self: int, other: TInt) -> TInt: + """bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_left_shift(self_tensor, other) -@torch_op("aten::bitwise_not") -def aten_bitwise_not(self: TInt) -> TInt: +@torch_op("aten::bitwise_not", trace_only=True) +def aten_bitwise_not(self: TTensor) -> TTensor: """bitwise_not(Tensor self) -> Tensor""" - # logical_not implements the BOOL variant - return op.BitwiseNot(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + if self.dtype.is_integer(): + return op.BitwiseNot(self) + raise NotImplementedError(f"Not implemented for type {self.dtype}") @torch_op( ( - "aten::bitwise_or", "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", "_operator::or_", - ) + ), + trace_only=True, ) -def aten_bitwise_or(self: TInt, other: TInt) -> TInt: +def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_or implements the BOOL variant - - return op.BitwiseOr(self, other) + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None -@torch_op("aten::bitwise_right_shift") -def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) + if dtype.is_integer(): + return op.BitwiseOr(self, other) + if dtype == ir.DataType.BOOL: + return op.Or(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFF), to=UINT16.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT16.dtype), op.Cast(shifted, to=INT16.dtype) - ) +@torch_op("aten::bitwise_or.Scalar", trace_only=True) +def aten_bitwise_or_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_or(self, other_tensor) -@torch_op("aten::bitwise_right_shift") -def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFFFFFF), to=UINT32.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT32.dtype), op.Cast(shifted, to=INT32.dtype) - ) +@torch_op("aten::bitwise_or.Scalar_Tensor", trace_only=True) +def aten_bitwise_or_scalar_tensor(self: int, other: TTensor) -> TTensor: + """bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_or(self_tensor, other) -@torch_op("aten::bitwise_right_shift") -def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: +@torch_op( + ( + "aten::bitwise_right_shift.Tensor", + "_operator::__rshift__", + ), + trace_only=True, +) +def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - # 0xFFFFFFFFFFFFFFFF - op.Cast(op.Constant(value_int=-1), to=UINT64.dtype), - other, - direction="RIGHT", - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT64.dtype), op.Cast(shifted, to=INT64.dtype) - ) - + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + + if dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + mask = ir.tensor(0xFF, dtype=unsigned_dtype) + elif dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + mask = ir.tensor(0xFFFF, dtype=unsigned_dtype) + elif dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype) + elif dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF + else: + raise NotImplementedError(f"Not implemented for type {dtype}") -@torch_op("aten::bitwise_right_shift") -def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) # Simulate arithmetic shift using logical shift # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFF), to=UINT8.dtype), other, direction="RIGHT" - ) + mask = op.BitShift(mask, other, direction="RIGHT") mask = op.BitwiseNot(mask) # Do logical shift shifted = op.BitShift(self, other, direction="RIGHT") @@ -1344,29 +1399,68 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: negative_shifted = op.BitwiseOr(shifted, mask) # Choose the shifted value based on the sign bit return op.Where( - negative, op.Cast(negative_shifted, to=INT8.dtype), op.Cast(shifted, to=INT8.dtype) + negative, op.Cast(negative_shifted, to=signed_dtype), op.Cast(shifted, to=signed_dtype) ) @torch_op( - ( - "aten::bitwise_xor", - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ) + ("aten::bitwise_right_shift.Tensor_Scalar", "aten::__rshift__.Scalar"), trace_only=True ) -def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: +def aten_bitwise_right_shift_tensor_scalar(self: TInt, other: int) -> TInt: + """bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_right_shift(self, other_tensor) + + +@torch_op("aten::bitwise_right_shift.Scalar_Tensor", trace_only=True) +def aten_bitwise_right_shift_scalar_tensor(self: int, other: TInt) -> TInt: + """bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_right_shift(self_tensor, other) + + +@torch_op("aten::bitwise_xor.Tensor", trace_only=True) +def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor: """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_xor implements the BOOL variant - return op.BitwiseXor(self, other) + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + + if dtype.is_integer(): + return op.BitwiseXor(self, other) + if dtype == ir.DataType.BOOL: + return op.Xor(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") + + +@torch_op("aten::bitwise_xor.Scalar", trace_only=True) +def aten_bitwise_xor_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_xor(self, other_tensor) + +@torch_op("aten::bitwise_xor.Scalar_Tensor", trace_only=True) +def aten_bitwise_xor_scalar_tensor(self: int, other: TTensor) -> TTensor: + """bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_xor(self_tensor, other) -def aten_blackman_window(window_length: int) -> TensorType: + +@torch_op("aten::blackman_window", trace_only=True) +def aten_blackman_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + return op.BlackmanWindow(window_length, output_datatype=dtype) def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType: @@ -1375,7 +1469,7 @@ def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::bmm") +@torch_op("aten::bmm", trace_only=True) def aten_bmm(self: TFloat, mat2: TFloat) -> TFloat: """bmm(Tensor self, Tensor mat2) -> Tensor""" @@ -1388,10 +1482,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::broadcast_to") -def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor: +@torch_op("aten::broadcast_to", trace_only=True) +def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor: """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - + size = common_ops.merge_dims(size) return op.Expand(self, size) @@ -1424,15 +1518,13 @@ def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: return aten_cat(tensors, dim=dim) -@torch_op("aten::cat") +@torch_op(("aten::cat", "aten::concat", "aten::concatenate"), trace_only=True) def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" - # NOTE: Having empty tensors when concatenating along non-zero dimension - # is not supported. - # TODO(justinchuby): Filter these tensors out with Sequence ops before - # calling ConcatFromSequence. - return op.ConcatFromSequence(tensors, axis=dim) + # Remove None tensors + tensors = [tensor for tensor in tensors if tensor is not None] + return op.Concat(*tensors, axis=dim) def aten_ccol_indices(self: TensorType) -> TensorType: @@ -1455,14 +1547,14 @@ def aten_cdist( raise NotImplementedError() -@torch_op("aten::ceil") +@torch_op("aten::ceil", trace_only=True) def aten_ceil(self: TFloat) -> TFloat: """ceil(Tensor self) -> Tensor""" return op.Ceil(self) -@torch_op("math::ceil") +@torch_op("math::ceil", trace_only=True) def python_math_ceil(self: TFloat) -> TInt: """ceil(Tensor self) -> Tensor""" ceil = op.Ceil(self) @@ -1515,38 +1607,68 @@ def aten_choose_qparams_optimized( raise NotImplementedError() -@torch_op("aten::chunk") -def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: - """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" - # This will create a Sequence of tensors - neg_1 = op.Constant(value_ints=[-1]) - # Get size of specified dim - self_shape = op.Shape(self) - dim_size = op.Gather(self_shape, dim, axis=0) - # Compute size/chunk to get the number of data in one chunk - num_per_chunk = op.Div(dim_size, chunks) - num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] +if version_utils.torch_older_than("2.7.0"): + # PyTorch <2.7 does not support determining the number of outputs for the Split op + # https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6 + @torch_op("aten::chunk") + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" + # This will create a Sequence of tensors + neg_1 = op.Constant(value_ints=[-1]) + # Get size of specified dim + self_shape = op.Shape(self) + dim_size = op.Gather(self_shape, dim, axis=0) + # Compute size/chunk to get the number of data in one chunk + num_per_chunk = op.Div(dim_size, chunks) + num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] - # Compute real chunk number - num_chunk = op.Div(dim_size, num_per_chunk) - # Get something like [n, n, n, n, ...], total num_chunk - list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) + # Compute real chunk number + num_chunk = op.Div(dim_size, num_per_chunk) + # Get something like [n, n, n, n, ...], total num_chunk + list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) - remainder = op.Mod(dim_size, num_per_chunk) - if remainder > 0: # type: ignore[operator] - # Append the remainder to the [n, n, n, n, ..., r] - list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) + remainder = op.Mod(dim_size, num_per_chunk) + if remainder > 0: # type: ignore[operator] + # Append the remainder to the [n, n, n, n, ..., r] + list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) - return op.SplitToSequence(self, list_split, axis=dim) + return op.SplitToSequence(self, list_split, axis=dim) +else: + @torch_op("aten::chunk", trace_only=True) + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" + if chunks == 1: + return op.Identity(self) + return op.Split(self, axis=dim, num_outputs=chunks) -@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True) -def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal: - """clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" - clamped = self + +@torch_op("aten::clamp", trace_only=True) +def aten_clamp(self: TReal, min: Optional[float] = None, max: Optional[float] = None) -> TReal: + """clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor""" + + if min is None and max is None: + return op.Identity(self) + + if min is not None: + min = op.CastLike(min, self) + + if max is not None: + max = op.CastLike(max, self) + + return op.Clip(self, min, max) + + +@torch_op("aten::clamp.Tensor", trace_only=True) +def aten_clamp_tensor( + self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None +) -> TReal: + """clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" if min is None and max is None: - return clamped + return op.Identity(self) + + clamped = self # If min is greater than max torch.clamp(..., min, max) # sets all elements in input to the value of max. @@ -1562,48 +1684,58 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = return clamped -@torch_op("aten::clamp_max", traceable=True) -def aten_clamp_max(self: TReal, max_: TReal) -> TReal: - """clamp_max(Tensor self, Tensor max) -> Tensor""" +@torch_op("aten::clamp_max", trace_only=True) +def aten_clamp_max(self: TReal, max_: float) -> TReal: + """clamp_max(Tensor self, Scalar max) -> Tensor""" - self_size = op.Size(self) - max_shape = op.Shape(max_) - max_rank = op.Size(max_shape) - if self_size == 0: - result = op.Expand(self, max_shape) + # This implementation does not intend to handle when self is an empty tensor + max_ = op.CastLike(max_, self) + return op.Clip(self, None, max_) + + +@torch_op("aten::clamp_max.Tensor", trace_only=True) +def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal: + """clamp_max.Tensor(Tensor self, Tensor max) -> Tensor""" + + # This implementation does not intend to handle when self is an empty tensor + max_rank = len(max_.shape) + if max_rank == 0: + max_ = op.CastLike(max_, self) + result = op.Clip(self, None, max_) else: - if max_rank == 0: - max_ = op.CastLike(max_, self) - result = op.Clip(self, None, max_) - else: - result = op.Min(self, max_) + result = op.Min(self, max_) return result -@torch_op("aten::clamp_min", traceable=True) -def aten_clamp_min(self: TReal, min_: TReal) -> TReal: - """clamp_min(Tensor self, Tensor min) -> Tensor""" +@torch_op("aten::clamp_min", trace_only=True) +def aten_clamp_min(self: TReal, min_: float) -> TReal: + """clamp_min(Tensor self, Scalar min) -> Tensor""" - self_size = op.Size(self) - min_shape = op.Shape(min_) - min_rank = op.Size(min_shape) - if self_size == 0: - result = op.Expand(self, min_shape) + # This implementation does not intend to handle when self is an empty tensor + min_ = op.CastLike(min_, self) + return op.Clip(self, min_, None) + + +@torch_op("aten::clamp_min.Tensor", trace_only=True) +def aten_clamp_min_tensor(self: TReal, min_: TReal) -> TReal: + """clamp_min.Tensor(Tensor self, Tensor min) -> Tensor""" + + # This implementation does not intend to handle when self is an empty tensor + min_rank = len(min_.shape) + if min_rank == 0: + min_ = op.CastLike(min_, self) + result = op.Clip(self, min_, None) else: - if min_rank == 0: - min_ = op.CastLike(min_, self) - result = op.Clip(self, min_, None) - else: - result = op.Max(self, min_) + result = op.Max(self, min_) return result -@torch_op("aten::clone") +@torch_op("aten::clone", trace_only=True) def aten_clone( self: TTensor, - memory_format: str = "", # pylint: disable=unused-argument + memory_format: str = "", ) -> TTensor: """clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor""" @@ -1642,13 +1774,6 @@ def aten_combinations( raise NotImplementedError() -@torch_op("aten::complex", private=True) -def _aten_complex(real: TFloat, imag: TFloat) -> TFloat: - """Non-broadcasting complex constructor.""" - - return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1) - - @torch_op("aten::complex", trace_only=True) def aten_complex(real: TFloat, imag: TFloat) -> TFloat: """complex(Tensor real, Tensor imag) -> Tensor""" @@ -1658,33 +1783,16 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat: real = op.Expand(real, broadcasted_shape) imag = op.Expand(imag, broadcasted_shape) - return _aten_complex(real, imag) - - -@torch_op("aten::concat") -def aten_concat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: - """concat(Tensor[] tensors, int dim=0) -> Tensor""" - - # TODO(justinchuby): Combine the implementation with cat - return op.ConcatFromSequence(tensors, axis=dim) - - -@torch_op("aten::concatenate") -def aten_concatenate(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: - """concatenate(Tensor[] tensors, int dim=0) -> Tensor""" - - # TODO(justinchuby): Combine the implementation with cat - return op.ConcatFromSequence(tensors, axis=dim) + return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1) -@torch_op("aten::conj") +@torch_op("aten::conj", trace_only=True) def aten_conj(self: TTensor) -> TTensor: """conj(Tensor(a) self) -> Tensor(a)""" return op.Identity(self) -@torch_op("aten::conj", complex=True, private=True) def _complex_conjugate(self: TFloat) -> TFloat: zero = op.Constant(value_ints=[0]) one = op.Constant(value_ints=[1]) @@ -1703,8 +1811,6 @@ def _complex_conjugate(self: TFloat) -> TFloat: def aten_conj_complex(self: TFloat) -> TFloat: """conj(Tensor(a) self) -> Tensor(a)""" - # TODO(#834): Allow calling scripted functions from other - # scripted functions and remove trace only. return _complex_conjugate(self) @@ -1749,10 +1855,10 @@ def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTens return op.Pad(self, onnx_padding, value) -@torch_op("aten::contiguous") +@torch_op("aten::contiguous", trace_only=True) def aten_contiguous( self: TTensor, - memory_format: str = "contiguous_format", # pylint: disable=unused-argument + memory_format: str = "contiguous_format", ) -> TTensor: """contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)""" @@ -1940,24 +2046,32 @@ def aten_convolution( ) -> TFloat: """convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor""" + rank = len(input.shape) + + image_d = rank - 2 + + # NOTE: We assume the sequence padding/dilation/stride + # from ATen op can only be either len == 1 or + # len == rank. + if not isinstance(padding, Sequence): - padding = (padding, padding) + padding = [padding] * image_d + elif len(padding) == 1: + padding = [padding[0]] * image_d pads = [*padding, *padding] if not isinstance(dilation, Sequence): - dilation = (dilation, dilation) + dilation = [dilation] * image_d + elif len(dilation) == 1: + dilation = [dilation[0]] * image_d dilations = list(dilation) if not isinstance(stride, Sequence): - stride = (stride, stride) + stride = [stride] * image_d + elif len(stride) == 1: + stride = [stride[0]] * image_d strides = list(stride) - if bias is None: - weight_dim_0 = op.Shape(weight, start=0, end=1) - bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1])) - zero = op.CastLike(0.0, input) - bias = op.Expand(zero, bias_shape) - result = _aten_convolution_onnx( input, weight, @@ -1973,12 +2087,11 @@ def aten_convolution( return result -@torch_op("aten::convolution", private=True, traceable=True) def _aten_convolution_onnx( input: TFloat, weight: TFloat, bias: TFloat, - transposed: BOOL, + transposed: bool, strides: Sequence[int], pads: Sequence[int], dilations: Sequence[int], @@ -1992,7 +2105,7 @@ def _aten_convolution_onnx( # Alternatively we could cast transposed to BOOL. # E.g. `if op.Cast(transposed, BOOL.dtype): ...` - no_batch = Rank(input) != Rank(weight) + no_batch = len(input.shape) != len(weight.shape) if no_batch: input = op.Unsqueeze(input, op.Constant(value_ints=[0])) @@ -2076,11 +2189,11 @@ def aten_convolution_overrideable( raise NotImplementedError() -@torch_op("aten::copy") +@torch_op("aten::copy", trace_only=True) def aten_copy( self: TTensor, src: TTensor2, - non_blocking: bool = False, # pylint: disable=unused-argument + non_blocking: bool = False, ) -> TTensor: """copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor""" @@ -2091,7 +2204,11 @@ def aten_copy( def aten__to_copy( self: TTensor, dtype: int = -1, - non_blocking: bool = False, # pylint: disable=unused-argument + layout: str = "", + device: str = "", + pin_memory: bool = False, + non_blocking: bool = False, + memory_format: str = "", ) -> TTensor: """_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor""" @@ -2113,14 +2230,14 @@ def aten_corrcoef(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::cos") +@torch_op("aten::cos", trace_only=True) def aten_cos(self: TFloat) -> TFloat: """cos(Tensor self) -> Tensor""" return op.Cos(self) -@torch_op("aten::cosh") +@torch_op("aten::cosh", trace_only=True) def aten_cosh(self: TFloat) -> TFloat: """cosh(Tensor self) -> Tensor""" @@ -2164,23 +2281,13 @@ def aten_cov( raise NotImplementedError() -@torch_op("aten::cross") +@torch_op(("aten::cross", "aten::linalg_cross")) def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor: """cross(Tensor self, Tensor other, int? dim=None) -> Tensor""" - zero = op.Constant(value_ints=[0]) - one = op.Constant(value_ints=[1]) - two = op.Constant(value_ints=[2]) - three = op.Constant(value_ints=[3]) - axes = op.Expand(dim, op.Constant(value_ints=[1])) - # Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073 - a1 = op.Slice(self, zero, one, axes) - a2 = op.Slice(self, one, two, axes) - a3 = op.Slice(self, two, three, axes) - b1 = op.Slice(other, zero, one, axes) - b2 = op.Slice(other, one, two, axes) - b3 = op.Slice(other, two, three, axes) + a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3) + b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3) # Broadcasting is implicitly supported by Mul c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2)) c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3)) @@ -2390,18 +2497,11 @@ def aten_cumsum( cast = self else: cast = op.Cast(self, to=dtype) - return _aten_cumsum_onnx(cast, dim) - - -@torch_op("aten::cumsum", private=True, traceable=True) -def _aten_cumsum_onnx( - self: TRealUnlessInt16OrInt8, dim: Union[INT32, INT64] -) -> TRealUnlessInt16OrInt8: - if IsScalar(self): + if len(self.shape) == 0: # A scalar - result = op.Identity(self) + result = op.Identity(cast) else: - result = op.CumSum(self, dim) + result = op.CumSum(cast, dim) return result @@ -2411,7 +2511,7 @@ def aten_data(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::deg2rad", traceable=True) +@torch_op("aten::deg2rad", trace_only=True) def aten_deg2rad(self: TFloat) -> TFloat: """deg2rad(Tensor self) -> Tensor""" @@ -2424,7 +2524,7 @@ def aten_dense_dim(self: TensorType) -> int: raise NotImplementedError() -@torch_op("aten::detach") +@torch_op("aten::detach", trace_only=True) def aten_detach(self: TensorType) -> TensorType: """detach(Tensor(a) self) -> Tensor(a)""" @@ -2458,87 +2558,10 @@ def aten_diagflat(self: TensorType, offset: int = 0) -> TensorType: @torch_op(("aten::diagonal", "aten::diagonal_copy"), trace_only=True) -def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TReal: +def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TTensor: """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" - # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims - # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 - # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2 - # [0,1,2] -> [0,1,2] when dim1=1 and dim2=2 - if dim1 < 0: - dim1 = dim1 + len(self.shape) - if dim2 < 0: - dim2 = dim2 + len(self.shape) - - self_rank = len(self.shape) - perm = list(range(self_rank)) - perm.remove(dim1) - perm.remove(dim2) - perm.append(dim1) - perm.append(dim2) - - # If rank=2, then axes=[0]; if rank=3, then axes=[1] - # This is because computing diagonal sum is on dim2 after transpose by perm - axes = [self_rank - 2] - - return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True, traceable=True) -def _aten_diagonal_onnx( - self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> TTensor: - neg_1 = op.Constant(value_ints=[-1]) - dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row - dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col - mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) - mask = op.CastLike(mask, self) - self_t = op.Transpose(self, perm=perm) - result = op.Mul(self_t, mask) - result = op.ReduceSum(result, keepdims=False, axes=axes) - # min(row, col) - min_dim_size = op.Min(dim1_size, dim2_size) - # take 2 tensors as example: - # one is 3x5 in size, min_dim_size = 3, dim1_size = 3 - # the other is 5x3 in size, min_dim_size = 3, dim1_size = 5 - # 3 rows x 5 cols 5 rows x 3 cols - # offset diagonal offset diagonal - # ---------------- ---------------- - # -4 0 -6 0 - # -3 0 -5 0 - # -2 1 -4 1 - # -1 2 -3 2 - # 0 3 -2 3 - # 1 3 -1 3 - # 2 3 0 3 - # 3 2 1 2 - # 4 1 2 1 - # 5 0 3 0 - # 6 0 4 0 - - # From above table, we can get the logic below - if offset < 0: - # row + offset - length = dim1_size + offset - start = op.Constant(value_ints=[0]) - else: # offset >= 0 - # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) - - # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length - result = op.Slice(result, start, end, axes=axes) - - return result - - -@torch_op("aten::diagonal", trace_only=True) -def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1) -> BOOL: - """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" + is_bool = self.dtype == BOOL.dtype # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 @@ -2560,23 +2583,21 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True) -def _aten_diagonal_bool_onnx( - self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> BOOL: neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) - self_int = op.Cast(self, to=INT64.dtype) - mask_int = op.Cast(mask, to=INT64.dtype) - self_int_t = op.Transpose(self_int, perm=perm) - result = op.Mul(self_int_t, mask_int) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) + + if is_bool: + self_int = op.Cast(self, to=INT64.dtype) + mask_int = op.Cast(mask, to=INT64.dtype) + self_int_t = op.Transpose(self_int, perm=perm) + result = op.Mul(self_int_t, mask_int) + else: + mask = op.CastLike(mask, self) + self_t = op.Transpose(self, perm=perm) + result = op.Mul(self_t, mask) result = op.ReduceSum(result, keepdims=False, axes=axes) # min(row, col) min_dim_size = op.Min(dim1_size, dim2_size) @@ -2599,20 +2620,23 @@ def _aten_diagonal_bool_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset_val # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) - result = op.Cast(result, to=BOOL.dtype) + + if is_bool: + result = op.Cast(result, to=BOOL.dtype) return result @@ -2667,17 +2691,14 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType @torch_op( ( - "aten::div", "aten::div.Tensor", "aten::div.Scalar", - # When rounding_mode is None, performs a true division - # https://pytorch.org/docs/stable/generated/torch.div.html - "aten::div.Tensor_mode", - "aten::div.Scalar_mode", - "aten::divide", - "aten::true_divide", - "_operator::truediv", - ) + "aten::divide.Tensor", + "aten::divide.Scalar", + "aten::true_divide.Tensor", + "aten::true_divide.Scalar", + ), + trace_only=True, ) def aten_div(self: TFloat, other: TFloat) -> TFloat: """div.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -2686,14 +2707,19 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat: return op.Div(self, other) +@torch_op("_operator::truediv", trace_only=True) +def operator_truediv(self: TensorType, other: TensorType) -> FLOAT: + return op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + + @torch_op( ( - "aten::div", "aten::div.Tensor", "aten::div.Scalar", - "aten::divide", - "aten::true_divide", - "_operator::truediv", + "aten::divide.Tensor", + "aten::divide.Scalar", + "aten::true_divide.Tensor", + "aten::true_divide.Scalar", ), complex=True, ) @@ -2721,55 +2747,51 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: +def aten_div_mode(self: TReal, other: TReal, rounding_mode: Optional[str] = None) -> TReal: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" - # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison - assert rounding_mode in {"trunc", "floor"} - - if rounding_mode == "trunc": - # Rounds the results of the division towards zero. - # Equivalent to C-style integer division - result = aten_trunc(op.Div(self, other)) - else: # rounding_mode == "floor" - result = op.Floor(op.Div(self, other)) - - return result + assert rounding_mode in {"trunc", "floor", None} + if self.dtype.is_integer(): + quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) -@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode_int(self: TInt, other: TInt, rounding_mode: str) -> TInt: - """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + if rounding_mode == "trunc": + # Rounds the results of the division towards zero. + # Equivalent to C-style integer division + result = aten_trunc(quotient) + return op.CastLike(result, self) + if rounding_mode == "floor": + result = op.Floor(quotient) + return op.CastLike(result, self) - Variant for integer inputs. - """ - # TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison - assert rounding_mode in {"trunc", "floor"} + assert rounding_mode is None + # When rounding_mode is None, the return type is float32 + return quotient - quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + # Float inputs if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - result = aten_trunc(quotient) - else: # rounding_mode == "floor" - result = op.Floor(quotient) + return aten_trunc(op.Div(self, other)) + if rounding_mode == "floor": + return op.Floor(op.Div(self, other)) - return op.CastLike(result, self) + return op.Div(self, other) -@torch_op("aten::dot") +@torch_op("aten::dot", trace_only=True) def aten_dot(self: TFloat, tensor: TFloat) -> TFloat: """dot(Tensor self, Tensor tensor) -> Tensor""" return op.MatMul(self, tensor) -@torch_op("aten::dropout", traceable=True) +@torch_op("aten::dropout", trace_only=True) def aten_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat: """dropout(Tensor input, float p, bool train) -> Tensor""" - if IsScalar(input): + if len(input.shape) == 0: input = op.Reshape(input, op.Constant(value_ints=[-1])) result, _ = op.Dropout(input, p, train) result = op.Squeeze(result) @@ -2789,7 +2811,7 @@ def aten_dstack(tensors: Sequence[TensorType]) -> TensorType: def aten_einsum( equation: str, tensors: Sequence[TReal], - path: Optional[int] = None, # pylint: disable=unused-argument + path: Optional[int] = None, ) -> TReal: """einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor""" @@ -2797,14 +2819,14 @@ def aten_einsum( return op.Einsum(*tensors, equation=equation) -@torch_op("aten::embedding") +@torch_op("aten::embedding", trace_only=True) def aten_embedding( weight: TTensor, - indices: TTensor, + indices: TInt, padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False, -): # pylint: disable=unused-argument +) -> TTensor: # embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor return op.Gather(weight, indices) @@ -2828,9 +2850,9 @@ def aten_embedding_bag( weight: TFloat, indices: INT64, offsets: INT64, - scale_grad_by_freq: bool = False, # pylint: disable=unused-argument + scale_grad_by_freq: bool = False, mode: int = 0, # [0,1,2] indicate ["sum", "mean", "max"] - sparse: bool = False, # pylint: disable=unused-argument + sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: @@ -2867,8 +2889,8 @@ def _aten_embedding_bag_onnx( indices_1d = op.Reshape(indices, neg_1) # Get weight out according to indices_1d, new_weight = op.Gather(weight, indices_1d) - # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) - new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1)) + # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) + new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=[1])) weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1) indices_size = op.Shape(indices_1d) @@ -2962,9 +2984,9 @@ def aten_embedding_bag_padding_idx( weight: TFloat, indices: INT64, offsets: INT64, - scale_grad_by_freq: bool = False, # pylint: disable=unused-argument + scale_grad_by_freq: bool = False, mode: int = 0, # [0,1,2] indicate ["sum", "mean", "max"] - sparse: bool = False, # pylint: disable=unused-argument + sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, padding_idx: int = -1, @@ -2974,9 +2996,9 @@ def aten_embedding_bag_padding_idx( We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert ( - padding_idx is not None - ), "padding_idx must not be None. This is likely a dispatcher error" + assert padding_idx is not None, ( + "padding_idx must not be None. This is likely a dispatcher error" + ) if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) @@ -3007,8 +3029,8 @@ def _aten_embedding_bag_1d_padding_idx_onnx( # Get weight out according to indices, # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]] indices_weight = op.Gather(weight, indices) - # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) - indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1)) + # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights) + indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=[1])) # The element in sequence must be FLOAT32 dtype due to ORT bug indices_weight = op.Cast(indices_weight, to=FLOAT.dtype) @@ -3101,7 +3123,7 @@ def aten_embedding_dense_backward( raise NotImplementedError() -@torch_op("aten::embedding_renorm", traceable=True) +@torch_op("aten::embedding_renorm", trace_only=True) def aten_embedding_renorm( weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0 ) -> TFloat: @@ -3150,35 +3172,42 @@ def aten_embedding_sparse_backward( raise NotImplementedError() -@torch_op(("aten::empty", "aten::empty.memory_format")) -def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var] - # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor +@torch_op("aten::empty.memory_format", trace_only=True) +def aten_empty( + size: Sequence[INT64], + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, + memory_format: str = "", +) -> TensorType: # type: ignore[type-var] + """empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + if dtype == -1: + dtype = FLOAT.dtype - # using Zeros to simulate np.empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + # using Zeros to simulate empty() + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @torch_op("aten::empty_like", trace_only=True) -def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor: +def aten_empty_like( + self: TTensor, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, + memory_format: str = "", +) -> TTensor: """empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - if dtype == -1: + if dtype == -1 or dtype is None: zero = op.CastLike(0, self) else: zero = op.Cast(0, to=dtype) - return _aten_empty_like_onnx(self, zero) - - -@torch_op("aten::empty_like", private=True) -def _aten_empty_like_onnx(self: TTensor, zero) -> TTensor: shape = op.Shape(self) return op.Expand(zero, shape) @@ -3191,28 +3220,32 @@ def aten_empty_quantized( raise NotImplementedError() -@torch_op("aten::empty_strided") +@torch_op("aten::empty_strided", trace_only=True) def aten_empty_strided( - size: INT64, - stride: INT64, # pylint: disable=unused-argument + size: Sequence[INT64], + stride: INT64, + layout: str = "", + dtype: int = FLOAT.dtype, + device: str = "", + pin_memory: bool = False, ) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) -@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar")) +@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq"), trace_only=True) def aten_eq(self: TTensor, other: TTensor) -> BOOL: """eq.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Equal(self, other) -@torch_op("aten::equal") +@torch_op("aten::equal", trace_only=True) def aten_equal(self: TTensor, other: TTensor) -> BOOL: """equal(Tensor self, Tensor other) -> bool""" @@ -3231,14 +3264,14 @@ def aten_erfinv(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::exp") +@torch_op("aten::exp", trace_only=True) def aten_exp(self: TFloat) -> TFloat: """exp(Tensor self) -> Tensor""" return op.Exp(self) -@torch_op("aten::exp2", traceable=True) +@torch_op("aten::exp2", trace_only=True) def aten_exp2(self: TFloat) -> TFloat: """exp2(Tensor self) -> Tensor""" @@ -3247,17 +3280,18 @@ def aten_exp2(self: TFloat) -> TFloat: return op.Pow(two, self) -@torch_op("aten::expand") -def aten_expand(self: TTensor, size: TInt) -> TTensor: +@torch_op("aten::expand", trace_only=True) +def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor: """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. - size = op.Abs(size) - return op.Expand(self, size) + # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely + # and isn't expected to appear from correct usages of SymInt. + size = [1 if isinstance(s, int) and s == -1 else s for s in size] + return op.Expand(self, common_ops.merge_dims(size)) -@torch_op("aten::expand_as", traceable=True) +@torch_op("aten::expand_as", trace_only=True) def aten_expand_as(self: TTensor, other: TTensor) -> TTensor: """expand_as(Tensor(a) self, Tensor other) -> Tensor(a)""" @@ -3272,12 +3306,6 @@ def aten_expand_copy(self: TensorType, size: INT64, implicit: bool = False) -> T raise NotImplementedError() -def aten_expm1(self: TensorType) -> TensorType: - """expm1(Tensor self) -> Tensor""" - - raise NotImplementedError() - - def aten_eye(n: int) -> TensorType: """eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -3418,15 +3446,14 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType raise NotImplementedError() -@torch_op(("aten::fill", "aten::fill.Tensor")) -def aten_fill(self: TTensor, value: TTensor) -> TTensor: +@torch_op(("aten::fill.Tensor", "aten::fill.Scalar")) +def aten_fill(self: TTensor, value: TTensor2) -> TTensor: """fill.Tensor(Tensor self, Tensor value) -> Tensor""" - # after fill, the self Tensor should keep origianl type + # Cast the value before Expand so it can be constant folded + value = op.CastLike(value, self) shape = op.Shape(self) - expanded = op.Expand(value, shape) - result = op.CastLike(expanded, self) - return result + return op.Expand(value, shape) def aten_fix(self: TensorType) -> TensorType: @@ -3435,17 +3462,63 @@ def aten_fix(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::flip") -def aten_flip(self: TTensor, dims: INT64) -> TTensor: +@torch_op("aten::flatten.using_ints", trace_only=True) +def aten_flatten(self: TTensor, start_dim: int = 0, end_dim: int = -1) -> TTensor: + """flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)""" + dim = len(self.shape) + if dim == 1: + return op.Identity(self) + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1: + if end_dim in (-1, dim - 1): + return op.Flatten(self, axis=start_dim) + elif start_dim == 0: + if end_dim in (-2, dim - 2): + return op.Flatten(self, axis=end_dim + 1) + + # if end_dim is negative add dim + if end_dim < 0: + end_dim = dim + end_dim + + input_size = op.Shape(self) + dim_head = op.Slice( + input_size, + op.Constant(value_ints=[0]), + op.Constant(value_ints=[start_dim]), + op.Constant(value_ints=[0]), + ) + final_dims = [dim_head, op.Constant(value_ints=[-1])] + if end_dim < dim - 1: + dim_tail = op.Slice( + input_size, + op.Constant(value_ints=[end_dim + 1]), + op.Constant(value_ints=[dim]), + op.Constant(value_ints=[0]), + ) + final_dims = [ + dim_head, + op.Constant(value_ints=[-1]), + dim_tail, + ] + + final_shape = op.Concat(*final_dims, axis=0) + return op.Reshape(self, final_shape) + + +@torch_op("aten::flip", trace_only=True) +def aten_flip(self: TTensor, dims: Sequence[int]) -> TTensor: """flip(Tensor self, int[] dims) -> Tensor""" - shape_dim = op.Shape(dims) - neg_1 = op.Constant(value_int=-1) - starts = op.Expand(neg_1, shape_dim) # something like [-1, -1, -1] - steps = op.Expand(neg_1, shape_dim) # something like [-1, -1, -1] - ends = op.Expand(_INT64_MIN, shape_dim) # something like [-xxx, -xxx, -xxx] - result = op.Slice(self, starts, ends, dims, steps) - return result + if not dims: + # Nothing to flip + return op.Identity(self) + + rank = len(dims) + starts = op.Constant(value_ints=[-1] * rank) # something like [-1, -1, -1] + steps = starts # something like [-1, -1, -1] + ends = op.Constant(value_ints=[_INT64_MIN] * rank) # something like [-xxx, -xxx, -xxx] + dims = op.Constant(value_ints=dims) + return op.Slice(self, starts, ends, dims, steps) def aten_fliplr(self: TensorType) -> TensorType: @@ -3460,25 +3533,49 @@ def aten_flipud(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::floor") -def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::floor", trace_only=True) +def aten_floor(self: TFloat) -> TFloat: """floor(Tensor self) -> Tensor""" return op.Floor(self) -@torch_op("math::floor") -def python_math_floor(self: TFloatOrBFloat16) -> TInt: +@torch_op("math::floor", trace_only=True) +def python_math_floor(self: TFloat) -> TInt: """floor(Tensor self) -> Tensor""" floor = op.Floor(self) return op.Cast(floor, to=INT64.dtype) -@torch_op(("aten::floor_divide", "_operator::floordiv")) -def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: +@torch_op("aten::floor_divide", trace_only=True) +def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: """floor_divide(Tensor self, Tensor other) -> Tensor""" - return op.Floor(op.Div(self, other)) + if self.dtype.is_floating_point(): + return op.Floor(op.Div(self, other)) + + assert self.dtype.is_integer() + + if not self.dtype.is_signed(): + return op.Div(self, other) + + # Convert truncation to flooring + # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70 + # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + offset = op.And( + op.Not(op.Equal(op.Sign(self), op.Sign(other))), + op.Cast(op.Mod(self, other), to=BOOL.dtype), + ) + offset = op.Cast(offset, to=self.dtype) + return op.Sub(op.Div(self, other), offset) + + +@torch_op("_operator::floordiv", trace_only=True) +def operator_floordiv(self: INT64, other: INT64) -> INT64: + # We implement floor_divide only for positive inputs (using integer division) + # because that is the usual intended case and is the most efficient. + return op.Div(self, other) def aten_fmax(self: TensorType, other: TensorType) -> TensorType: @@ -3493,14 +3590,14 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::fmod") +@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"), trace_only=True) def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8: """fmod.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Mod(self, other, fmod=1) -@torch_op("aten::frac") +@torch_op("aten::frac", trace_only=True) def aten_frac(self: TFloat) -> TFloat: """frac(Tensor self) -> Tensor @@ -3531,32 +3628,41 @@ def aten_from_file( raise NotImplementedError() -@torch_op("aten::full") -def aten_full(size: INT64, fill_value: FLOAT, dtype: int = FLOAT.dtype): +@torch_op("aten::full", trace_only=True) +def aten_full( + size: Union[INT64, INT32], + fill_value: TensorType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype != -1: + fill_value = op.Cast(fill_value, to=dtype) + size = op.Cast(size, to=INT64.dtype) - fill_value = op.Cast(fill_value, to=dtype) return op.Expand(fill_value, size) -@torch_op("aten::full_like") -def aten_full_like(self, fill_value: TTensor) -> TTensor: +@torch_op("aten::full_like", trace_only=True) +def aten_full_like( + self: TensorType, + fill_value: TensorType, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - fill_value = op.CastLike(fill_value, self) - self_shape = op.Shape(self) - - return op.Expand(fill_value, self_shape) - - -@torch_op("aten::full_like") -def aten_full_like_dtype(self, fill_value: TTensor, dtype: int) -> TTensor: - """full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + if dtype == -1: + fill_value = op.CastLike(fill_value, self) + else: + fill_value = op.Cast(fill_value, to=dtype) - fill_value = op.Cast(fill_value, to=dtype) self_shape = op.Shape(self) - return op.Expand(fill_value, self_shape) @@ -3580,25 +3686,30 @@ def aten_fused_moving_avg_obs_fake_quant( raise NotImplementedError() -@torch_op("aten::gather", traceable=True) +@torch_op("aten::gather", trace_only=True) def aten_gather( self: TReal, dim: int, index: TInt, - sparse_grad: bool = False, # pylint: disable=unused-argument + sparse_grad: bool = False, ) -> TReal: """gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor""" - if IsScalar(index): # When (index) is empty, return (self) - result = self - else: - if IsScalar(self): # Unsqueeze for GatherElements op - self = op.Reshape(self, op.Constant(value_ints=[-1])) - if op.Size(index) == 0: # Return empty array - result = op.CastLike(index, self) + if len(self.shape) == 0: + if len(index.shape) == 0: + return op.Identity(self) else: - index = op.Cast(index, to=INT64.dtype) - result = op.GatherElements(self, index, axis=dim) + return op.Expand(self, op.Shape(index)) + + is_scalar_index = len(index.shape) == 0 + if is_scalar_index: + index = op.Unsqueeze(index, [0]) + + index = op.Cast(index, to=INT64.dtype) + result = op.GatherElements(self, index, axis=dim) + + if is_scalar_index: + result = op.Squeeze(result, [0]) return result @@ -3617,25 +3728,21 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal", "_operator::ge") + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), + trace_only=True, ) -def aten_ge(self: TReal, other: TReal) -> BOOL: +def aten_ge(self: TTensor, other: TTensor) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.GreaterOrEqual(self, other) - - -@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal")) -def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL: - """ge.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + # self, other, self >= other + # F, F, T + # F, T, F + # T, F, T + # T, T, T + return op.Or(self, op.Not(other)) - # self, other, self >= other - # F, F, T - # F, T, F - # T, F, T - # T, T, T - - return op.Or(self, op.Not(other)) + return op.GreaterOrEqual(self, other) def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: @@ -3650,8 +3757,7 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType: raise NotImplementedError() -# NOTE: The name is made up for `getitem` to be included in the registry -@torch_op("aten::getitem") +@torch_op(("_operator::getitem", "aten::getitem")) def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: return op.SequenceAt(self, i) @@ -3748,19 +3854,6 @@ def aten_grid_sampler_3d_backward( raise NotImplementedError() -def aten_group_norm( - input: TensorType, - num_groups: int, - weight: Optional[TensorType] = None, - bias: Optional[TensorType] = None, - eps: float = 1e-05, - cudnn_enabled: bool = True, -) -> TensorType: - """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor""" - - raise NotImplementedError() - - def aten_gru_cell( input: TensorType, hx: TensorType, @@ -3774,35 +3867,57 @@ def aten_gru_cell( raise NotImplementedError() -@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater", "_operator::gt")) -def aten_gt(self: TReal, other: TReal) -> BOOL: +@torch_op( + ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), + trace_only=True, +) +def aten_gt(self: TTensor, other: TTensor) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Greater(self, other) - + if self.dtype == ir.DataType.BOOL: + # self, other, self > other + # F, F, F + # F, T, F + # T, F, T + # T, T, F -@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater")) -def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: - """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self > other - # F, F, F - # F, T, F - # T, F, T - # T, T, F + return op.And(self, op.Not(other)) - return op.And(self, op.Not(other)) + return op.Greater(self, other) -def aten_hamming_window(window_length: int) -> TensorType: +@torch_op("aten::hamming_window", trace_only=True) +def aten_hamming_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() - - -def aten_hann_window(window_length: int) -> TensorType: + if dtype is None or dtype == -1: + dtype = 1 + # ONNX uses different alpha/beta values for the Hamming window + # Whereas PyTorch uses alpha=0.54, beta=0.46, ONNX uses + # alpha=0.543478, beta=0.456522. This causes a slight difference + # in the output values, but we still uses the HammingWindow op for performance. + return op.HammingWindow(window_length, output_datatype=dtype) + + +@torch_op("aten::hann_window", trace_only=True) +def aten_hann_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + return op.HannWindow(window_length, output_datatype=dtype) def aten_hardshrink(self: TensorType, lambd: float = 0.5) -> TensorType: @@ -3819,7 +3934,7 @@ def aten_hardshrink_backward( raise NotImplementedError() -@torch_op("aten::heaviside") +@torch_op("aten::heaviside", trace_only=True) def aten_heaviside(self: TReal, values: TReal) -> TReal: """heaviside(Tensor self, Tensor values) -> Tensor""" @@ -3864,7 +3979,7 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::hstack") +# Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918 def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: """hstack(Tensor[] tensors) -> Tensor""" @@ -3937,7 +4052,6 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64: return op.Shape(broadcasted) -@torch_op("aten::index.Tensor", private=True, trace_only=True) def _aten_index_onnx( self: TensorType, indices: Sequence[Optional[INT64]], @@ -3965,7 +4079,7 @@ def _aten_index_onnx( not_none_indices = [idx for idx in indices if idx is not None] broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) final_index = op.Concat( - *(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices), + *(op.Unsqueeze(op.Expand(idx, broadcast_shape), [-1]) for idx in not_none_indices), axis=-1, ) @@ -3974,7 +4088,7 @@ def _aten_index_onnx( if _has_none_in_middle(indices): # If there is None in the middle, Advanced Indexing cannot decide where to put # the new dimensions. So it places them in the front, like GatherND does. - return self + return op.Identity(self) # When the indices are consecutive, Advanced Indexing will place the new dimensions # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes. @@ -4103,7 +4217,7 @@ def aten_index_copy( raise NotImplementedError() -@torch_op(("aten::index_put", "aten::_unsafe_index_put")) +@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, indices: Sequence[INT64], @@ -4116,19 +4230,83 @@ def aten_index_put( `_. """ - # TODO(justinchuby): Handle when indicies has more than one element - index = op.SequenceAt(indices, 0) - new_index = op.Unsqueeze(index, [-1]) - - if op.Cast(accumulate, to=BOOL.dtype): - result = op.ScatterND(self, new_index, values, reduction="add") + def _make_reshape_list_broadcastable(reshape_list, values_shape): + # Remove ones until the rank of reshape_list matches values_shape. + while len(reshape_list) > len(values_shape) and 1 in reshape_list: + reshape_list.remove(1) + + # Now ensure each dimension is broadcastable: + # This is mandatory when mixing basic and advanced indexing + # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) + # the reshape list should be : [[2, 1], [1, 3], [2, 1]] + for i, r in enumerate(reshape_list): + if r not in (1, values_shape[i]): + value_index = values_shape.index(r) + # Swap elements + # For the example above the current reshape list is [1, 2] for last dim, + # to make it broadcastable, we swap the elements + reshape_list[value_index], reshape_list[i] = r, 1 + + return reshape_list + + # Ensure the number of indices matches the tensor rank. + self_rank = len(self.shape) + if len(indices) < self_rank: + indices = list(indices) + [None] * (self_rank - len(indices)) + + # Get values shape + values_shape = tuple(values.shape) + + index_vectors = [] + for i in range(self_rank): + if indices[i] is None: + # For a full slice along dim i, create a range index [0, self.shape[i]). + idx = op.Range(0, self.shape[i], 1) + reshape_update = self.shape[i] + else: + idx = indices[i] + reshape_update = math.prod(idx.shape) + # when Index is more than 1D, flatten it and also the values shape + # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) + # Indices -> (2*4,) and values shape (2*4, 32) + if len(idx.shape) > 1: + values_shape = (reshape_update, *values_shape[len(idx.shape) :]) + + # Flatten index (always working with 1D index in each dim) + idx = op.Reshape(idx, [-1]) + + # Create a reshape pattern: one value per index dimension, + # with the current dimension set to the update size. + reshape_list = [1] * len(indices) + reshape_list[i] = reshape_update + + # Adjust the reshape list to match the values shape. + reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) + + # Reshape and expand the index. + idx = op.Reshape(idx, reshape_list, allowzero=True) + idx = op.Expand(idx, values_shape) + + # Flatten the index to 1D and unsqueeze to form a column vector. + idx = op.Reshape(idx, [-1]) + idx = op.Unsqueeze(idx, axes=[1]) + index_vectors.append(idx) + + # Concatenate the index vectors along axis=1 to form the final indices. + new_index = op.Concat(*index_vectors, axis=1) + + # Flatten values to match the indices + flat_values = op.Reshape(values, [-1]) + + if accumulate: + result = op.ScatterND(self, new_index, flat_values, reduction="add") else: - result = op.ScatterND(self, new_index, values) + result = op.ScatterND(self, new_index, flat_values) return result -@torch_op("aten::index_put") +@torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( self: TReal, indices: Sequence[BOOL], @@ -4137,37 +4315,18 @@ def aten_index_put_bool( ) -> TReal: """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" - index = op.SequenceAt(indices, 0) # assume indices only have 1 element - # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it - index_int = op.Cast(index, to=INT32.dtype) - # if all False, return self - if op.ReduceSum(index_int) == 0: - result = self - else: - # change array([F,F,T,F,F]) to array([2]) - index = op.ArgMax(index_int) # assume index only have 1 True - # change array([2]) to array([2,2,2,2,2]) - self_dim_1 = op.Shape(self, start=1, end=2) - index_dim_0 = op.Shape(index, start=0, end=1) - shape = op.Concat(self_dim_1, index_dim_0, axis=0) - new_ind = op.Expand(index, shape) - new_ind_t = op.Transpose(new_ind) - - # values must have same rank with input(self) - if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator] - values = op.Unsqueeze(values, op.Constant(value_ints=[0])) - - if op.Cast(accumulate, to=BOOL.dtype): - zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) - zeros = op.CastLike(zeros, values) - result = op.ScatterElements(zeros, new_ind_t, values) - # FIXME: type promotion - result = op.CastLike(result, self) - result = op.Add(result, self) - else: - result = op.ScatterElements(self, new_ind_t, values) - - return result + # TODO: Support indices with more than 1 elements + index = indices[0] + # accumulate should be always False, True does not make sense but an assert would be great + # Reshape indices so it can be properly broadcasted + self_rank = len(self.shape) + index_rank = len(index.shape) + if self_rank > index_rank: + index_shape = op.Shape(index) + padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)]) + padded_shape = op.Concat(index_shape, padding, axis=0) + index = op.Reshape(index, padded_shape) + return op.Where(index, values, self) def aten_index_reduce( @@ -4183,11 +4342,11 @@ def aten_index_reduce( raise NotImplementedError() -@torch_op("aten::index_select", traceable=True) +@torch_op("aten::index_select", trace_only=True) def aten_index_select(self: TTensor, dim: int, index: IntType) -> TTensor: """index_select(Tensor self, int dim, Tensor index) -> Tensor""" - self_is_scalar = IsScalar(self) + self_is_scalar = len(self.shape) == 0 if self_is_scalar: self = op.Reshape(self, op.Constant(value_ints=[-1])) @@ -4258,12 +4417,15 @@ def aten_instance_norm( if use_input_stats: return op.InstanceNormalization(input, weight, bias, epsilon=eps) - assert ( - running_mean is not None and running_var is not None - ), "running_mean and running_var must be provided when use_input_stats is False" + assert running_mean is not None and running_var is not None, ( + "running_mean and running_var must be provided when use_input_stats is False" + ) batch_size = op.Shape(input, start=0, end=1) - bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0)) + bn_input = op.Reshape( + input, + op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0), + ) weight = op.Tile(weight, batch_size) bias = op.Tile(bias, batch_size) running_mean = op.Tile(running_mean, batch_size) @@ -4279,7 +4441,7 @@ def aten_instance_norm( momentum=1.0 - momentum, training_mode=False, ) - return op.Reshape(norm, op.Shape(input)) + return op.Reshape(norm, op.Shape(input), allowzero=True) def aten_int_repr(self: TensorType) -> TensorType: @@ -4360,21 +4522,11 @@ def aten_is_pinned(self: TensorType, device: Optional[str] = None) -> bool: raise NotImplementedError() -@torch_op("aten::is_same_size") +# is_same_size is decomposed by PyTorch def aten_is_same_size(self: TTensor, other: TTensor) -> BOOL: """is_same_size(Tensor self, Tensor other) -> bool""" - # Cannot compare different shape of two tensors using op.Equal() - # So we need to compare the rank first, if rank is same, then compare shape - result = op.Equal(Rank(self), Rank(other)) - if result: # Same rank, then compare shape - self_shape = op.Shape(self) - other_shape = op.Shape(other) - result_bool = op.Equal(self_shape, other_shape) - result_int = op.Cast(result_bool, to=INT8.dtype) - result = op.Cast(op.ReduceMin(result_int, keepdims=False), to=BOOL.dtype) - - return result + raise NotImplementedError def aten_is_set_to(self: TensorType, tensor: TensorType) -> bool: @@ -4401,7 +4553,7 @@ def aten_isclose( other: TReal, rtol: float = 1e-05, atol: float = 1e-08, - equal_nan: bool = False, # pylint: disable=unused-argument + equal_nan: bool = False, ) -> BOOL: """isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor""" @@ -4424,7 +4576,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL: @torch_op("aten::isinf") -def aten_isinf(self: TFloatOrBFloat16) -> BOOL: +def aten_isinf(self: TFloat) -> BOOL: """isinf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4433,14 +4585,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL: @torch_op("aten::isnan") -def aten_isnan(self: TFloatOrBFloat16) -> BOOL: +def aten_isnan(self: TFloat) -> BOOL: """isnan(Tensor self) -> Tensor""" return op.IsNaN(self) @torch_op("aten::isneginf") -def aten_isneginf(self: TFloatOrBFloat16) -> BOOL: +def aten_isneginf(self: TFloat) -> BOOL: """isneginf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4449,7 +4601,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL: @torch_op("aten::isposinf") -def aten_isposinf(self: TFloatOrBFloat16) -> BOOL: +def aten_isposinf(self: TFloat) -> BOOL: """isposinf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4521,6 +4673,7 @@ def aten_layer_norm( weight: Optional[TReal] = None, bias: Optional[TReal] = None, eps: float = 1e-05, + cudnn_enable: bool = True, ) -> TReal: """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor""" @@ -4528,28 +4681,10 @@ def aten_layer_norm( start_axis = -len(normalized_shape) if weight is None: - one = op.Constant(value_float=1.0) + one = op.Constant(value=ir.tensor(1, dtype=input.dtype)) weight = op.Expand(one, op.Shape(input, start=start_axis)) - if bias is None: - zero = op.Constant(value_float=0.0) - bias = op.Expand(zero, op.Shape(input, start=start_axis)) - - return _aten_layer_norm_onnx(input, weight, bias, axis=start_axis, eps=eps) - - -@torch_op("aten::layer_norm", private=True) -def _aten_layer_norm_onnx( - input: TReal, - weight: TReal, - bias: TReal, - axis: int, - eps: float = 1e-05, -) -> TReal: - """layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor""" - - # TODO(justinchuby): Use OptionalHasElement after onnx/onnx#4982 - result, _, _ = op.LayerNormalization(input, weight, bias, axis=axis, epsilon=eps) + result, _, _ = op.LayerNormalization(input, weight, bias, axis=start_axis, epsilon=eps) return result @@ -4565,30 +4700,36 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::le", "aten::le.Tensor", "_operator::le")) -def aten_le(self: TReal, other: TReal) -> BOOL: +@torch_op( + ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), + trace_only=True, +) +def aten_le(self: TTensor, other: TTensor) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.LessOrEqual(self, other) - - -@torch_op(("aten::le", "aten::le.Tensor", "aten::less_equal")) -def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: - """le.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + # self, other, self <= other + # F, F, T + # F, T, T + # T, F, F + # T, T, T - # self, other, self <= other - # F, F, T - # F, T, T - # T, F, F - # T, T, T + return op.Or(other, op.Not(self)) - return op.Or(other, op.Not(self)) + return op.LessOrEqual(self, other) -def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType: +@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) +def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor: """lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor""" - raise NotImplementedError() + weight = op.CastLike(weight, self) + diff = op.Sub(end, self) + return op.Where( + op.Less(weight, 0.5), + op.Add(self, op.Mul(weight, diff)), + op.Sub(end, op.Mul(diff, op.Sub(1.0, weight))), + ) def aten_lgamma(self: TensorType) -> TensorType: @@ -4609,7 +4750,7 @@ def aten_lift_fresh(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::lift_fresh_copy") +@torch_op("aten::lift_fresh_copy", trace_only=True) def aten_lift_fresh_copy(self: TensorType) -> TensorType: """lift_fresh_copy(Tensor self) -> Tensor""" @@ -4626,10 +4767,19 @@ def aten_linear_backward( @torch_op("aten::linspace", trace_only=True) def aten_linspace( - start: TFloat, end: TFloat, steps: int, dtype: int = FLOAT.dtype + start: TFloat, + end: TFloat, + steps: int, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1 or dtype is None: + dtype = FLOAT.dtype + # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 if steps == 0: return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype) @@ -4651,43 +4801,43 @@ def aten_linspace( ) -@torch_op("aten::log") -def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::log", trace_only=True) +def aten_log(self: TFloat) -> TFloat: """log(Tensor self) -> Tensor""" return op.Log(self) -@torch_op("aten::log10") -def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::log10", trace_only=True) +def aten_log10(self: TFloat) -> TFloat: """log10(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self)) @torch_op("aten::log1p") -def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log1p(self: TFloat) -> TFloat: """log1p(Tensor self) -> Tensor""" return op.Log(op.Add(self, 1.0)) -@torch_op("aten::log2") -def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::log2", trace_only=True) +def aten_log2(self: TFloat) -> TFloat: """log2(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self)) -@torch_op("aten::logaddexp") -def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::logaddexp", trace_only=True) +def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat: """logaddexp(Tensor self, Tensor other) -> Tensor""" return op.Log(op.Add(op.Exp(self), op.Exp(other))) -@torch_op("aten::logaddexp2") -def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::logaddexp2", trace_only=True) +def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat: """logaddexp2(Tensor self, Tensor other) -> Tensor""" two = op.CastLike(2.0, self) summation = op.Add(op.Pow(two, self), op.Pow(two, other)) @@ -4695,11 +4845,11 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr return op.Div(op.Log(summation), op.Log(two)) -@torch_op("aten::logcumsumexp", traceable=True) -def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: +@torch_op("aten::logcumsumexp", trace_only=True) +def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat: """logcumsumexp(Tensor self, int dim) -> Tensor""" - if IsScalar(self): + if len(self.shape) == 0: result = self else: # Make dim 1-d @@ -4719,88 +4869,70 @@ def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: return result -@torch_op("aten::logdet") +@torch_op("aten::logdet", trace_only=True) def aten_logdet(self: TFloat) -> TFloat: """logdet(Tensor self) -> Tensor""" return op.Log(op.Det(self)) -@torch_op( - ( - "aten::logical_and", - "aten::bitwise_and", - "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", - ) -) -def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: +@torch_op("aten::logical_and", trace_only=True) +def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" - return op.And(self, other) + assert self.dtype == other.dtype + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op(("aten::logical_not", "aten::bitwise_not")) -def aten_logical_not(self: BOOL) -> BOOL: + +@torch_op("aten::logical_not", trace_only=True) +def aten_logical_not(self: TTensor) -> BOOL: """logical_not(Tensor self) -> Tensor""" - return op.Not(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + return op.Not(op.Cast(self, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_or", - "aten::bitwise_or", - "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", - "aten::add", - "aten::add.Tensor", - ) -) -def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: +@torch_op("aten::logical_or", trace_only=True) +def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" - return op.Or(self, other) + assert self.dtype == other.dtype + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_xor", - "aten::bitwise_xor", - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ) -) -def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: + +@torch_op("aten::logical_xor", trace_only=True) +def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" - return op.Xor(self, other) + assert self.dtype == other.dtype + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op("aten::logit", private=True) -def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16: - return op.Log(op.Div(self, op.Sub(1.0, self))) +@torch_op("aten::logit", trace_only=True) +def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: + """logit(Tensor self, float? eps=None) -> Tensor""" + one = ir.tensor(1, dtype=self.dtype) -@torch_op("aten::logit", private=True) -def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16: - eps = op.CastLike(eps, self) - one = op.CastLike(1.0, self) - temporary_self = op.Where(self <= one - eps, self, one - eps) - z = op.Where(temporary_self < eps, eps, temporary_self) + if eps is None: + return op.Log(op.Div(self, op.Sub(one, self))) - return op.Log(op.Div(z, op.Sub(one, z))) + one_minus_eps = ir.tensor(1 - eps, dtype=self.dtype) + eps = ir.tensor(eps, dtype=self.dtype) + temporary_self = op.Where(self <= one_minus_eps, self, one_minus_eps) + z = op.Where(temporary_self < eps, eps, temporary_self) -@torch_op("aten::logit", trace_only=True) -def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16: - """logit(Tensor self, float? eps=None) -> Tensor""" - if eps is None: - return _aten_logit_onnx(self) - return _aten_logit_clamp_onnx(self, eps) + return op.Log(op.Div(z, op.Sub(one, z))) def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> TensorType: @@ -4809,11 +4941,11 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T raise NotImplementedError() -@torch_op("aten::logsumexp", traceable=True) +@torch_op("aten::logsumexp", trace_only=True) def aten_logsumexp(self: TFloat, dim: INT64, keepdim: int = False) -> TFloat: """logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor""" - if IsScalar(self): + if len(self.shape) == 0: # A scalar result = self else: @@ -4821,12 +4953,6 @@ def aten_logsumexp(self: TFloat, dim: INT64, keepdim: int = False) -> TFloat: return result -def aten_lshift(self: TensorType, other: TensorType) -> TensorType: - """__lshift__.Tensor(Tensor self, Tensor other) -> Tensor""" - - raise NotImplementedError() - - def aten_lstm_cell( input: TensorType, hx: Sequence[TensorType], @@ -4861,26 +4987,24 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less", "_operator::lt")) -def aten_lt(self: TReal, other: TReal) -> BOOL: +@torch_op( + ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), + trace_only=True, +) +def aten_lt(self: TTensor, other: TTensor) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + # self, other, self < other + # F, F, F + # F, T, T + # T, F, F + # T, T, F + return op.And(other, op.Not(self)) + return op.Less(self, other) -@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less")) -def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL: - """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - - # self, other, self < other - # F, F, F - # F, T, T - # T, F, F - # T, T, F - - return op.And(other, op.Not(self)) - - def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: """lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor""" @@ -4910,9 +5034,6 @@ def aten_mH(self: TRealOrUInt8) -> TRealOrUInt8: def aten_mH_complex(self: TFloat) -> TFloat: """mH(Tensor(a) self) -> Tensor(a)""" - # TODO(#834): Allow calling scripted functions from other - # scripted functions and remove trace only. - # c is the last dimension being the real and imaginary parts trasposed = op.Einsum(self, equation="...ijc->...jic") return _complex_conjugate(trasposed) @@ -4946,8 +5067,8 @@ def aten_margin_ranking_loss( @torch_op( - ("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"), - traceable=True, + ("aten::masked_fill.Scalar", "aten::masked_fill.Tensor"), + trace_only=True, ) def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: """masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor""" @@ -4957,10 +5078,26 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: return op.Where(mask, value_cast, self) -def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType: +@torch_op(("aten::masked_scatter"), trace_only=True) +def aten_masked_scatter(self: TTensor, mask: TTensor, source: TTensor) -> TTensor: """masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor""" - raise NotImplementedError() + if len(mask.shape) < len(self.shape): + mask = op.Expand(mask, op.Shape(self)) + else: + self = op.Expand(self, op.Shape(mask)) + index = op.Transpose(op.NonZero(mask), perm=[1, 0]) + + # NOTE: source can have more elements than needed. + # It could also have arbitrary shape. + # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. + source = op.Reshape(source, op.Constant(value_ints=[-1])) + axes = op.Constant(value_ints=[0]) + starts = op.Constant(value_ints=[0]) + ends = op.Gather(op.Shape(index), op.Constant(value_ints=[0]), axis=0) + source = op.Slice(source, starts, ends, axes) + + return op.ScatterND(self, index, source) def aten_masked_select(self: TensorType, mask: TensorType) -> TensorType: @@ -4977,7 +5114,7 @@ def aten_masked_select_backward( raise NotImplementedError() -@torch_op("aten::matmul") +@torch_op("aten::matmul", trace_only=True) def aten_matmul( self: TRealUnlessInt16OrInt8, other: TRealUnlessInt16OrInt8 ) -> TRealUnlessInt16OrInt8: @@ -5018,27 +5155,18 @@ def aten_matrix_power(self: TensorType, n: int) -> TensorType: raise NotImplementedError() -@torch_op("aten::max") +@torch_op("aten::max", trace_only=True) def aten_max(self: TReal) -> TReal: """max(Tensor self) -> Tensor""" - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Reshape(self, op.Constant(value_ints=[-1])) - - result = op.ReduceMax(self, keepdims=False) - - if self_is_scalar: - result = op.Squeeze(result) - - return result + return op.ReduceMax(self, keepdims=False) -@torch_op("aten::max.dim", traceable=True) +@torch_op("aten::max.dim", trace_only=True) def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, INT64]: """max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" - if IsScalar(self): + if len(self.shape) == 0: result = self indices = op.Constant(value_int=0) else: @@ -5048,18 +5176,14 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I return result, indices -@torch_op(("aten::maximum", "aten::max.other")) -def aten_maximum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::maximum", trace_only=True) +def aten_maximum(self: TTensor, other: TTensor) -> TTensor: """maximum(Tensor self, Tensor other) -> Tensor""" - return op.Max(self, other) - - -@torch_op(("aten::maximum", "aten::max.other")) -def aten_maximum_bool(self: BOOL, other: BOOL) -> BOOL: - """maximum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) - return op.Or(self, other) + return op.Max(self, other) @torch_op("aten::mean") @@ -5070,16 +5194,15 @@ def aten_mean(self: TReal) -> TReal: return op.Squeeze(result) -@torch_op("aten::mean.dim", traceable=True) +@torch_op("aten::mean.dim", trace_only=True) def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: """mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" - if IsScalar(self): + if len(self.shape) == 0: result = self else: - if IsScalar(dim): - dim = op.Unsqueeze(dim, axes=0) - result = op.ReduceMean(self, dim, keepdims=keepdim) + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + result = op.ReduceMean(self, dims, keepdims=keepdim) return result @@ -5095,17 +5218,17 @@ def aten_meshgrid(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::min") +@torch_op("aten::min", trace_only=True) def aten_min(self: TReal) -> TReal: """min(Tensor self) -> Tensor""" return op.ReduceMin(self, keepdims=False) -@torch_op("aten::min.dim", traceable=True) +@torch_op("aten::min.dim", trace_only=True) def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, TInt]: """min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" - if IsScalar(self): + if len(self.shape) == 0: result = self indices = op.Constant(value_int=0) else: @@ -5116,18 +5239,14 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T return result, indices -@torch_op(("aten::minimum", "aten::min.other")) -def aten_minimum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::minimum", trace_only=True) +def aten_minimum(self: TTensor, other: TTensor) -> TTensor: """minimum(Tensor self, Tensor other) -> Tensor""" - return op.Min(self, other) - + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) -@torch_op(("aten::minimum", "aten::min.other")) -def aten_minimum_bool(self: BOOL, other: BOOL) -> BOOL: - """minimum(Tensor self, Tensor other) -> Tensor""" - - return op.And(self, other) + return op.Min(self, other) def aten_miopen_batch_norm( @@ -5398,7 +5517,7 @@ def aten_mkldnn_max_pool3d_backward( raise NotImplementedError() -@torch_op("aten::mm") +@torch_op("aten::mm", trace_only=True) def aten_mm( self: TRealUnlessInt16OrInt8, mat2: TRealUnlessInt16OrInt8 ) -> TRealUnlessInt16OrInt8: @@ -5466,27 +5585,28 @@ def aten_msort(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul")) -def aten_mul(self: TReal, other: TReal) -> TReal: +@torch_op( + ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), + trace_only=True, +) +def aten_mul(self: TTensor, other: TTensor) -> TTensor: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Mul(self, other) - - -@torch_op(("aten::mul", "aten::mul.Tensor")) -def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: - """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) - # TODO(justinchuby): Handle cases where type reconcilation is not enough, - # since different ONNX operators are used based on different data types. - - return op.And(self, other) + return op.Mul(self, other) -@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), complex=True) +@torch_op( + ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), + trace_only=True, + complex=True, +) def aten_mul_complex(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" + # TODO(justinchuby): Maybe use Split to simplify the logic self_real = op.Slice(self, [0], [1], axes=[-1]) self_imag = op.Slice(self, [1], [2], axes=[-1]) other_real = op.Slice(other, [0], [1], axes=[-1]) @@ -5506,22 +5626,22 @@ def aten_mul_complex(self: TReal, other: TReal) -> TReal: return op.Concat(real, imag, axis=-1) -@torch_op("aten::multinomial") +@torch_op("aten::multinomial", trace_only=True) def aten_multinomial( self: TFloat, num_samples: int, - replacement: bool = False, # pylint: disable=unused-argument + replacement: bool = False, ) -> TInt: """multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor""" # ONNX Multinomial doesn't support 1D input - if Rank(self) == 1: + if len(self.shape) == 1: unsqueezed_input = op.Unsqueeze(self, axes=0) else: unsqueezed_input = self # ONNX multinomial expects log probability log_input = op.Log(unsqueezed_input) result = op.Multinomial(log_input, dtype=INT64.dtype, sample_size=num_samples) - if Rank(self) == 1: + if len(self.shape) == 1: result = op.Squeeze(result) return result @@ -5532,10 +5652,11 @@ def aten_multiply(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::mv", trace_only=True) def aten_mv(self: TensorType, vec: TensorType) -> TensorType: """mv(Tensor self, Tensor vec) -> Tensor""" - raise NotImplementedError() + return op.MatMul(self, vec) def aten_mvlgamma(self: TensorType, p: int) -> TensorType: @@ -5595,18 +5716,13 @@ def aten_nansum( raise NotImplementedError() -@torch_op("aten::narrow", traceable=True) +@torch_op("aten::narrow", trace_only=True) def aten_narrow(self: TTensor, dim: INT64, start: INT64, length: INT64) -> TTensor: """narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)""" - if IsScalar(dim): - dim = op.Reshape(dim, op.Constant(value_ints=[-1])) - - if IsScalar(start): - start = op.Reshape(start, op.Constant(value_ints=[-1])) - - if IsScalar(length): - length = op.Reshape(length, op.Constant(value_ints=[-1])) + dim = op.Reshape(dim, op.Constant(value_ints=[-1])) + start = op.Reshape(start, op.Constant(value_ints=[-1])) + length = op.Reshape(length, op.Constant(value_ints=[-1])) end = op.Add(start, length) return op.Slice(self, start, end, dim) @@ -5682,14 +5798,18 @@ def aten_native_batch_norm( axes.pop(1) axes = op.Constant(value_ints=axes) if running_mean is None: # Using input mean - running_mean = op.Squeeze(op.ReduceMean(input, axes)) + running_mean = op.ReduceMean(input, axes, keepdims=False) if running_var is None: # Using input var mean = op.ReduceMean(input, axes) input_sub_mean = op.Sub(input, mean) sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) - running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes)) + running_var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) + # TODO: This is a temporary fix for the issue that BatchNormalization + # is forced to be in training mode in PyTorch, and ORT currently + # only supports training mode with opset version lower than 14. + training = False # We have to split to two private functions, because BatchNormalization returns # three outputs when training_mode=True and one when it is False. if training: @@ -5717,7 +5837,6 @@ def aten_native_batch_norm( return norm, input_mean, input_rstd -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_training_onnx( input: TFloat, weight: TFloat, @@ -5769,7 +5888,6 @@ def _aten_native_batch_norm_training_onnx( return norm, mean, rstd, running_mean, new_running_var -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_inference_onnx( input: TFloat, weight: TFloat, @@ -5829,14 +5947,18 @@ def aten__native_batch_norm_legit_functional( axes.pop(1) axes = op.Constant(value_ints=axes) if running_mean is None: # Using input mean - running_mean = op.Squeeze(op.ReduceMean(input, axes)) + running_mean = op.ReduceMean(input, axes, keepdims=False) if running_var is None: # Using input var mean = op.ReduceMean(input, axes) input_sub_mean = op.Sub(input, mean) sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) - running_var = op.Squeeze(op.ReduceMean(sqr_input_sub_mean, axes)) + running_var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) + # TODO: This is a temporary fix for the issue that BatchNormalization + # is forced to be in training mode in PyTorch, and ORT currently + # only supports training mode with opset version lower than 14. + training = False # We have to split to two private functions, because BatchNormalization returns # three outputs when training_mode=True and one when it is False. if training: @@ -5899,16 +6021,10 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: raise NotImplementedError() -@torch_op("aten::native_dropout") -def aten_native_dropout( - input: TFloatOrBFloat16, p: float, train: bool = True -) -> Tuple[TFloatOrBFloat16, BOOL]: +@torch_op("aten::native_dropout", trace_only=True) +def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]: """native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)""" - # Python bool attributes need to be explicitly converted to BOOL - # because the underlying attribute type is int - # TODO(#872): Allow ONNX Script to handle this conversion - train = op.Cast(train, to=BOOL.dtype) result, mask = op.Dropout(input, p, train) return result, mask @@ -5926,9 +6042,9 @@ def aten_native_group_norm( input: TFloat, weight: Optional[TFloat] = None, bias: Optional[TFloat] = None, - N: Optional[INT64] = None, # pylint: disable=unused-argument - C: Optional[INT64] = None, # pylint: disable=unused-argument - HxW: Optional[INT64] = None, # pylint: disable=unused-argument + N: Optional[INT64] = None, + C: Optional[INT64] = None, + HxW: Optional[INT64] = None, group: int = 1, eps: float = 1e-05, ) -> Tuple[TFloat, TFloat, TFloat]: @@ -5941,22 +6057,10 @@ def aten_native_group_norm( if bias is None: # Set to 0.0 as default, the shape is Channel size bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) - # Accoding to Torch, return rstd instead of var - norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) - return norm, mean, rstd - - -@torch_op("aten::native_group_norm", private=True) -def _aten_native_group_norm_onnx( - input: TFloat, - weight: TFloat, - bias: TFloat, - group: INT64, - eps: float, -) -> Tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate + # This implementation should be simplified after opset 21 neg_1 = op.Constant(value_ints=[-1]) # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter group_tensor = op.Reshape(group, neg_1) @@ -5969,14 +6073,16 @@ def _aten_native_group_norm_onnx( input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps ) # Reshape back to input's shape - norm = op.Reshape(norm, op.Shape(input)) + norm = op.Reshape(norm, op.Shape(input), allowzero=True) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy input_rank = Rank(input) axes_unsqueeze = op.Range(1, input_rank - 1, 1) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) + weight_full_shape = op.CastLike(weight_full_shape, norm) norm_mul_weight = op.Mul(norm, weight_full_shape) + bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight) norm_result = op.Add(norm_mul_weight, bias_full_shape) # Compute mean and rstd, but using Torch algorithm # The returned shape for mean and vstd should be [N, group, -1] @@ -5991,7 +6097,9 @@ def _aten_native_group_norm_onnx( sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) # In Pytorch, vstd = 1/(sqrt(var + eps)) var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) - rstd = op.Div(1.0, op.Sqrt(var + eps)) + eps = op.Constant(value=ir.tensor(eps, dtype=input.dtype)) + one = op.Constant(value=ir.tensor(1.0, dtype=input.dtype)) + rstd = op.Div(one, op.Sqrt(op.Add(var, eps))) # Get the correct shape [N, group] for mean again mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=False) return norm_result, mean, rstd @@ -6017,7 +6125,7 @@ def aten_native_group_norm_backward( @torch_op("aten::native_layer_norm", trace_only=True) def aten_native_layer_norm( input: TReal, - normalized_shape: INT64, + normalized_shape: Sequence[int], weight: Optional[TReal] = None, bias: Optional[TReal] = None, eps: float = 1e-05, @@ -6066,14 +6174,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType: raise NotImplementedError() -@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne")) +@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"), trace_only=True) def aten_ne(self: TReal, other: TReal) -> BOOL: """ne.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Not(op.Equal(self, other)) -@torch_op(("aten::neg", "_operator::neg")) +@torch_op(("aten::neg", "_operator::neg"), trace_only=True) def aten_neg(self: TReal) -> TReal: """neg(Tensor self) -> Tensor""" @@ -6086,111 +6194,94 @@ def aten_negative(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::new_empty") -def aten_new_empty(self: TTensor, size: INT64) -> TTensor: - """new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - - # using zero to simulate empty array - result = op.ConstantOfShape(size) - return op.CastLike(result, self) - - -@torch_op("aten::new_empty") -def aten_new_empty_dtype( - self: TTensor, # pylint: disable=unused-argument +@torch_op("aten::new_empty", trace_only=True) +def aten_new_empty( + self: TTensor, size: INT64, - dtype: int, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TTensor: """new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # using zero to simulate empty array result = op.ConstantOfShape(size) + if dtype == -1: + return op.CastLike(result, self) return op.Cast(result, to=dtype) -@torch_op("aten::new_empty_strided") +@torch_op("aten::new_empty_strided", trace_only=True) def aten_new_empty_strided( self: TTensor, size: INT64, - stride: INT64, # pylint: disable=unused-argument -) -> TTensor: - """new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - - # using zero to simulate empty array - zero = op.ConstantOfShape(size) - return op.CastLike(zero, self) - - -@torch_op("aten::new_empty_strided") -def aten_new_empty_strided_dtype( - self: TTensor, # pylint: disable=unused-argument - size: INT64, - stride: INT64, # pylint: disable=unused-argument - dtype: int, + stride: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TTensor: """new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # using zero to simulate empty array zero = op.ConstantOfShape(size) + if dtype == -1: + return op.CastLike(zero, self) return op.Cast(zero, to=dtype) -@torch_op("aten::new_full") -def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: - # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - - fill_value = op.CastLike(fill_value, self) - return op.Expand(fill_value, size) - - -@torch_op("aten::new_full") -def aten_new_full_dtype( - self: TTensor, # pylint: disable=unused-argument +@torch_op("aten::new_full", trace_only=True) +def aten_new_full( + self: TTensor, size: INT64, - fill_value: TTensor, - dtype: int, + fill_value: TensorType, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TTensor: # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - fill_value = op.Cast(fill_value, to=dtype) + if dtype == -1: + fill_value = op.CastLike(fill_value, self) + else: + fill_value = op.Cast(fill_value, to=dtype) return op.Expand(fill_value, size) -@torch_op("aten::new_ones") -def aten_new_ones(self: TReal, size: INT64) -> TReal: # pylint: disable=unused-argument +@torch_op("aten::new_ones", trace_only=True) +def aten_new_ones( + self: TReal, + size: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TReal: """new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" one = op.Constant(value_float=1.0) result = op.Expand(one, size) - return op.CastLike(result, self) + if dtype == -1: + return op.CastLike(result, self) + return op.Cast(result, to=dtype) -@torch_op("aten::new_ones") -def aten_new_ones_dtype( - self: TReal, # pylint: disable=unused-argument +@torch_op("aten::new_zeros", trace_only=True) +def aten_new_zeros( + self: TReal, size: INT64, - dtype: int, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TReal: - one = op.Constant(value_float=1.0) - result = op.Expand(one, size) - return op.Cast(result, to=dtype) - - -@torch_op("aten::new_zeros") -def aten_new_zeros(self: TReal, size: INT64) -> TReal: """new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" result = op.ConstantOfShape(size) - return op.CastLike(result, self) - - -@torch_op("aten::new_zeros") -def aten_new_zeros_dtype( - self: TReal, # pylint: disable=unused-argument - size: INT64, - dtype: int, -) -> TReal: - result = op.ConstantOfShape(size) + if dtype == -1: + return op.CastLike(result, self) return op.Cast(result, to=dtype) @@ -6200,7 +6291,7 @@ def aten_nextafter(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::nonzero") +@torch_op("aten::nonzero", trace_only=True) def aten_nonzero(self: TTensor) -> INT64: """nonzero(Tensor self) -> Tensor""" # NOTE: In torch the return shape is [n, d], while in onnx [d, n], @@ -6220,7 +6311,7 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp raise NotImplementedError() -@torch_op(("aten::normal", "aten::normal_functional"), traceable=True) +@torch_op("aten::normal_functional", trace_only=True) def aten_normal( self: TTensor, mean: float = 0.0, @@ -6228,26 +6319,28 @@ def aten_normal( ) -> TFloat: # type: ignore[type-var] """normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor""" - if IsScalar(self): + if len(self.shape) == 0: self = op.Reshape(self, op.Constant(value_ints=[-1])) result = op.RandomNormalLike(self, mean=mean, scale=std) return result -@torch_op("aten::normal.float_float") +@torch_op("aten::normal.float_float", trace_only=True) def aten_normal_float_float( mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype ) -> TensorType: """normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1: + dtype = FLOAT.dtype # Create a dummy tensor for RandomNormalLike to get the shape dummy_tensor = op.ConstantOfShape(size) result = op.RandomNormalLike(dummy_tensor, mean=mean, scale=std) return op.Cast(result, to=dtype) -@torch_op("aten::normal.float_Tensor") +@torch_op("aten::normal.float_Tensor", trace_only=True) def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: """normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6257,7 +6350,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: return op.Add(op.Mul(std, sampled), mean_casted) -@torch_op("aten::normal.Tensor_float") +@torch_op("aten::normal.Tensor_float", trace_only=True) def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: """normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor""" @@ -6266,7 +6359,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean) -@torch_op("aten::normal.Tensor_Tensor") +@torch_op("aten::normal.Tensor_Tensor", trace_only=True) def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat: """normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6287,10 +6380,17 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType: raise NotImplementedError() -@torch_op("aten::ones") -def aten_ones(size: IntType, dtype: int = FLOAT.dtype): +@torch_op("aten::ones", trace_only=True) +def aten_ones( + size: IntType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +): """ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1: + dtype = FLOAT.dtype size = op.Cast(size, to=INT64.dtype) one = op.Constant(value_float=1.0) one = op.Cast(one, to=dtype) @@ -6298,26 +6398,26 @@ def aten_ones(size: IntType, dtype: int = FLOAT.dtype): @torch_op("aten::ones_like", trace_only=True) -def aten_ones_like(self: TTensor, dtype: int = -1) -> TTensor: - """ones_like. +def aten_ones_like( + self: TTensor, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, + memory_format: str = "", +) -> TTensor: + """ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype before calling this function. """ - # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. + if dtype is None: + dtype = -1 if dtype == -1: one = op.CastLike(1, self) else: one = op.Cast(1, to=dtype) - return _aten_ones_like_onnx(self, one) - - -@torch_op("aten::ones_like", private=True) -def _aten_ones_like_onnx(self: TTensor, one) -> TTensor: shape = op.Shape(self) return op.Expand(one, shape) @@ -6403,16 +6503,42 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType: raise NotImplementedError() -def aten_pixel_shuffle(self: TensorType, upscale_factor: int) -> TensorType: +@torch_op("aten::pixel_shuffle", trace_only=True) +def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: """pixel_shuffle(Tensor self, int upscale_factor) -> Tensor""" + if len(self.shape) == 4: + return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD") - raise NotImplementedError() + # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + + reshaped_self = op.Reshape( + self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) + ) + depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") + final_dims = op.Shape(depth_to_space, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) + return op.Reshape(depth_to_space, output_shape, allowzero=True) -def aten_pixel_unshuffle(self: TensorType, downscale_factor: int) -> TensorType: +@torch_op("aten::pixel_unshuffle", trace_only=True) +def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" + if len(self.shape) == 4: + return op.SpaceToDepth(self, blocksize=downscale_factor) - raise NotImplementedError() + # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + + reshaped_self = op.Reshape( + self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) + ) + space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) + final_dims = op.Shape(space_to_depth, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) + return op.Reshape(space_to_depth, output_shape, allowzero=True) def aten_poisson(self: TensorType, generator: Optional[str] = None) -> TensorType: @@ -6455,19 +6581,49 @@ def aten_positive(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op( - ("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow") -) +@torch_op(("aten::pow.Tensor_Tensor", "_operator::pow"), trace_only=True) def aten_pow(self: TReal, exponent: TTensor) -> TReal: """pow(Tensor self, Tensor exponent) -> Tensor""" - + # TODO(justinchuby): Add type promotion return op.Pow(self, exponent) -def aten_prelu(self: TensorType, weight: TensorType) -> TensorType: +@torch_op("aten::pow.Tensor_Scalar", trace_only=True) +def aten_pow_tensor_scalar(self: TReal, exponent: float) -> TReal: + """pow(Tensor self, Scalar exponent) -> Tensor""" + if self.dtype.is_floating_point(): + # Handle cases when e.g. (1) self is float16 or int + return op.Pow(self, ir.tensor(exponent, dtype=self.dtype)) + # For integer types, we need to cast self to the exponent type + if isinstance(exponent, int): + # The scalar exponent can be an int + return op.Pow(self, ir.tensor(exponent, dtype=self.dtype)) + + # exponent is float so we cast self to match the exponent type. + # More precisely if self is float64, we should cast exponent to float64; but + # this is uncommon and should be fixed when we create a general type promotion + # mechanism for torchlib + return op.Pow(op.Cast(self, to=FLOAT.dtype), exponent) + + +@torch_op("aten::pow.Scalar", trace_only=True) +def aten_pow_scalar(self: float, exponent: TTensor) -> TTensor: + """pow.Scalar(Scalar self, Tensor exponent) -> Tensor""" + return op.Pow(op.Cast(self, to=exponent.dtype), exponent) + + +@torch_op(("aten::prelu", "aten::_prelu_kernel"), trace_only=True) +def aten_prelu(self: TReal, weight: TReal) -> TReal: """prelu(Tensor self, Tensor weight) -> Tensor""" - raise NotImplementedError() + rank = len(self.shape) + if rank == 0: + # e.g. self: [], weight: [1] + weight = op.Squeeze(weight) + elif rank >= 2: + # e.g. self: [5,10,5], weight: [10] + weight = op.Reshape(weight, [1, -1] + [1] * (rank - 2)) + return op.PRelu(self, weight) def aten_prelu_backward( @@ -6478,10 +6634,22 @@ def aten_prelu_backward( raise NotImplementedError() -def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType: +@torch_op("aten::prod", trace_only=True) +def aten_prod(self: TReal, dtype: int = -1) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" - raise NotImplementedError() + if dtype != -1 and dtype is not None: + self = op.Cast(self, to=dtype) + return op.ReduceProd(self) + + +@torch_op("aten::prod.dim_int", trace_only=True) +def aten_prod_dim_int(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal: + """prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" + + if dtype != -1 and dtype is not None: + self = op.Cast(self, to=dtype) + return op.ReduceProd(self, axes=[dim], keepdims=keepdim) def aten_promote_types(type1: int, type2: int) -> int: @@ -6701,37 +6869,48 @@ def aten_quantized_rnn_tanh_cell( raise NotImplementedError() -@torch_op("aten::rad2deg", traceable=True) +@torch_op("aten::rad2deg", trace_only=True) def aten_rad2deg(self: TFloat) -> TFloat: """rad2deg(Tensor self) -> Tensor""" return op.Mul(self, op.CastLike(180.0 / _MATH_PI, self)) -@torch_op("aten::rand") -def aten_rand(size: INT64, dtype: int = FLOAT.dtype) -> TReal: +@torch_op("aten::rand", trace_only=True) +def aten_rand( + size: INT64, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TReal: """rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - + if dtype == -1: + dtype = FLOAT.dtype shaper = op.ConstantOfShape(size) return op.RandomUniformLike(shaper, dtype=dtype) -@torch_op("aten::rand_like") -def aten_rand_like(self: TFloat) -> TFloat: - """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - return op.RandomUniformLike(self) - - -@torch_op("aten::rand_like") -def aten_rand_like_dtype(self: TensorType, dtype: int) -> TensorType: +@torch_op("aten::rand_like", trace_only=True) +def aten_rand_like( + self: TFloat, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False +) -> TFloat: """rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + if dtype == -1: + return op.RandomUniformLike(self) return op.RandomUniformLike(self, dtype=dtype) -@torch_op("aten::randint") -def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorType: +@torch_op("aten::randint", trace_only=True) +def aten_randint( + high: INT64, + size: INT64, + dtype: int = INT64.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" shaper = op.ConstantOfShape(size) @@ -6743,9 +6922,15 @@ def aten_randint(high: INT64, size: INT64, dtype: int = INT64.dtype) -> TensorTy return op.Cast(rand_int, to=dtype) -@torch_op("aten::randint.low") +@torch_op("aten::randint.low", trace_only=True) def aten_randint_low( - low: INT64, high: INT64, size: INT64, dtype: int = INT64.dtype + low: INT64, + high: INT64, + size: INT64, + dtype: int = INT64.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -6760,21 +6945,15 @@ def aten_randint_low( return op.Cast(rand_int, to=dtype) -@torch_op("aten::randint_like") -def aten_randint_like(self: TensorType, high: INT64) -> IntType: - """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - self_float = op.Cast(self, to=FLOAT.dtype) - rand = op.RandomUniformLike(self_float) - # Scale to [0, high] first - rand_scaled = op.Mul(rand, op.CastLike(high, rand)) - # Round to ints - rand_int = op.Floor(rand_scaled) - return op.CastLike(rand_int, self) - - -@torch_op("aten::randint_like") -def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> TensorType: +@torch_op("aten::randint_like", trace_only=True) +def aten_randint_like( + self: TensorType, + high: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> IntType: """randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" self_float = op.Cast(self, to=FLOAT.dtype) @@ -6783,11 +6962,21 @@ def aten_randint_like_dtype(self: TensorType, high: INT64, dtype: int) -> Tensor rand_scaled = op.Mul(rand, op.CastLike(high, rand)) # Round to ints rand_int = op.Floor(rand_scaled) + if dtype == -1: + return op.CastLike(rand_int, self) return op.Cast(rand_int, to=dtype) -@torch_op("aten::randint_like.low_dtype") -def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> IntType: +@torch_op("aten::randint_like.low_dtype", trace_only=True) +def aten_randint_like_low_dtype( + self: TensorType, + low: INT64, + high: INT64, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> IntType: """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor This is the TorchLib overload for aten::randint_like.low_dtype when dtype is None. @@ -6801,55 +6990,47 @@ def aten_randint_like_low_dtype(self: TensorType, low: INT64, high: INT64) -> In rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) # Round to ints rand_int = op.Floor(rand_translated) - return op.CastLike(rand_int, self) - - -@torch_op("aten::randint_like.low_dtype") -def aten_randint_like_low_dtype_dtype( - self: TensorType, low: INT64, high: INT64, dtype: int -) -> TensorType: - """randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - self_float = op.Cast(self, to=FLOAT.dtype) - rand = op.RandomUniformLike(self_float) - # Translate to [low, high] first - high = op.Cast(high, to=FLOAT.dtype) - low = op.Cast(low, to=FLOAT.dtype) - rand_translated = op.Add(op.Mul(rand, op.Sub(high, low)), low) - # Round to ints - rand_int = op.Floor(rand_translated) + if dtype == -1: + return op.CastLike(rand_int, self) return op.Cast(rand_int, to=dtype) -@torch_op("aten::randn") -def aten_randn(size: INT64, dtype: int = FLOAT.dtype) -> TReal: +@torch_op("aten::randn", trace_only=True) +def aten_randn( + size: INT64, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TReal: """randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" shaper = op.ConstantOfShape(size) return op.RandomNormalLike(shaper, dtype=dtype) -@torch_op("aten::randn_like") -def aten_randn_like(self: TFloat) -> TFloat: - """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - - return op.RandomNormalLike(self) - - -@torch_op("aten::randn_like") -def aten_randn_like_dtype(self: TensorType, dtype: int) -> TensorType: +@torch_op("aten::randn_like", trace_only=True) +def aten_randn_like( + self: TFloat, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False +) -> TFloat: """randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" + if dtype == -1: + return op.RandomNormalLike(self) return op.RandomNormalLike(self, dtype=dtype) -def aten_randperm(n: int) -> TensorType: +def aten_randperm( + n: int, layout: str = "", device: str = "", pin_memory: bool = False +) -> TensorType: """randperm(int n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" raise NotImplementedError() -def aten_range(start: float, end: float) -> TensorType: +def aten_range( + start: float, end: float, layout: str = "", device: str = "", pin_memory: bool = False +) -> TensorType: """range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" raise NotImplementedError() @@ -6867,8 +7048,8 @@ def aten_real(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::reciprocal") -def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::reciprocal", trace_only=True) +def aten_reciprocal(self: TFloat) -> TFloat: """reciprocal(Tensor self) -> Tensor""" return op.Reciprocal(self) @@ -6886,10 +7067,15 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op("aten::remainder") -def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op( + ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True +) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype.is_integer(): + return op.Mod(self, other) + # TODO(justinchuby): Improve fp16 precision by following the logic in # https://github.com/pytorch/pytorch/blob/3a823e46170778cc32783f27596c77d0103084a9/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L264-L277 @@ -6899,13 +7085,6 @@ def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op("aten::remainder") -def aten_remainder_int(self: TInt, other: TInt) -> TInt: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.Mod(self, other) - - def aten_rename(self: TensorType, names: Optional[str]) -> TensorType: """rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)""" @@ -6918,39 +7097,141 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT raise NotImplementedError() -@torch_op("aten::repeat") -def aten_repeat(self: TTensor, repeats: TInt) -> TTensor: +@torch_op("aten::repeat", trace_only=True) +def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: """repeat(Tensor self, SymInt[] repeats) -> Tensor""" - if op.Size(repeats) == 0: - result = self + if len(repeats) == 0: + return self + self_expanded = op.Expand(self, [1] * len(repeats)) + return op.Tile(self_expanded, repeats) + + +@torch_op("aten::repeat_interleave.self_int", trace_only=True) +def aten_repeat_interleave_self_int( + self: TensorType, repeats: int, dim: Optional[int] = None +) -> TensorType: + """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + + The trick is to repeat in one direction orthogonal to reshape. + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat_interleave(2, dim=0) + + is equivalent to: + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat((1, 2)).reshape((-1, t.shape[1])) + """ + if dim is None: + raise NotImplementedError("No conversion available yet when dim is None.") + + self_rank = len(self.shape) + pos_dim = (dim + self_rank) % self_rank + unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) + if isinstance(repeats, int): + tiles = [1] * (self_rank + 1) + tiles[pos_dim + 1] = repeats + tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) else: - # TODO(justinchuby): Make ones_like a function when onnxscript supports it - repeats = op.Cast(repeats, to=INT64.dtype) - # shape = ones_like(repeats) := { - one = op.Constant(value_int=1) - repeats_shape = op.Shape(repeats) - shape = op.Expand(one, repeats_shape) - # } - self_expanded = op.Expand(self, shape) - result = op.Tile(self_expanded, repeats) - return result + # repeats is a symbolic tensor + tile_repeat = op.Concat( + op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)), + op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), + op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), + axis=0, + ) + tiled = op.Expand(unsqueezed, tile_repeat) + if self_rank == 1: + return op.Identity(tiled) + final_shape = op.Concat( + op.Shape(self, start=0, end=dim), + op.Constant(value_ints=[-1]), + op.Shape(self, start=pos_dim + 1), + axis=0, + ) + return op.Reshape(tiled, final_shape) -def aten_repeat_interleave( - repeats: TensorType, output_size: Optional[int] = None +@torch_op("aten::repeat_interleave.Tensor", trace_only=True) +def aten_repeat_interleave_Tensor( + self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None ) -> TensorType: - """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" + """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor - raise NotImplementedError() + When `repeats` is a tensor, each line is multiplied + by a different number. + There are multiple strategies. Here is one. + .. code-block:: python -@torch_op("aten::reshape") -def aten_reshape(self: TTensor, shape: IntType) -> TTensor: - """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" + import torch - # Reshape only support INT64 as 'shape' - shape = op.Cast(shape, to=INT64.dtype) + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + times = torch.tensor([2, 3], dtype=torch.int64) + y = x.repeat_interleave(times, dim=0) + print("repeat_interleave") + print(y) + + ci = times.cumsum(dim=0) + rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1)) + srows = times.shape[0] - rows.to(torch.int64).sum(axis=0) + indices = srows.reshape((-1, )) + print("decomposed") + print(x[indices, :]) + """ + if repeats is None: + repeats = self + self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1) + if dim is None: + # flatten + self = op.Reshape(self, [-1]) + rank = 1 + else: + rank = len(self.shape) + + if rank > 2: + shape_x0 = op.Shape(self, start=0, end=1) + shape_x = op.Shape(self, start=1) + self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) + elif rank == 1: + shape_x = None + self = op.Reshape(self, [-1, 1]) + else: + if rank != 2: + raise NotImplementedError( + f"rank(self)={rank} not implemented for repeat_interleave" + ) + shape_x = None + + ci = op.CumSum(repeats, [0]) + last_ci = op.Gather(ci, [-1]) + trange = op.Range(0, op.Squeeze(last_ci, [0]), 1) + rows = op.Less(trange, op.Unsqueeze(ci, [-1])) + srows = op.Sub( + op.Shape(self, start=0, end=1), + op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]), + ) + indices = op.Reshape(srows, [-1]) + values = op.GatherND(self, op.Unsqueeze(indices, [-1])) + if rank == 2: + return values + # shape_x is None at this stage. + assert shape_x is None # for mypy + return op.Reshape( + values, + op.Concat([-1], shape_x, axis=0) if shape_x else [-1], + ) + + +@torch_op("aten::reshape", trace_only=True) +def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor: + """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" + shape = common_ops.merge_dims(shape) return op.Reshape(self, shape) @@ -6960,14 +7241,14 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::resolve_conj") +@torch_op("aten::resolve_conj", trace_only=True) def aten_resolve_conj(self: TTensor) -> TTensor: """resolve_conj(Tensor(a) self) -> Tensor(a)""" return op.Identity(self) -@torch_op("aten::resolve_neg") +@torch_op("aten::resolve_neg", trace_only=True) def aten_resolve_neg(self: TTensor) -> TTensor: """resolve_neg(Tensor(a) self) -> Tensor(a)""" @@ -7019,74 +7300,84 @@ def aten_rnn_tanh_cell( @torch_op("aten::roll", trace_only=True) -def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor: +def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 0: - return self + return op.Identity(self) elif self.shape[0] == 0: # empty tensor - return self + return op.Identity(self) + + # NOTE: In pytorch, default value of dims is an empty list. + if len(dims) == 0: # Empty sequence + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + return _aten_roll_shift_no_dim_onnx(self, shifts[0]) else: - # NOTE: In pytorch, default value of dims is an empty list. - if len(dims) == 0: # Empty sequence - # assert isinstance(shifts, int) - return _aten_roll_shift_no_dim_onnx(self, shifts) - else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list - result = self - for i in range(len(shifts)): # pylint: disable=consider-using-enumerate - shift = op.Gather(shifts, i, axis=0) - dim = dims[i] - result = _aten_roll_shift_and_dim_onnx(result, shift, dim) - return result + assert len(shifts) == len(dims) + result = self + for i, shift in enumerate(shifts): + dim = dims[i] + result = _aten_roll_shift_and_dim_onnx(result, shift, dim) + return result @torch_op("aten::roll", trace_only=True, complex=True) -def aten_roll_complex(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor: +def aten_roll_complex( + self: TTensor, shifts: Sequence[int], dims: Sequence[int] = () +) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 1: - return self + return op.Identity(self) if self.shape[0] == 0: # empty tensor - return self + return op.Identity(self) self_real = op.Slice(self, [0], [1], axes=[-1]) self_imag = op.Slice(self, [1], [2], axes=[-1]) if not dims: - # assert isinstance(shifts, int) - shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts) - shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts) + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0]) + shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0]) result = op.Concat(shift_real, shift_imag, axis=-1) else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list + assert len(shifts) == len(dims) for i, dim in enumerate(dims): - shift = op.Gather(shifts, i, axis=0) - self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim) - self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim) + self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim) + self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim) result = op.Concat(self_real, self_imag, axis=-1) return result -@torch_op("aten::roll", private=True) -def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: +def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D] self_flatten = op.Reshape(self, neg_1) # Compute slice length - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: + if shift < 0: # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end - slice_length = -shift_tensor + slice_length = op.Constant(value_ints=[-shift]) else: # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end # The effect equals to move [D] to the beginning - slice_length = op.Size(self_flatten) - shift_tensor + slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift]) # Get second part of the tensor, e.g. [A,B,C] suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length) # Get first part of the tensor, e.g. [D] @@ -7096,15 +7387,13 @@ def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: return op.Reshape(result, op.Shape(self)) -@torch_op("aten::roll", private=True) -def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor: +def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) - dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1) - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: - slice_length = -shift_tensor + dim_tensor = op.Constant(value_ints=[dim]) + if shift < 0: + slice_length = op.Constant(value_ints=[-shift]) else: - slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor + slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift]) # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) @@ -7118,7 +7407,7 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te raise NotImplementedError() -@torch_op("aten::round") +@torch_op("aten::round", trace_only=True) def aten_round(self: TFloat) -> TFloat: """round(Tensor self) -> Tensor""" @@ -7167,50 +7456,49 @@ def aten_rrelu( raise NotImplementedError() -def aten_rshift(self: TensorType, other: TensorType) -> TensorType: - """__rshift__.Tensor(Tensor self, Tensor other) -> Tensor""" - - raise NotImplementedError() - - -@torch_op("aten::rsqrt") -def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::rsqrt", trace_only=True) +def aten_rsqrt(self: TFloat) -> TFloat: """rsqrt(Tensor self) -> Tensor""" return op.Reciprocal(op.Sqrt(self)) -@torch_op(("aten::rsub", "aten::rsub.Scalar")) +# Do not register rsub. It will be decomposed and type promoted by torch def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - return op.Sub(other, op.Mul(self, alpha)) - - -@torch_op(("aten::rsub", "aten::rsub.Scalar"), trace_only=True, complex=True) -def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: - """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - - return aten_rsub(self, other, alpha) + raise NotImplementedError @torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> RealType: +def aten_scalar_tensor( + s: TensorType, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1: + dtype = FLOAT.dtype - # Set trace_only=True because different if branches return different dtypes - # which is not supported in an ONNX function return common_ops.cast_to(s, dtype=dtype) @torch_op("aten::scalar_tensor", trace_only=True, complex=True) def aten_scalar_tensor_complex( - s: Union[FLOAT, DOUBLE], dtype: int = COMPLEX64.dtype + s: Union[FLOAT, DOUBLE], + dtype: int = COMPLEX64.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # NOTE: When the input is originally in complex, this function is invoked. # On the other hand, when the input is originally in real, aten_scalar_tensor is used. # is invoked. + if dtype == -1: + dtype = COMPLEX64.dtype if dtype == COMPLEX128.dtype: result = op.Cast(s, to=DOUBLE.dtype) elif dtype == COMPLEX64.dtype: @@ -7222,16 +7510,38 @@ def aten_scalar_tensor_complex( return result -@torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor_sym_number(s: RealType, dtype: int = FLOAT.dtype) -> RealType: - """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" +@torch_op("aten::scatter.src", trace_only=True) +def aten_scatter_src( + self: TTensor, + dim: int, # we have to use int here because ScatterElements() will use this attribute + index: TInt, + src: TTensor, +) -> TTensor: + """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" + if len(index.shape) == 0: + index = op.Unsqueeze(index, [0]) + if len(src.shape) == 0: + src = op.Unsqueeze(src, [0]) + return op.ScatterElements(self, index, src, axis=dim) - # Set trace_only=True because different if branches return different dtypes - # which is not supported in an ONNX function - return common_ops.cast_to(s, dtype=dtype) +@torch_op("aten::scatter.value", trace_only=True) +def aten_scatter_value( + self: TTensor, + dim: int, # we have to use int here because ScatterElements() will use this attribute + index: TInt, + value: float, +) -> TTensor: + """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor""" + # Ensure value is a scalar tensor and expand it to match index shape + if len(index.shape) == 0: + index = op.Unsqueeze(index, [0]) + scalar_tensor = ir.tensor([value], dtype=self.dtype) + src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor) + return op.ScatterElements(self, index, src, axis=dim) -@torch_op("aten::scatter_add") + +@torch_op("aten::scatter_add", trace_only=True) def aten_scatter_add( self: TReal, dim: int, # we have to use int here because ScatterElements() will use this attribute @@ -7244,14 +7554,14 @@ def aten_scatter_add( return op.ScatterElements(self, index, src, axis=dim, reduction="add") -@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True) +@torch_op("aten::scatter_reduce.two", trace_only=True) def aten_scatter_reduce( self: TReal, dim: int, # we have to use int here because ScatterElements() will use this attribute index: TInt, src: TReal, reduce: str, - include_self: bool = True, # pylint: disable=unused-argument + include_self: bool = True, ): """scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor""" @@ -7263,24 +7573,66 @@ def aten_scatter_reduce( "amax": "max", } onnx_reduce = reduce_mode[reduce] - return _aten_scatter_reduce_onnx(self, index, src, dim, onnx_reduce) + dtype = src.dtype or self.dtype + assert dtype is not None, "dtype should be not None" - -@torch_op("aten::scatter_reduce", private=True) -def _aten_scatter_reduce_onnx( - self: TReal, - index: TInt, - src: TReal, - dim: int, - onnx_reduce: str, -): - self_is_scalar = IsScalar(self) + self_is_scalar = len(self.shape) == 0 if self_is_scalar: # assert (index_rank == 0 and rank_src == 0) neg_1 = op.Constant(value_ints=[-1]) self = op.Reshape(self, neg_1) index = op.Reshape(index, neg_1) src = op.Reshape(src, neg_1) + + if not include_self: + # onnx standard always assume the value from self is part of the reduction. + # A first step is added to replace the impacted value by another one + # chosen in a way that the results of the reduction is not changed + # whether or not it takes part in it. + # It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. + # mean is not supported. + if onnx_reduce == "max": + if dtype in { + ir.DataType.FLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + }: + value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) + elif dtype == ir.DataType.BFLOAT16: + value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) + elif dtype == ir.DataType.BOOL: + value = ir.tensor([False], dtype=dtype) + else: + value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) + reduction_init = "min" + elif onnx_reduce == "min": + if dtype in { + ir.DataType.FLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + }: + value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) + elif dtype == ir.DataType.BFLOAT16: + value = ir.tensor([torch.finfo(torch.bfloat16).max], dtype=dtype) + elif dtype == ir.DataType.BOOL: + value = ir.tensor([True], dtype=dtype) + else: + value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) + reduction_init = "max" + elif onnx_reduce == "add": + value = ir.tensor([0], dtype=dtype) + reduction_init = "none" + elif onnx_reduce == "mul": + value = ir.tensor([1], dtype=dtype) + reduction_init = "none" + else: + value = ir.tensor([0], dtype=dtype) + reduction_init = "none" + + cst = op.ConstantOfShape(op.Shape(src), value=value) + self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init) + result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce) + if self_is_scalar: result = op.Squeeze(result) return result @@ -7314,7 +7666,7 @@ def aten_segment_reduce( raise NotImplementedError() -@torch_op(("aten::select", "aten::select.int")) +@torch_op("aten::select.int", trace_only=True) def aten_select(self: TTensor, dim: int, index: int) -> TTensor: """select(Tensor self, int dim, int index) -> Tensor""" @@ -7329,13 +7681,13 @@ def aten_select_backward( raise NotImplementedError() -@torch_op("aten::select_scatter") +@torch_op("aten::select_scatter", trace_only=True) def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int) -> TensorType: """select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor""" # Change src rank to self rank according to dim # e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4] - update = op.Unsqueeze(src, axes=dim) + update = op.Unsqueeze(src, axes=[dim]) # Change index rank to the same as 'update' [2,1,4] indices = op.Expand(index, op.Shape(update)) return op.ScatterElements(self, indices, update, axis=dim, reduction="none") @@ -7360,8 +7712,8 @@ def aten_sgn(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::sigmoid") -def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::sigmoid", trace_only=True) +def aten_sigmoid(self: TFloat) -> TFloat: """sigmoid(Tensor self) -> Tensor""" return op.Sigmoid(self) @@ -7380,21 +7732,36 @@ def aten_signbit(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::sin") +@torch_op("aten::sin", trace_only=True) def aten_sin(self: TFloat) -> TFloat: """sin(Tensor self) -> Tensor""" return op.Sin(self) -@torch_op("aten::sinh") +@torch_op("aten::sinh", trace_only=True) def aten_sinh(self: TFloat) -> TFloat: """sinh(Tensor self) -> Tensor""" return op.Sinh(self) -@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True) +@torch_op(("aten::slice.Tensor"), trace_only=True, complex=True) +def aten_slice_complex( + self: TTensor, + dim: int = 0, + start: Optional[INT64] = None, + end: Optional[INT64] = None, + step: Optional[INT64] = None, +) -> TTensor: + """slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)""" + if dim < 0: + # Account for the complex dimension in ONNX + dim = len(self.shape) + dim - 1 + return aten_slice(self, dim, start, end, step) + + +@torch_op(("aten::slice.Tensor"), trace_only=True) def aten_slice( self: TTensor, dim: int = 0, @@ -7488,7 +7855,7 @@ def aten_slice_scatter( zero, op.Unsqueeze(step, zero), ) - index_base = op.Unsqueeze(index_base, -1) + index_base = op.Unsqueeze(index_base, [-1]) # Use trace only to construct the perm attribute in Transpose dims = None @@ -7522,15 +7889,15 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), trace_only=True) -def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16: +@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) +def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - self_is_scalar = IsScalar(self) + self_is_scalar = len(self.shape) == 0 if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) - if dtype != -1: + if dtype != -1 and dtype is not None: result = op.Cast(result, to=dtype) if self_is_scalar: # Convert to scalar when input is scalar @@ -7539,27 +7906,20 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB return result -@torch_op(("aten::softmax", "aten::softmax.int", "aten::special_softmax"), traceable=True) -def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: - """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) - if self_is_scalar: - # Convert to scalar when input is scalar - result = op.Squeeze(result) - - return result - - +@torch_op("aten::sort", trace_only=True) def aten_sort( - self: TensorType, dim: int = -1, descending: bool = False -) -> tuple[TensorType, TensorType]: - """sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)""" + self: TReal, dim: int = -1, descending: bool = False, stable: bool = False +) -> tuple[TReal, INT64]: + """sort(Tensor self, int dim=-1, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices)""" - raise NotImplementedError() + self_is_scalar = len(self.shape) == 0 + if self_is_scalar: + return op.Identity(self), op.Constant(value_int=0) + shape = op.Shape(self) + dim_size = op.Gather(shape, dim, axis=0) + dim_size = op.Reshape(dim_size, op.Constant(value_ints=[1])) + values, indices = op.TopK(self, dim_size, axis=dim, largest=descending, sorted=True) + return values, indices def aten_sparse_dim(self: TensorType) -> int: @@ -7602,8 +7962,8 @@ def aten_split_with_sizes_copy( raise NotImplementedError() -@torch_op("aten::sqrt") -def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::sqrt", trace_only=True) +def aten_sqrt(self: TFloat) -> TFloat: """sqrt(Tensor self) -> Tensor""" return op.Sqrt(self) @@ -7615,25 +7975,18 @@ def aten_square(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::squeeze") +@torch_op("aten::squeeze", trace_only=True) def aten_squeeze(self: TTensor) -> TTensor: """squeeze(Tensor(a) self) -> Tensor(a)""" return op.Squeeze(self) -@torch_op("aten::squeeze.dim") +@torch_op("aten::squeeze.dim", trace_only=True) def aten_squeeze_dim(self: TTensor, dim: int) -> TTensor: - result = self - if Rank(self) > 0: # type: ignore[operator] - # check if specified dimension is 1, do squeeze - shape = op.Shape(self) - dim_size = op.Gather(shape, dim, axis=0) - if dim_size == 1: - dims = op.Reshape(dim, op.Constant(value_ints=[-1])) - result = op.Squeeze(self, dims) - - return result + if len(self.shape) == 0: + return op.Identity(self) + return op.Squeeze(self, [dim]) @torch_op("aten::squeeze.dim", complex=True, trace_only=True) @@ -7642,6 +7995,9 @@ def aten_squeeze_dim_complex(self: TTensor, dim: int) -> TTensor: # Account for the complex dimension in ONNX dim = dim - 1 + if len(self.shape) == 1: + # The single dimension is the complex dimension + return op.Identity(self) return aten_squeeze_dim(self, dim) @@ -7668,168 +8024,126 @@ def aten_stack_complex(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTen return aten_stack(tensors, dim) -@torch_op("aten::stack") +@torch_op("aten::stack", trace_only=True) def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrString: """stack(Tensor[] tensors, int dim=0) -> Tensor""" + if isinstance(tensors, Sequence): + unsqueezed = [op.Unsqueeze(t, op.Constant(value_ints=[dim])) for t in tensors] + return op.Concat(*unsqueezed, axis=dim) return op.ConcatFromSequence(tensors, axis=dim, new_axis=1) -def aten_std(self: TensorType, unbiased: bool = True) -> TensorType: +# std is decomposed by PyTroch +def aten_std(self: TReal, unbiased: bool = True) -> TReal: """std(Tensor self, bool unbiased=True) -> Tensor""" + var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False) + return op.Sqrt(var) - raise NotImplementedError() +# std_dim is decomposed by PyTroch +def aten_std_dim( + self: TReal, + dim: Sequence[int], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, +) -> TReal: + """std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor""" -def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]: - """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" + var = _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim) + return op.Sqrt(var) - raise NotImplementedError() +# std is decomposed by PyTroch +def aten_std_correction( + self: TReal, + # FIXME(justinchuby): Make dim Optional[Sequence[int]] + dim: Optional[int] = None, + correction: Optional[float] = None, + keepdim: bool = False, +) -> TReal: + """std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor""" -@torch_op("aten::stft", private=True) -def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]: - signal_rank = Rank(self) - if signal_rank == 1: - # Add a batch dimension - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - return self, signal_rank - - -@torch_op("aten::stft", private=True) -def _center_window_around_zeros_if_needed( - window: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: - # first dimension - n_win = op.Shape(window, start=0, end=1) - # Center window around zeros if needed (required by ONNX's STFT) - if n_win < n_fft: - left = (n_fft - n_win) / 2 - - right = n_fft - left - n_win - left = op.Reshape(left, op.Constant(value_ints=[1])) - right = op.Reshape(right, op.Constant(value_ints=[1])) - - left_win = op.Expand(op.Constant(value_ints=[0]), left) - right_win = op.Expand(op.Constant(value_ints=[0]), right) - right_win = op.CastLike(right_win, window) - left_win = op.CastLike(left_win, window) - window = op.Concat(left_win, window, right_win, axis=0) - return window - - -@torch_op("aten::stft", private=True) -def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16: - left = (n_fft - win_length) / 2 - - right = n_fft - left - win_length - left = op.Reshape(left, op.Constant(value_ints=[1])) - right = op.Reshape(right, op.Constant(value_ints=[1])) - win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) - - left_win = op.Expand(op.Constant(value_ints=[0]), left) - right_win = op.Expand(op.Constant(value_ints=[0]), right) - window_list = op.Expand(op.Constant(value_ints=[1]), win_length) - return op.Concat(left_win, window_list, right_win, axis=0) - - -@torch_op("aten::stft", private=True) -def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16: - n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) - window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) - return window - - -@torch_op("aten::stft", private=True) -def _normalize_fft_result( - signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: - n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) - sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) - result = result / sqrt_nfft - return result + if correction is None: + correction = 1.0 + if dim is None: + var = _aten_var_onnx(self, correction=correction, keepdim=keepdim) + else: + var = _aten_var_dim_onnx(self, dims=dim, correction=correction, keepdim=keepdim) + return op.Sqrt(var) -@torch_op("aten::stft", private=True) -def _aten_stft_onnx( - signal: TFloatOrBFloat16, - frame_step_const: INT64, - window: Union[TFloatOrBFloat16, INT64], - frame_length_const: INT64, - signal_rank: INT64, - onesided: int, -) -> TFloatOrBFloat16: - window = op.CastLike(window, signal) - result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) - result = op.Transpose(result, perm=[0, 2, 1, 3]) - # Remove batch dimension, if needed - if signal_rank == 1: - result = op.Squeeze(result, op.Constant(value_ints=[0])) - return result +# std_mean is decomposed by PyTroch +def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: + """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" -@torch_op("aten::stft", trace_only=True) -def aten_stft( - self: TFloatOrBFloat16, - n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: Optional[TFloatOrBFloat16] = None, - normalized: bool = False, - onesided: Optional[bool] = None, - return_complex: Optional[bool] = None, -) -> TFloatOrBFloat16: - """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" - - # NOTE: regarless of the value of return_complex, we always return a real representation. - del return_complex - - # Get STFT sizes - if hop_length is None: - # core dump - # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) - hop_length = n_fft // 4 - frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) - frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) - - # Pre-process input if needed - self, signal_rank = _add_batch_dimension(self) - - # Get window and make sure it's the same size as `win_length` or `n_fft` - if window is not None and window.shape[0] is not None: - window = _center_window_around_zeros_if_needed(window, n_fft) - elif window is None: - if win_length is not None: - window = _create_window_from_win_length(win_length, n_fft) - else: - window = _create_window_from_n_fft(n_fft) + # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" + # If not this case, should be explicitly set correction value according to unbiased value + var, mean = _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False) + return op.Sqrt(var), mean - if onesided is None or onesided: - onesided = 1 - else: - onesided = 0 - # remove batch dimension included - result = _aten_stft_onnx( - self, frame_step_const, window, frame_length_const, signal_rank, onesided + +# std_mean is decomposed by PyTroch +def aten_std_mean_dim( + self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False +) -> Tuple[TReal, TReal]: + """std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" + + # Although dim is Optional in signature, but we assume it must have value for this overload + # Assert(dim is not None) + var, mean = _aten_var_mean_dim_onnx( + self, dims=dim, correction=float(unbiased), keepdim=keepdim ) + return op.Sqrt(var), mean - # Normalize, if needed - if normalized: - result = _normalize_fft_result(self, result, n_fft) - return result +# std_mean is decomposed by PyTroch +def aten_std_mean_correction( + self: TReal, + # FIXME(justinchuby): Make dim Optional[Sequence[int]] + dim: Optional[int] = None, + correction: Optional[float] = None, + keepdim: bool = False, +) -> Tuple[TReal, TReal]: + """std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)""" + + if correction is None: + correction = 1.0 + + if dim is None: + var, mean = _aten_var_mean_onnx(self, correction=correction, keepdim=keepdim) + else: + var, mean = _aten_var_mean_dim_onnx( + self, dims=dim, correction=correction, keepdim=keepdim + ) + return op.Sqrt(var), mean -@torch_op(("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub")) +@torch_op( + ( + "aten::sub.Tensor", + "aten::sub.Scalar", + "aten::subtract.Tensor", + "aten::subtract.Scalar", + "_operator::sub", + ), + trace_only=True, +) def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - alpha = op.CastLike(alpha, other) - other = op.Mul(other, alpha) - + if alpha != 1.0: + alpha = op.CastLike(alpha, other) + other = op.Mul(other, alpha) return op.Sub(self, other) @torch_op( - ("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"), + ( + "aten::sub.Tensor", + "aten::sub.Scalar", + "aten::subtract.Tensor", + "aten::subtract.Scalar", + ), trace_only=True, complex=True, ) @@ -7839,53 +8153,35 @@ def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return aten_sub(self, other, alpha=alpha) -@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True) -def aten_sum_dim_IntList( - self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1 -) -> TReal: - """sum(Tensor self, SymInt dim, bool keepdim, *, ScalarType? dtype=None) -> Tensor""" - - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - # TODO: Combine the overloads when OptionalHasElement() works - if dim is None: - result = _aten_sum_dim_none(self, keepdim=keepdim) +@torch_op("aten::sum", trace_only=True) +def aten_sum(self: TReal, dtype: int = -1) -> TReal: + """sum(Tensor self, *, ScalarType? dtype=None) -> Tensor""" + if len(self.shape) == 0: + result = op.Identity(self) else: - result = _aten_sum_dim_onnx(self, dim, keepdim=keepdim) - - if dtype != -1: + result = op.ReduceSum(self, keepdims=False) + if dtype != -1 and dtype is not None: result = op.Cast(result, to=dtype) - return result -@torch_op("aten::sum", private=True, traceable=True) -def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Reshape(self, op.Constant(value_ints=[-1])) - - if IsScalar(dim): +@torch_op("aten::sum.dim_IntList", trace_only=True) +def aten_sum_dim_IntList( + self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1 +) -> TReal: + """sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" + if len(self.shape) == 0: + result = op.Identity(self) + elif dim is None: + result = op.ReduceSum(self, keepdims=keepdim) + else: dim = op.Reshape(dim, op.Constant(value_ints=[-1])) dim = op.Cast(dim, to=INT64.dtype) - result = op.ReduceSum(self, dim, keepdims=keepdim) + result = op.ReduceSum(self, dim, keepdims=keepdim) - if self_is_scalar: - result = op.Squeeze(result) - return result - - -@torch_op("aten::sum", private=True) -def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal: - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Reshape(self, op.Constant(value_ints=[-1])) - - result = op.ReduceSum(self, keepdims=keepdim) + if dtype != -1 and dtype is not None: + result = op.Cast(result, to=dtype) - if self_is_scalar: - result = op.Squeeze(result) return result @@ -7915,17 +8211,10 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType: raise NotImplementedError() -@torch_op("aten::sym_size") -def aten_sym_size(self: TReal, dim: int = 0) -> TReal: - """sym_size(Tensor self, int dim) -> Tensor""" - # NOTE: onnxscript doesn't support attribute process, - # so op.Shape(self, start=dim, end=dim + 1) is not supported. - shape = op.Shape(self) - # Reshape helps dim from int to tensor, and - # input arguments support attribute processing. - start = op.Reshape(dim, op.Constant(value_ints=[1])) - end = op.Reshape(dim + 1, op.Constant(value_ints=[1])) - return op.Slice(shape, start, end) +@torch_op("aten::sym_size.int", trace_only=True) +def aten_sym_size(self: TensorType, dim: int = 0) -> INT64: + """sym_size.int(Tensor self, int dim) -> SymInt""" + return op.Squeeze(op.Shape(self, end=dim + 1, start=dim)) def aten_symeig( @@ -7936,7 +8225,7 @@ def aten_symeig( raise NotImplementedError() -@torch_op("aten::t", traceable=True) +@torch_op("aten::t", trace_only=True) def aten_t(self: TTensor) -> TTensor: """t(Tensor(a) self) -> Tensor(a)""" @@ -7969,33 +8258,33 @@ def aten_take_along_dim( raise NotImplementedError() -@torch_op("aten::tan") +@torch_op("aten::tan", trace_only=True) def aten_tan(self: TFloat) -> TFloat: """tan(Tensor self) -> Tensor""" return op.Tan(self) -@torch_op("aten::tanh") +@torch_op("aten::tanh", trace_only=True) def aten_tanh(self: TFloat) -> TFloat: """tanh(Tensor self) -> Tensor""" return op.Tanh(self) -@torch_op("aten::tensor.bool") +@torch_op("aten::tensor.bool", trace_only=True) def aten_tensor_bool(self: bool, dtype: int) -> TensorType: tensor = op.Constant(value_int=self) return op.Cast(tensor, to=dtype) -@torch_op("aten::tensor.float") +@torch_op("aten::tensor.float", trace_only=True) def aten_tensor_float(self: float, dtype: int) -> TensorType: tensor = op.Constant(value_float=self) return op.Cast(tensor, to=dtype) -@torch_op("aten::tensor.int") +@torch_op("aten::tensor.int", trace_only=True) def aten_tensor_int(self: int, dtype: int) -> TensorType: tensor = op.Constant(value_int=self) return op.Cast(tensor, to=dtype) @@ -8045,7 +8334,7 @@ def aten_tile(self: TTensor, dims: INT64) -> TTensor: exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) self_shape = op.Shape(self) self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0) - self = op.Reshape(self, self_final_shape) + self = op.Reshape(self, self_final_shape, allowzero=True) return op.Tile(self, dims) @@ -8112,20 +8401,14 @@ def aten_to_sparse_csr(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::topk", traceable=True) +@torch_op("aten::topk", trace_only=True) def aten_topk( - self: TReal, k: INT64, dim: int = -1, largest: bool = True, sorted: bool = True + self: TReal, k: int, dim: int = -1, largest: bool = True, sorted: bool = True ) -> Tuple[TReal, INT64]: """topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)""" - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - k = op.Reshape(op.Cast(k, to=INT64.dtype), op.Constant(value_ints=[1])) - values, indices = op.TopK(self, k, axis=dim, largest=largest, sorted=sorted) - if self_is_scalar: - values = op.Squeeze(values, op.Constant(value_ints=[0])) - indices = op.Squeeze(indices, op.Constant(value_ints=[0])) + # We do not handle scalar inputs for topk + values, indices = op.TopK(self, [k], axis=dim, largest=largest, sorted=sorted) return values, indices @@ -8141,8 +8424,8 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() -@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True) -def aten_transpose(self, dim0: int, dim1: int): +@torch_op("aten::transpose.int", trace_only=True) +def aten_transpose(self: TTensor, dim0: int, dim1: int) -> TTensor: """transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)""" # Use trace only to construct the prem attribute in Transpose @@ -8161,8 +8444,8 @@ def aten_transpose(self, dim0: int, dim1: int): return result -@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True, complex=True) -def aten_transpose_complex(self, dim0: int, dim1: int): +@torch_op("aten::transpose.int", trace_only=True, complex=True) +def aten_transpose_complex(self: TTensor, dim0: int, dim1: int) -> TTensor: """transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)""" # Use trace only to construct the prem attribute in Transpose @@ -8199,7 +8482,7 @@ def aten_triangular_solve( raise NotImplementedError() -@torch_op("aten::tril") +@torch_op("aten::tril", trace_only=True) def aten_tril(self: TTensor, diagonal: int = 0) -> TTensor: """tril(Tensor self, int diagonal=0) -> Tensor""" @@ -8227,7 +8510,7 @@ def aten_triplet_margin_loss( raise NotImplementedError() -@torch_op("aten::triu") +@torch_op("aten::triu", trace_only=True) def aten_triu(self: TTensor, diagonal: int = 0) -> TTensor: """triu(Tensor self, int diagonal=0) -> Tensor""" @@ -8240,40 +8523,43 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType: raise NotImplementedError() -@torch_op("aten::trunc") -def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op("aten::trunc", trace_only=True) +def aten_trunc(self: TFloat) -> TFloat: """trunc(Tensor self) -> Tensor""" - - # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126 - integer_parts = op.Floor(op.Abs(self)) - is_negative = op.Less(self, 0.0) - return op.Where(is_negative, op.Neg(integer_parts), integer_parts) + # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-2658170591 + return op.Floor(op.Abs(self)) * op.Sign(self) -def aten_type_as(self: TensorType, other: TensorType) -> TensorType: +@torch_op("aten::type_as", trace_only=True) +def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: """type_as(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.CastLike(self, other) -@torch_op(("aten::unbind", "aten::unbind.int")) +@torch_op("aten::unbind.int", trace_only=True) def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - split_sizes = op.Constant(value_int=1) - return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) + if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): + # We can create a definitive split op if the input shape is static + # Only torch>=2.7 supports correctly generating the correct number of outputs for Split + outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) + return [op.Squeeze(out, [dim]) for out in outputs] + + return op.SplitToSequence(self, axis=dim, keepdims=False) -@torch_op("aten::unflatten") -def aten_unflatten(self: TReal, dim: INT64, sizes: INT64): +@torch_op("aten::unflatten.int", trace_only=True) +def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]): """unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)""" self_size = op.Shape(self) # PyTorch accepts negative dim as reversed counting - self_rank = op.Size(self_size) - dim = self_rank + dim - dim = dim % self_rank + self_rank = len(self.shape) + if dim < 0: + dim = self_rank + dim head_start_idx = op.Constant(value_ints=[0]) head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1])) @@ -8283,9 +8569,17 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64): tail_end_idx = op.Constant(value_ints=[_INT64_MAX]) tail_part_rank = op.Slice(self_size, tail_start_idx, tail_end_idx) - final_shape = op.Concat(head_part_rank, sizes, tail_part_rank, axis=0) + sizes = [op.Reshape(size, op.Constant(value_ints=[1])) for size in sizes] - return op.Reshape(self, final_shape) + # corner case 1: head part is None + if dim == 0: + final_shape = op.Concat(*sizes, tail_part_rank, axis=0) + # corner case 2: tail part is None + elif dim == self_rank - 1: + final_shape = op.Concat(head_part_rank, *sizes, axis=0) + else: + final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0) + return op.Reshape(self, final_shape, allowzero=True) @torch_op("aten::unfold", trace_only=True) @@ -8294,44 +8588,42 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor: self_rank = len(self.shape) if self_rank == 0: - result = op.Unsqueeze(self, 0) + result = op.Unsqueeze(self, [0]) else: # Handle negative dimension if dimension < 0: dimension = dimension + self_rank - dim_size = self.shape[dimension] - target_end = (dim_size - size) // step + 1 - if target_end >= 1: # the rank of final reuslt will be self_rank + 1 - self_rank = self_rank + 1 + + input_shape = op.Shape(self) + dim_size = op.Gather(input_shape, op.Constant(value_ints=[dimension])) + + # Create indices for each window + window_starts = op.Range(0, op.Sub(dim_size, size - 1), step) + + # Create the base indices for one window + window_indices = list(range(size)) + + # Broadcast to create all indices + starts_expanded = op.Unsqueeze(window_starts, [1]) # [num_windows, 1] + indices_expanded = op.Unsqueeze(window_indices, [0]) # [1, size] + all_indices = op.Add(starts_expanded, indices_expanded) # [num_windows, size] + + # Gather along the specified dimension + result = op.Gather(self, all_indices, axis=dimension) + + # The result shape is now [..., num_windows, size, ...] with num_windows at position 'dimension'. + # We need to move the size dimension to the end: + # Current shape: [..., num_windows, size, ...] + # Target shape: [..., num_windows, ..., size] + + # Move the size dimension (at position dimension+1) to the end # perm need to be list[int], so have to be generated in trace_only mode - perm = list(range(self_rank)) - # from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1 + perm = list(range(self_rank + 1)) perm.append(perm.pop(dimension + 1)) - result = _aten_unfold_onnx(self, dimension, size, step, target_end, perm) - return result + result = op.Transpose(result, perm=perm) -@torch_op("aten::unfold", private=True) -def _aten_unfold_onnx( - self: TTensor, dim: int, size: int, step: int, target_end: int, perm: Sequence[int] -) -> TTensor: - dims = op.Reshape(op.Constant(value_int=dim), op.Constant(value_ints=[-1])) - # FIXME(justinchuby): obtain the dtype for SequenceEmpty, currently it assumes float - seq_result = op.SequenceEmpty() - i = op.Constant(value_int=0) - cond = i < target_end - while cond: # because for loop cannot work here, so use while loop - starts = op.Reshape(i * step, [-1]) # starts is [0, step, step*2, step*3, ...] - ends = starts + size # ends is [0+size, step+size, step*2+size, step*3+size, ...] - slice_result = op.Slice(self, starts, ends, dims) - # sequence only support float32 - slice_result_float32 = op.Cast(slice_result, to=FLOAT.dtype) - seq_result = op.SequenceInsert(seq_result, slice_result_float32) - i = i + 1 - cond = i < target_end - concat_result = op.ConcatFromSequence(seq_result, axis=dim, new_axis=1) - result = op.Transpose(concat_result, perm=perm) - return op.CastLike(result, self) + return result def aten_unfold_backward( @@ -8359,16 +8651,84 @@ def aten_unique_consecutive( raise NotImplementedError() +@torch_op("aten::_unique", trace_only=True) +def aten__unique( + self: TensorType, + sorted: bool = True, # pylint: disable=unused-argument + return_inverse: bool = False, +) -> tuple[TensorType, TensorType]: + """_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)""" + + unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True) + input_size = op.Shape(self) + if return_inverse: + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) + else: + input_numel = op.ReduceProd(input_size, keepdims=False) + if input_numel == 0: + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) + else: + inverse_indices = op.ConstantOfShape([0]) + inverse_indices = op.Cast(inverse_indices, to=INT64.dtype) + return unique_values, inverse_indices + + +@torch_op("aten::_unique2", trace_only=True) +def aten__unique2( + self: TensorType, + sorted: bool = True, # pylint: disable=unused-argument + return_inverse: bool = False, + return_counts: bool = False, +) -> tuple[TensorType, TensorType, TensorType]: + """_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)""" + + unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True) + input_size = op.Shape(self) + if return_inverse: + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) + else: + input_numel = op.ReduceProd(input_size, keepdims=False) + if input_numel == 0: + inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True) + else: + inverse_indices = op.ConstantOfShape([0]) + inverse_indices = op.Cast(inverse_indices, to=INT64.dtype) + if not return_counts: + counts = op.ConstantOfShape([0]) + counts = op.Cast(counts, to=INT64.dtype) + return unique_values, inverse_indices, counts + + +@torch_op("aten::unique_dim", trace_only=True) def aten_unique_dim( self: TensorType, dim: int, - sorted: bool = True, + sorted: bool = True, # pylint: disable=unused-argument return_inverse: bool = False, return_counts: bool = False, ) -> tuple[TensorType, TensorType, TensorType]: """unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)""" - raise NotImplementedError() + unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True) + input_size = op.Shape(self) + # Normalize dim to be non-negative + input_ndim = op.Max(op.Size(input_size), op.Constant(value_ints=[1])) + dim = op.Mod(dim, input_ndim) + if return_inverse: + inverse_indices = op.Reshape( + inverse_indices, + op.Reshape(op.Slice(input_size, dim, dim + 1), op.Constant(value_ints=[-1])), + ) + else: + inverse_indices = op.ConstantOfShape([0]) + inverse_indices = op.Cast(inverse_indices, to=INT64.dtype) + if return_counts: + output_size = op.Shape(unique_values) + counts = op.Reshape(counts, op.Reshape(op.Slice(output_size, dim, dim + 1), [-1])) + else: + counts = op.ConstantOfShape([0]) + counts = op.Cast(counts, to=INT64.dtype) + return unique_values, inverse_indices, counts def aten_unique_dim_consecutive( @@ -8385,10 +8745,11 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType raise NotImplementedError() -def aten_unsafe_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: +@torch_op("aten::unsafe_split.Tensor") +def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]: """unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]""" - raise NotImplementedError() + return op.SplitToSequence(self, split_size, axis=dim) def aten_unsafe_split_with_sizes( @@ -8399,12 +8760,11 @@ def aten_unsafe_split_with_sizes( raise NotImplementedError() -@torch_op("aten::unsqueeze") +@torch_op("aten::unsqueeze", trace_only=True) def aten_unsqueeze(self: TTensor, dim: int) -> TTensor: """unsqueeze(Tensor(a) self, int dim) -> Tensor(a)""" - dim = op.Cast(dim, to=INT64.dtype) - return op.Unsqueeze(self, dim) + return op.Unsqueeze(self, [dim]) def aten_unsqueeze_copy(self: TensorType, dim: int) -> TensorType: @@ -8441,7 +8801,7 @@ def aten_vander( raise NotImplementedError() -@torch_op("aten::var", trace_only=True) +# var is decomposed by PyTroch def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal: """var(Tensor self, bool unbiased=True) -> Tensor""" @@ -8450,7 +8810,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal: return _aten_var_onnx(self, correction=float(unbiased), keepdim=False) -@torch_op("aten::var.dim", trace_only=True) +# var is decomposed by PyTroch def aten_var_dim( self: TReal, dim: Sequence[int], @@ -8462,7 +8822,7 @@ def aten_var_dim( return _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim) -@torch_op("aten::var.correction", trace_only=True) +# var is decomposed by PyTroch def aten_var_correction( self: TReal, # FIXME(justinchuby): Make dim Optional[Sequence[int]] @@ -8482,7 +8842,7 @@ def aten_var_correction( return var -@torch_op("aten::var", private=True, traceable=True) +# var is decomposed by PyTroch def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal: mean = op.ReduceMean(self, keepdims=keepdim) sub_mean = op.Sub(self, mean) @@ -8499,7 +8859,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe return var -@torch_op("aten::var.dim", private=True, traceable=True) +# var is decomposed by PyTroch def _aten_var_dim_onnx( self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False ) -> TReal: @@ -8520,7 +8880,7 @@ def _aten_var_dim_onnx( return var -@torch_op("aten::var_mean", trace_only=True) +# var_mean is decomposed by PyTroch def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: """var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" @@ -8529,7 +8889,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: return _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False) -@torch_op("aten::var_mean.dim", trace_only=True) +# var_mean is decomposed by PyTroch def aten_var_mean_dim( self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False ) -> Tuple[TReal, TReal]: @@ -8540,7 +8900,7 @@ def aten_var_mean_dim( return _aten_var_mean_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim) -@torch_op("aten::var_mean.correction", trace_only=True) +# var_mean is decomposed by PyTroch def aten_var_mean_correction( self: TReal, # FIXME(justinchuby): Make dim Optional[Sequence[int]] @@ -8562,7 +8922,7 @@ def aten_var_mean_correction( return var, mean -@torch_op("aten::var_mean", private=True) +# var_mean is decomposed by PyTroch def _aten_var_mean_onnx( self: TReal, correction: float = 1.0, keepdim: bool = False ) -> Tuple[TReal, TReal]: @@ -8582,7 +8942,7 @@ def _aten_var_mean_onnx( return var, mean -@torch_op("aten::var_mean.dim", private=True) +# var_mean is decomposed by PyTroch def _aten_var_mean_dim_onnx( self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False ) -> Tuple[TReal, TReal]: @@ -8610,32 +8970,31 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::view") -def aten_view(self: TTensor, size: IntType) -> TTensor: +@torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True) +def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - return op.Reshape(self, size) + size = common_ops.merge_dims(size) + return op.Reshape(self, size, allowzero=True) -@torch_op("aten::view", complex=True) -def aten_view_complex(self: TTensor, size: IntType) -> TTensor: +@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True) +def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) - return op.Reshape(self, complex_size) + complex_size = common_ops.merge_dims([*size, 2]) + return op.Reshape(self, complex_size, allowzero=True) -@torch_op("aten::view_as") +@torch_op("aten::view_as", trace_only=True) def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" size = op.Shape(other) - return op.Reshape(self, size) + return op.Reshape(self, size, allowzero=True) -@torch_op("aten::view_as_complex") +@torch_op("aten::view_as_complex", trace_only=True) def aten_view_as_complex(self: TTensor) -> TTensor: """view_as_complex(Tensor(a) self) -> Tensor(a)""" @@ -8644,7 +9003,7 @@ def aten_view_as_complex(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_as_complex_copy") +@torch_op("aten::view_as_complex_copy", trace_only=True) def aten_view_as_complex_copy(self: TTensor) -> TTensor: """view_as_complex_copy(Tensor self) -> Tensor""" @@ -8653,7 +9012,7 @@ def aten_view_as_complex_copy(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_as_real", complex=True) +@torch_op("aten::view_as_real", complex=True, trace_only=True) def aten_view_as_real(self: TTensor) -> TTensor: """view_as_real(Tensor(a) self) -> Tensor(a)""" @@ -8662,7 +9021,7 @@ def aten_view_as_real(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_as_real_copy", complex=True) +@torch_op("aten::view_as_real_copy", complex=True, trace_only=True) def aten_view_as_real_copy(self: TTensor) -> TTensor: """view_as_real_copy(Tensor self) -> Tensor""" @@ -8671,15 +9030,15 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_copy") -def aten_view_copy(self: TTensor, size: IntType) -> TTensor: +@torch_op("aten::view_copy", trace_only=True) +def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size) -@torch_op("aten::vstack") +# Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918 def aten_vstack(tensors: Sequence[TTensor]) -> TTensor: """vstack(Tensor[] tensors) -> Tensor""" @@ -8697,7 +9056,15 @@ def reshape_to_2d(tensor): return op.ConcatFromSequence(tensors_2d, axis=0) -@torch_op(("aten::where", "aten::where.self")) +@torch_op( + ( + "aten::where.Scalar", + "aten::where.ScalarSelf", + "aten::where.ScalarOther", + "aten::where.self", + ), + trace_only=True, +) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor""" @@ -8710,33 +9077,44 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::zeros") -def aten_zeros(size: IntType, dtype: int = FLOAT.dtype): +@torch_op("aten::zeros", trace_only=True) +def aten_zeros( + size: Sequence[INT64], + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + if dtype == -1: + dtype = FLOAT.dtype - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @torch_op("aten::zeros_like", trace_only=True) -def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor: +def aten_zeros_like( + self: TTensor, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, + memory_format: str = "", +) -> TTensor: """zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" # NOTE: trace_only because both if branches need to be the same type, but we have # a cast in the if branch. + if dtype is None: + dtype = -1 if dtype == -1: zero = op.CastLike(0, self) else: zero = op.Cast(0, to=dtype) - return _aten_zeros_like_onnx(self, zero) - - -@torch_op("aten::zeros_like", private=True) -def _aten_zeros_like_onnx(self: TTensor, zero) -> TTensor: shape = op.Shape(self) return op.Expand(zero, shape) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index f35b4f611b..ea92dc347d 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -21,98 +21,33 @@ from onnxscript.onnx_types import TensorType -@torch_op( - ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), - private=True, - complex=True, - traceable=True, -) def _fftn_onnx_normalization( - self, - transformed: TFloat, + self: TFloat, normalization: int, - forward: bool, - dims: Sequence[int], -) -> TFloat: - # Obtain the total_sample_count (n) for normalization - self_shape = op.Shape(self) - total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) - total_sample_count = op.CastLike(total_sample_count, transformed) - - # Normalize the result - # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn - # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 - if normalization == 1: - # "forward" - normalize by 1/n - if forward: - result = op.Div(transformed, op.Sqrt(total_sample_count)) - else: - result = op.Mul(transformed, op.Sqrt(total_sample_count)) - elif normalization == 2: - # "ortho" - normalize by 1/sqrt(n) - if forward: - result = op.Div(transformed, total_sample_count) - else: - result = transformed - else: - # "backward" - no normalization - if forward: - result = transformed - else: - result = op.Mul(transformed, total_sample_count) - - return result - - -@torch_op( - ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), - trace_only=True, - private=True, - complex=True, -) -def _fftn_onnx( - self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool + signal_size: INT64, + inverse: bool = False, ) -> TFloat: - """Standard complex to complex or real to complex FFT (forward or backward). - - This is a private shared function for implementing the various FFT functions. - - Args: - self: The input tensor. - dims: The dimensions to apply FFT. - normalization: The normalization mode. - inverse: Whether to compute the inverse FFT. - onesided: Whether to compute the one-sided FFT, which retains only the - positive frequencies. - - Returns: - The transformed tensor. - """ - - # NOTE: trace_only because we need to process each dimension in a loop - # NOTE: SymInt dim is not support because DFT-17 needs a static axis - # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support - - # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new - # dimension at the beginning to represent the batch dimension. - transformed = op.Unsqueeze(self, axes=[0]) - - # Add 1 to account for the batch dimension when counting axes from the left - new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] - - for dim in new_dims[:-1]: - transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False) - - # Torch computers one-sided FFT on the last dimension only. - if onesided: - transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True) + """Normalize in forward or backward direction.""" + # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131 + # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19 + # Modes: + # 0: no normalization (backward) + # 1: "ortho" - divide by 1/sqrt(signal_size) (ortho) + # 2: divide by signal_size (forward) + signal_size = op.CastLike(signal_size, self) + if not inverse: + # Forward normalization + if normalization == 1: + self = op.Div(self, op.Sqrt(signal_size)) + elif normalization == 2: + self = op.Div(self, signal_size) else: - transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False) - - # Remove the batch dimension - transformed = op.Squeeze(transformed, axes=[0]) - - return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) + # Backward normalization, accounting for op.DFT already dividing by signal_size + if normalization == 0: + self = op.Mul(self, signal_size) + elif normalization == 1: + self = op.Mul(self, op.Sqrt(signal_size)) + return self @torch_op("aten::_fft_c2c", trace_only=True, complex=True) @@ -124,14 +59,34 @@ def aten__fft_c2c( Standard complex to complex FFT (forward or backward). """ - # NOTE: trace_only because we need to negate forward - # NOTE: SymInt dim is not support because DFT-17 needs a static axis - # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support + # NOTE: SymInt dim is not supported because DFT-17 needs a static axis # ONNX DFT input assumes the last dimension is the complex dimension. - # Thus dim=-1 in PyTorch is dim=-2 in ONNX. - dim = [d - 1 if d < 0 else d for d in dim] - return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False) + + unsqueeze_first_dim = 0 in dim + # 1. Add a new dimension for the end and batch dimension, if needed + # 2. ONNX DFT input assumes the last dimension is the complex dimension. + # If needed, add 1 to account for the batch dimension. + + if unsqueeze_first_dim: + transformed = op.Unsqueeze(self, axes=[0]) + dim = [d + 1 for d in dim] + else: + transformed = self + + for dimension in reversed(dim): + transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False) + transformed = _fftn_onnx_normalization( + transformed, + normalization, + op.Shape(transformed, start=dimension, end=dimension + 1), + not forward, + ) + + if unsqueeze_first_dim: + transformed = op.Squeeze(transformed, axes=[0]) + + return transformed @torch_op("aten::_fft_c2r", trace_only=True, complex=True) @@ -139,24 +94,52 @@ def aten__fft_c2r( self: TFloat, dim: Sequence[int], normalization: int, - last_dim_size: INT64, # pylint: disable=unused-argument + last_dim_size: INT64, ) -> TFloat: """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor - Complex to real inverse FFT. + Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation. """ - - # TODO(justinchuby): Figure out what last_dim_size does - - self_rank = len(self.shape) - # ONNX DFT input assumes the last dimension is the complex dimension. - # Thus dim=-1 in PyTorch is dim=-2 in ONNX. - dim = [(d - 1) + self_rank if d < 0 else d for d in dim] - transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) - # Take only the real part - real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) - - return op.Squeeze(real_part, axes=[-1]) + if len(dim) != 1: + raise NotImplementedError("Only one dimension is supported for inverse FFT") + + dimension = dim[0] + unsqueeze_first_dim = dimension == 0 + # 1. Add a new dimension for batch dimension, if needed + # 2. ONNX DFT input assumes the last dimension is the complex dimension. + # If needed, add 1 to account for the batch dimension. + + if unsqueeze_first_dim: + transformed = op.Unsqueeze(self, axes=[0]) + dimension = 1 + else: + transformed = self + + # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed + # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we + # place no such restriction on the ONNX side. + transformed = op.DFT( + transformed, + dft_length=last_dim_size, + axis=dimension, + inverse=True, + onesided=False, + ) + transformed = _fftn_onnx_normalization( + transformed, + normalization, + op.Shape(transformed, start=dimension, end=dimension + 1), + inverse=True, + ) + + if unsqueeze_first_dim: + transformed = op.Squeeze(transformed, axes=[0]) + + # Remove the imaginary part + transformed = op.Slice(transformed, [0], [1], [-1]) + transformed = op.Squeeze(transformed, axes=[-1]) + + return transformed @torch_op("aten::_fft_r2c", trace_only=True) @@ -168,17 +151,37 @@ def aten__fft_r2c( Real to complex forward FFT. """ - # Add a new dimension at the end - signal = op.Unsqueeze(self, axes=[-1]) # No need to fill the imaginary part because ONNX DFT accepts real inputs # https://onnx.ai/onnx/operators/onnx__DFT.html#inputs - self_rank = len(self.shape) - # ONNX DFT input assumes the last dimension is the complex dimension. - # Thus dim=-1 in PyTorch is dim=-2 in ONNX. - dim = [(d - 1) + self_rank if d < 0 else d for d in dim] + unsqueeze_first_dim = 0 in dim + # 1. Add a new dimension for the end and batch dimension, if needed + # 2. ONNX DFT input assumes the last dimension is the complex dimension. + # If needed, add 1 to account for the batch dimension. + + if unsqueeze_first_dim: + transformed = op.Unsqueeze(self, axes=[0, -1]) + dim = [d + 1 for d in dim] + else: + transformed = op.Unsqueeze(self, axes=[-1]) + + for idx, dimension in enumerate(reversed(dim)): + transformed = _fftn_onnx_normalization( + transformed, + normalization, + op.Shape(transformed, start=dimension, end=dimension + 1), + inverse=False, + ) + if idx > 0: + transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False) + else: + # Torch computes one-sided FFT on the last dimension only. + transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=onesided) + + if unsqueeze_first_dim: + transformed = op.Squeeze(transformed, axes=[0]) - return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided) + return transformed def aten_fft_fft( diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 7890fb1c0b..c9d870bd86 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -12,17 +12,15 @@ from __future__ import annotations +import math from typing import Optional, Sequence -from onnxscript import BOOL, FLOAT, INT64 -from onnxscript.function_libs.torch_lib.ops import common as common_ops +from onnxscript import BOOL from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TFloat +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TTensor from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType -IsScalar = common_ops.IsScalar - def aten_linalg_cholesky(self: TensorType, upper: bool = False) -> TensorType: """linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor""" @@ -44,13 +42,14 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType: raise NotImplementedError() -def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> TensorType: +def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor: """linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor""" + # Same implementation as aten_cross raise NotImplementedError() -@torch_op(("aten::linalg_det", "aten::det")) +@torch_op(("aten::_linalg_det", "aten::linalg_det", "aten::det")) def aten_linalg_det(A: TFloat) -> TFloat: """linalg_det(Tensor A) -> Tensor""" @@ -326,73 +325,30 @@ def aten_linalg_vector_norm( if dtype != -1: self = op.Cast(self, to=dtype) - if dim is None or (isinstance(dim, tuple) and len(dim) == 0): + if dim is None: self = op.Reshape(self, op.Constant(value_ints=[-1])) keepdim = False - return _aten_linalg_vector_norm_no_dim_onnx(self, ord, keepdim) - else: - return _aten_linalg_vector_norm_onnx(self, ord, dim, keepdim) - - -@torch_op("aten::linalg_vector_norm", private=True) -def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool) -> TFloat: - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Unsqueeze(self, axes=[0]) - - self = op.Abs(self) - ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT - # TODO(justinchuby): Evaluate IsInf in trace mode - if op.IsInf(ord, detect_negative=0, detect_positive=1): - result = op.ReduceMax(self, keepdims=keepdim) - elif op.IsInf(ord, detect_negative=1, detect_positive=0): - result = op.ReduceMin(self, keepdims=keepdim) - elif ord == 0.0: # sum(x!=0) means count non-zero elements - self_bool = op.Cast(self, to=BOOL.dtype) - self_0_1 = op.CastLike(self_bool, self) - result = op.ReduceSum(self_0_1, keepdims=False) - # TODO(microsoft/onnxruntime#18338): Use ReduceL1/L2 when ONNX Runtime is fixed else: - ord_float = op.CastLike(ord, self) - self_pow = op.Pow(self, ord_float) - result = op.Pow(op.ReduceSum(self_pow, keepdims=keepdim), op.Div(1.0, ord_float)) - - if self_is_scalar: - result = op.Squeeze(result) - - return result - - -@torch_op("aten::linalg_vector_norm", private=True) -def _aten_linalg_vector_norm_onnx( - self: TFloat, ord: float, dim: INT64, keepdim: bool -) -> TFloat: - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Unsqueeze(self, axes=[0]) - - dim = op.Reshape(dim, op.Constant(value_ints=[-1])) - self = op.Abs(self) - ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT - # TODO(justinchuby): Evaluate IsInf in trace mode - if op.IsInf(ord, detect_negative=0, detect_positive=1): - result = op.ReduceMax(self, dim, keepdims=keepdim) - elif op.IsInf(ord, detect_negative=1, detect_positive=0): - result = op.ReduceMin(self, dim, keepdims=keepdim) + dim = op.Reshape(dim, op.Constant(value_ints=[-1])) + + if math.isinf(ord): + self = op.Abs(self) + if ord > 0: + return op.ReduceMax(self, dim, keepdims=keepdim) + else: + return op.ReduceMin(self, dim, keepdims=keepdim) elif ord == 0.0: # sum(x!=0) means count non-zero elements self_bool = op.Cast(self, to=BOOL.dtype) self_0_1 = op.CastLike(self_bool, self) - result = op.ReduceSum(self_0_1, dim, keepdims=keepdim) + return op.ReduceSum(self_0_1, dim, keepdims=keepdim) elif ord == 1.0: - result = op.ReduceL1(self, dim, keepdims=keepdim) + return op.ReduceL1(self, dim, keepdims=keepdim) elif ord == 2.0: - result = op.ReduceL2(self, dim, keepdims=keepdim) + return op.ReduceL2(self, dim, keepdims=keepdim) else: - ord_float = op.CastLike(ord, self) - self_pow = op.Pow(self, ord_float) - result = op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), op.Div(1.0, ord_float)) - - if self_is_scalar: - result = op.Squeeze(result) - - return result + if ord < 0 or ord % 2 != 0: + # Not an even integer (could be odd, fractional or negative), use Abs + self = op.Abs(self) + self_pow = op.Pow(self, ord) + exp = op.CastLike(1 / ord, self) + return op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), exp) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 7730008efb..2a7a46ec28 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -17,16 +17,14 @@ import math from typing import Optional, Sequence, Tuple, TypeVar, Union -import onnx - -from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64 +from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, TFloat, - TFloatOrBFloat16, TFloatOrUInt8, + TInt, TReal, TTensor, ) @@ -36,62 +34,14 @@ _MATH_PI = math.pi Rank = common_ops.Rank +_INT64_MAX = 9223372036854775807 +_INT64_MIN = -9223372036854775808 + # All float types but float32 TFloatUnlessFloat32 = TypeVar("TFloatUnlessFloat32", bound=Union[BFLOAT16, FLOAT16, DOUBLE]) -@torch_op("aten::aten_adaptive_avg_pool1d", traceable=True) -def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat: - """adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor""" - - # assert output_size == [1] - # TODO(justinchuby): Specify input constraints - - if Rank(self) == 2: - # Unbatched case - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - pooled = op.GlobalAveragePool(self) - result = op.Squeeze(pooled, op.Constant(value_ints=[0])) - else: - result = op.GlobalAveragePool(self) - - return result - - -@torch_op("aten::aten_adaptive_avg_pool2d", traceable=True) -def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat: - """adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor""" - - # assert output_size == [1, 1] - # TODO(justinchuby): Specify input constraints - - if Rank(self) == 3: - # Unbatched case - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - pooled = op.GlobalAveragePool(self) - result = op.Squeeze(pooled, op.Constant(value_ints=[0])) - else: - result = op.GlobalAveragePool(self) - - return result - - -@torch_op("aten::aten_adaptive_avg_pool3d", traceable=True) -def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat: - """adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor""" - - # assert output_size == [1, 1, 1] - # TODO(justinchuby): Specify input constraints - - if Rank(self) == 4: - # Unbatched case - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - pooled = op.GlobalAveragePool(self) - result = op.Squeeze(pooled, op.Constant(value_ints=[0])) - else: - result = op.GlobalAveragePool(self) - - return result +# NOTE: Implementations of adaptive_average_pool are handled by torch decomp def aten_adaptive_max_pool1d( @@ -206,7 +156,7 @@ def aten_avg_pool2d( padding: Sequence[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, # pylint: disable=unused-argument + divisor_override: Optional[int] = None, ) -> TFloat: """avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor""" @@ -267,7 +217,7 @@ def aten_avg_pool3d( padding: Sequence[int] = (0, 0, 0), ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, # pylint: disable=unused-argument + divisor_override: Optional[int] = None, ) -> TFloat: """avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor""" @@ -343,21 +293,17 @@ def aten_binary_cross_entropy_backward( raise NotImplementedError() -@torch_op("aten::celu") -def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: +@torch_op("aten::celu", trace_only=True) +def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - return op.Celu(self, alpha=alpha) # op.Celu only support float32 - + if self.dtype != FLOAT.dtype: + self_upcasted = op.Cast(self, to=FLOAT.dtype) -@torch_op("aten::celu", traceable=True) -def aten_celu_type_promoted( - self: TFloatUnlessFloat32, alpha: float = 1.0 -) -> TFloatUnlessFloat32: - """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" + # op.Celu only support float32 + return op.Cast(op.Celu(self_upcasted, alpha=alpha), to=self.dtype) - self_upcasted = op.Cast(self, to=FLOAT.dtype) - return op.CastLike(op.Celu(self_upcasted, alpha=alpha), self) + return op.Celu(self, alpha=alpha) @torch_op("aten::col2im", trace_only=True) @@ -409,15 +355,15 @@ def aten_conv_depthwise3d( raise NotImplementedError() -@torch_op("aten::cross_entropy_loss", traceable=True) +@torch_op("aten::cross_entropy_loss", trace_only=True) def aten_cross_entropy_loss( - self: TFloatOrBFloat16, + self: TFloat, target: IntType, - weight: Optional[TFloatOrBFloat16] = None, + weight: Optional[TFloat] = None, reduction: int = 1, # default is 'mean' ignore_index: int = -100, label_smoothing: float = 0.0, # this was ignored due to ONNX not support -) -> TFloatOrBFloat16: +) -> TFloat: """cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor""" if reduction == 0: # "none" @@ -436,7 +382,7 @@ def aten_cross_entropy_loss( return result -@torch_op("aten::elu") +@torch_op("aten::elu", trace_only=True) def aten_elu( self: TFloat, alpha: float = 1.0, @@ -445,9 +391,10 @@ def aten_elu( ) -> TFloat: """elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor""" - # del scale - # del input_scale - return op.Elu(self, alpha=alpha) + input_scale = op.CastLike(input_scale, self) + scale = op.CastLike(scale, self) + self = op.Mul(self, input_scale) + return op.Mul(op.Elu(self, alpha=alpha), scale) def aten_elu_backward( @@ -526,34 +473,32 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal: return result -@torch_op("aten::gelu", private=True) def _aten_gelu_approximate_none(self: TReal) -> TReal: """gelu(Tensor self, *, str approximate='none') -> Tensor""" # GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)] - inner = op.Div(self, 1.4142135623730951) + inner = op.Div(self, ir.tensor(1.4142135623730951, dtype=self.dtype)) erf = op.Erf(inner) - inner = op.Add(erf, 1) - inner = op.Mul(self, inner) - result = op.Mul(0.5, inner) + inner = op.Add(erf, ir.tensor(1, dtype=self.dtype)) + inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner) + result = op.Mul(self, inner) return result -@torch_op("aten::gelu", private=True) def _aten_gelu_approximate_tanh(self: TReal) -> TReal: """gelu(Tensor self, *, str approximate='none') -> Tensor""" # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} - cubed = op.Pow(self, 3) - inner = op.Mul(0.044715, cubed) + cubed = op.Pow(self, ir.tensor(3, dtype=self.dtype)) + inner = op.Mul(ir.tensor(0.044715, dtype=self.dtype), cubed) inner = op.Add(self, inner) - # Prefer explicit graph construction over precomputed constants for clarity. - two_over_pi = op.CastLike(op.Div(2.0, _MATH_PI), self) - inner = op.Mul(op.Sqrt(two_over_pi), inner) + # math.sqrt(2.0/math.pi) = 0.7978845608028654 + sqrt_two_over_pi = ir.tensor(0.7978845608028654, dtype=self.dtype) + inner = op.Mul(sqrt_two_over_pi, inner) inner = op.Tanh(inner) - inner = op.Add(inner, 1) - inner = op.Mul(self, inner) - result = op.Mul(0.5, inner) + inner = op.Add(inner, ir.tensor(1, dtype=self.dtype)) + inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner) + result = op.Mul(self, inner) return result @@ -565,10 +510,13 @@ def aten_gelu_backward( raise NotImplementedError() -def aten_glu(self: TensorType, dim: int = -1) -> TensorType: +@torch_op("aten::glu") +def aten_glu(self: TFloat, dim: int = -1) -> TFloat: """glu(Tensor self, int dim=-1) -> Tensor""" - raise NotImplementedError() + first, second = op.Split(self, axis=dim, num_outputs=2) + result = op.Mul(first, op.Sigmoid(second)) + return result def aten_glu_backward(grad_output: TensorType, self: TensorType, dim: int) -> TensorType: @@ -590,13 +538,63 @@ def aten_glu_backward_jvp( raise NotImplementedError() +@torch_op("aten::group_norm", trace_only=True) +def aten_group_norm( + input: TFloat, + num_groups: int, + weight: Optional[TFloat] = None, + bias: Optional[TFloat] = None, + eps: float = 1e-05, + cudnn_enabled: bool = True, +) -> TensorType: + """group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor""" + + # Actually we don't need N,C,HxW value because the input tensor has that information + if weight is None: # Set to 1.0 as default, the shape is Channel size + weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) + + if bias is None: # Set to 0.0 as default, the shape is Channel size + bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) + + # Because onnx.GroupNorm() need size=group for weight and bias + # But the torch's aten function's input need size=channel, the size mismatched + # So we have to use onnx.InstanceNorm() to simulate + neg_1 = op.Constant(value_ints=[-1]) + # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter + group_tensor = op.Reshape(num_groups, neg_1) + # 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1] + shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0) + input_reshaped = op.Reshape(input, shape_input) + weight_inst_norm = op.Expand( + op.CastLike(op.Constant(value_float=1.0), input), group_tensor + ) + bias_inst_norm = op.Expand(op.CastLike(op.Constant(value_float=0.0), input), group_tensor) + norm = op.InstanceNormalization( + input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps + ) + # Reshape back to input's shape + norm = op.Reshape(norm, op.Shape(input)) + # Using the input weight and bias to do affine + # But need to unsqueeze to the target shape for broading cast easy + input_rank = Rank(input) + one = op.Constant(value_int=1) + axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one) + weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) + bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) + weight_full_shape = op.CastLike(weight_full_shape, norm) + norm_mul_weight = op.Mul(norm, weight_full_shape) + bias_full_shape = op.CastLike(bias_full_shape, norm_mul_weight) + norm_result = op.Add(norm_mul_weight, bias_full_shape) + return norm_result + + def aten_glu_jvp(glu: TensorType, x: TensorType, dx: TensorType, dim: int) -> TensorType: """glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor""" raise NotImplementedError() -@torch_op("aten::hardsigmoid") +@torch_op("aten::hardsigmoid", trace_only=True) def aten_hardsigmoid(self: TFloat) -> TFloat: """hardsigmoid(Tensor self) -> Tensor""" @@ -629,12 +627,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T return op.Clip(self, min_val, max_val) +@torch_op("aten::hardtanh_backward", trace_only=True) def aten_hardtanh_backward( grad_output: TensorType, self: TensorType, min_val: float, max_val: float ) -> TensorType: """hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor""" - raise NotImplementedError() + max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0) + min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0) + return op.Mul(op.Mul(grad_output, max_mask), min_mask) def aten_huber_loss( @@ -653,16 +654,138 @@ def aten_huber_loss_backward( raise NotImplementedError() +def _get_im2col_indices_along_dim( + input_d: TInt, + kernel_size_d: int, + dilation_d: int, + padding_d: int, + stride_d: int, +): + # Input is always 4-D (N, C, H, W) + # Calculate indices of sliding blocks along spatial dimension + # Slide kernel over input each dim d: + # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) + # with steps = stride + + blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1))) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = op.Range(0, blocks_d, stride_d) + blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0]) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d) + kernel_mask = op.Unsqueeze(kernel_grid, [1]) + + # Broadcast and add kernel staring positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + block_mask = op.Add(blocks_d_indices, kernel_mask) + + return block_mask + + +def _get_im2col_padded_input(input, padding_h, padding_w): + # Input is always 4-D tensor (N, C, H, W) + # Padding tensor has the following format: (padding_h, padding_w) + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) + pad = op.Concat( + op.Constant(value_ints=[0, 0]), + op.Unsqueeze(padding_h, [0]), + op.Unsqueeze(padding_w, [0]), + op.Constant(value_ints=[0, 0]), + op.Unsqueeze(padding_h, [0]), + op.Unsqueeze(padding_w, [0]), + axis=0, + ) + return op.Pad(input, pad) + + +def _get_im2col_output_shape(input, kernel_h, kernel_w): + input_shape = op.Shape(input) + batch_dim = op.Gather(input_shape, 0, axis=0) + channel_dim = op.Gather(input_shape, 1, axis=0) + channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w) + + return op.Concat( + op.Unsqueeze(batch_dim, [0]), + op.Unsqueeze(channel_unfolded, [0]), + op.Constant(value_ints=[-1]), + axis=0, + ) + + +@torch_op("aten::im2col", trace_only=True) def aten_im2col( - self: TensorType, + self: TReal, kernel_size: Sequence[int], - dilation: Sequence[int], - padding: Sequence[int], - stride: Sequence[int], + dilation: Sequence[int] = (1, 1), + padding: Sequence[int] = (0, 0), + stride: Sequence[int] = (1, 1), ) -> TensorType: - """im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor""" + """im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor""" - raise NotImplementedError() + input_shape = op.Shape(self) + input_h = op.Gather(input_shape, 2, axis=0) + input_w = op.Gather(input_shape, 3, axis=0) + + if not isinstance(kernel_size, Sequence): + kernel_size = (kernel_size, kernel_size) + kernel_sizes = list(kernel_size) + + if not isinstance(dilation, Sequence): + dilation = (dilation, dilation) + dilations = list(dilation) + + if not isinstance(padding, Sequence): + padding = (padding, padding) + pads = list(padding) + + if isinstance(stride, int): + stride = (stride, stride) + strides = list(stride) + + stride_h, stride_w = strides[0], strides[1] + padding_h, padding_w = pads[0], pads[1] + dilation_h, dilation_w = dilations[0], dilations[1] + kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1] + + blocks_row_indices = _get_im2col_indices_along_dim( + input_h, kernel_h, dilation_h, padding_h, stride_h + ) + blocks_col_indices = _get_im2col_indices_along_dim( + input_w, kernel_w, dilation_w, padding_w, stride_w + ) + + output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w) + padded_input = _get_im2col_padded_input(self, padding_h, padding_w) + + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 + # [[[[1., 2., 3.,], + # [4., 5., 6.,], + # [7., 8., 9.,]]]] + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[1., 2., 3.], + # [4., 5., 6.]], + # [[4., 5., 6.], + # [7., 8., 9.]]]]] + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[[1., 2.], + # [4., 5.]], + # [[2., 3.], + # [5., 6]]], + # [[[4., 5.], + # [7., 8.]], + # [[5., 6.], + # [8., 9.]]]]]] + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: + # [[[1., 2., 4., 5.], + # [2., 3., 5., 6.], + # [4., 5., 7., 8.], + # [5., 6., 8., 9.]]] + output = op.Gather(padded_input, blocks_row_indices, axis=2) + output = op.Gather(output, blocks_col_indices, axis=4) + output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5]) + return op.Reshape(output, output_shape) def aten_infinitely_differentiable_gelu_backward( @@ -679,8 +802,8 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te raise NotImplementedError() -@torch_op("aten::leaky_relu") -def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16: +@torch_op("aten::leaky_relu", trace_only=True) +def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat: """leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor""" return op.LeakyRelu(self, alpha=negative_slope) @@ -694,31 +817,27 @@ def aten_leaky_relu_backward( raise NotImplementedError() -@torch_op("aten::linear") -def aten_linear(input: TFloat, weight: TFloat) -> TFloat: +@torch_op("aten::linear", trace_only=True) +def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat: """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases - # Optimizers may consider this path and replace it with Gemm - # We do not use Gemm here because input can have batch dimensions, which Gemm does not support - weight_transposed = op.Transpose(weight, perm=[1, 0]) - return op.MatMul(input, weight_transposed) - - -@torch_op("aten::linear") -def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat: - """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - - # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases - # Optimizers may consider this path and replace it with Gemm - # We do not use Gemm here because input can have batch dimensions, which Gemm does not support - weight_transposed = op.Transpose(weight, perm=[1, 0]) + if len(input.shape) == 2 and len(weight.shape) == 2: + # Use Gemm for the rank 2 input + return op.Gemm(input, weight, bias, transB=True) + if len(weight.shape) == 1: + # In rare cases the weight can be 1d + weight_transposed = op.Unsqueeze(weight, [1]) + else: + assert len(weight.shape) == 2 + weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) + if bias is None: + return mul return op.Add(mul, bias) @torch_op("aten::log_sigmoid") -def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log_sigmoid(self: TFloat) -> TFloat: """log_sigmoid(Tensor self) -> Tensor""" return op.Log(op.Sigmoid(self)) @@ -871,7 +990,6 @@ def aten_max_pool2d( return _aten_max_pool_onnx(self, kernel_shape, strides, pads, dilations, ceil_mode, 3) -@torch_op("internal::max_pool", private=True, traceable=True) def _aten_max_pool_onnx( self: TFloatOrUInt8, kernel_shape: Sequence[int], @@ -883,7 +1001,7 @@ def _aten_max_pool_onnx( ) -> TFloatOrUInt8: self_rank_is_unbatched_rank = Rank(self) == unbatched_rank if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + self = op.Unsqueeze(self, [0]) pool_result, _ = op.MaxPool( self, @@ -895,7 +1013,7 @@ def _aten_max_pool_onnx( ) if self_rank_is_unbatched_rank: - pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) + pool_result = op.Squeeze(pool_result, [0]) return pool_result @@ -1003,7 +1121,6 @@ def aten_max_pool3d_with_indices( ) -@torch_op("internal::max_pool_with_indices", private=True, traceable=True) def _aten_max_pool_with_indices_onnx( self: TFloatOrUInt8, kernel_size: Sequence[int], @@ -1018,7 +1135,7 @@ def _aten_max_pool_with_indices_onnx( ) -> Tuple[TFloatOrUInt8, INT64]: self_rank_is_unbatched_rank = Rank(self) == unbatched_rank if self_rank_is_unbatched_rank: - self = op.Unsqueeze(self, axes=0) + self = op.Unsqueeze(self, axes=[0]) pool_result, indices = op.MaxPool( self, @@ -1073,8 +1190,8 @@ def _aten_max_pool_with_indices_onnx( indices = op.Sub(indices, delta) if self_rank_is_unbatched_rank: - pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0])) - indices = op.Squeeze(indices, op.Constant(value_ints=[0])) + pool_result = op.Squeeze(pool_result, [0]) + indices = op.Squeeze(indices, [0]) return (pool_result, indices) @@ -1159,7 +1276,7 @@ def aten_mkldnn_reorder_conv3d_weight( raise NotImplementedError() -@torch_op("aten::mse_loss", traceable=True) +@torch_op("aten::mse_loss", trace_only=True) def aten_mse_loss(self: TReal, target: TReal, reduction: int = 1) -> TReal: """mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor""" # FIXME: When reduction=0, the shape(result) will be different than other case @@ -1235,10 +1352,11 @@ def aten_multilabel_margin_loss_forward( raise NotImplementedError() -@torch_op("aten::nll_loss", traceable=True) +@torch_op("aten::nll_loss", trace_only=True) def aten_nll_loss( self: TFloat, target: INT64, + weight: Optional[TFloat] = None, reduction: int = 1, ignore_index: int = -100, ) -> TFloat: @@ -1246,62 +1364,22 @@ def aten_nll_loss( self_rank_is_1 = Rank(self) == 1 if self_rank_is_1: # self rank should be at least 2 - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + self = op.Unsqueeze(self, [0]) rank_target = Rank(target) if rank_target == 0: # target rank should be at least 1 - target = op.Unsqueeze(target, op.Constant(value_ints=[0])) + target = op.Unsqueeze(target, [0]) if reduction == 0: - result = op.NegativeLogLikelihoodLoss( - self, target, ignore_index=ignore_index, reduction="none" - ) + reduction_str = "none" elif reduction == 1: - result = op.NegativeLogLikelihoodLoss( - self, target, ignore_index=ignore_index, reduction="mean" - ) + reduction_str = "mean" else: # assert reduction == 2 - result = op.NegativeLogLikelihoodLoss( - self, target, ignore_index=ignore_index, reduction="sum" - ) - - if self_rank_is_1: - result = op.Squeeze(result) - - return result - - -@torch_op("aten::nll_loss", traceable=True) -def aten_nll_loss_weight( - self: TFloat, - target: INT64, - weight: TFloat, - reduction: int = 1, - ignore_index: int = -100, -) -> TFloat: - """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor""" - - self_rank_is_1 = Rank(self) == 1 - if self_rank_is_1: - # self rank should be at least 2 - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - - rank_target = Rank(target) - if rank_target == 0: # target rank should be at least 1 - target = op.Unsqueeze(target, op.Constant(value_ints=[0])) + reduction_str = "sum" - if reduction == 0: - result = op.NegativeLogLikelihoodLoss( - self, target, weight, ignore_index=ignore_index, reduction="none" - ) - elif reduction == 1: - result = op.NegativeLogLikelihoodLoss( - self, target, weight, ignore_index=ignore_index, reduction="mean" - ) - else: - result = op.NegativeLogLikelihoodLoss( - self, target, weight, ignore_index=ignore_index, reduction="sum" - ) + result = op.NegativeLogLikelihoodLoss( + self, target, weight, ignore_index=ignore_index, reduction=reduction_str + ) if self_rank_is_1: result = op.Squeeze(result) @@ -1361,16 +1439,23 @@ def aten_nll_loss_backward( raise NotImplementedError() +@torch_op("aten::nll_loss_forward", trace_only=True) def aten_nll_loss_forward( self: TensorType, target: TensorType, weight: Optional[TensorType], reduction: int, - ignore_index: INT64, + ignore_index: int, ) -> tuple[TensorType, TensorType]: """nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)""" - raise NotImplementedError() + output = aten_nll_loss(self, target, weight, reduction, ignore_index) + # FIXME: Fake a total_weight tensor for now. It should be different based on weight, reduction and ignore_index + if weight is None: + total_weight = op.CastLike(op.Size(output), self) + else: + total_weight = op.CastLike(op.Size(output), weight) + return output, total_weight def aten_nll_loss_nd( @@ -1391,12 +1476,56 @@ def aten_one_hot(self: TensorType, num_classes: int = -1) -> TensorType: raise NotImplementedError() +def _process_padding(padding: Sequence[INT64 | int], rank: int) -> INT64: + """Convert PyTorch padding for ONNX Pad.""" + assert isinstance(padding, (list, tuple)) + if all(isinstance(pad, int) for pad in padding): + paddings = padding + zeros = [0] * (rank * 2 - len(paddings)) + paddings = [*paddings, *zeros] + paddings = paddings[-2::-2] + paddings[-1::-2] + return op.Constant(value=ir.tensor(paddings, dtype=ir.DataType.INT64)) + else: + paddings = [] + for pad in padding: + if isinstance(pad, int): + paddings.append(op.Constant(value_ints=[pad])) + else: + # Dynamic value + paddings.append(op.Reshape(pad, [-1])) + # Create a series of 1d zero tensors + zero = op.Constant(value_ints=[0]) + zeros = [zero] * (rank * 2 - len(paddings)) + paddings = [*paddings, *zeros] + # Interleave the padding values + paddings = paddings[-2::-2] + paddings[-1::-2] + return op.Concat(*paddings, axis=0) + + +@torch_op("aten::pad", trace_only=True) def aten_pad( - self: TensorType, pad: INT64, mode: str = "constant", value: Optional[float] = None + self: TensorType, + pad: Sequence[INT64], + mode: str = "constant", + value: Optional[float] = None, ) -> TensorType: """pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor""" - raise NotImplementedError() + rank = len(self.shape) + paddings = _process_padding(pad, rank) + const_value = ( + op.Constant(value=ir.tensor(value, dtype=ir.DataType(self.dtype))) + if value is not None + else None + ) + onnx_mode = { + "constant": "constant", + "reflect": "reflect", + "replicate": "edge", + "circular": "wrap", + }[mode] + + return op.Pad(self, paddings, constant_value=const_value, mode=onnx_mode) def aten_pad_sequence( @@ -1407,18 +1536,15 @@ def aten_pad_sequence( raise NotImplementedError() -@torch_op("aten::reflection_pad1d") -def aten_reflection_pad1d(self: TFloat, padding: INT64) -> TFloat: +@torch_op("aten::reflection_pad1d", trace_only=True) +def aten_reflection_pad1d(self: TFloat, padding: Sequence[INT64]) -> TFloat: """reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor""" # assert len(padding) == 2 # Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y] - start = op.Slice(padding, [0], [1], axes=[0]) - end = op.Slice(padding, [1], [2], axes=[0]) - padding_onnx = op.Concat( - op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0 - ) - return op.Pad(self, padding_onnx, mode="reflect") + rank = len(self.shape) + paddings = _process_padding(padding, rank) + return op.Pad(self, paddings, mode="reflect") def aten_reflection_pad1d_backward( @@ -1429,37 +1555,12 @@ def aten_reflection_pad1d_backward( raise NotImplementedError() -@torch_op("aten::reflection_pad2d") -def aten_reflection_pad2d(self: TTensor, padding: INT64) -> TTensor: +@torch_op("aten::reflection_pad2d", trace_only=True) +def aten_reflection_pad2d(self: TTensor, padding: Sequence[INT64]) -> TTensor: """reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor""" - # Convert torch padding format to onnx padding format - # Python code is: - # dim = len(self.shape) - # paddings = list(padding[:]) + [0] * (dim * 2 - len(padding)) - # paddings = paddings[-2::-2] + paddings[-1::-2] - - neg_1 = op.Constant(value_ints=[-1]) - zero = op.Constant(value_ints=[0]) - # [0] * (rank * 2 - len(padding)) - rank = Rank(self) - zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1) - zeros = op.Expand(zero, zero_count) - # list(padding[:]) + [0] * (dim * 2 - len(padding)) - torch_paddings = op.Concat(padding, zeros, axis=0) - # paddings[-2::-2] - size_d = op.Size(torch_paddings) - steps = op.Constant(value_ints=[-2]) - starts = steps - ends = op.Sub(starts, size_d) - odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - # paddings[-1::-2] - starts = neg_1 - ends = op.Sub(starts, size_d) - even_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - # paddings[-2::-2] + paddings[-1::-2] - onnx_padding = op.Concat(odd_elements, even_elements, axis=0) - - return op.Pad(self, onnx_padding, mode="reflect") + rank = len(self.shape) + paddings = _process_padding(padding, rank) + return op.Pad(self, paddings, mode="reflect") def aten_reflection_pad2d_backward( @@ -1470,10 +1571,12 @@ def aten_reflection_pad2d_backward( raise NotImplementedError() -def aten_reflection_pad3d(self: TensorType, padding: INT64) -> TensorType: +@torch_op("aten::reflection_pad3d", trace_only=True) +def aten_reflection_pad3d(self: TensorType, padding: Sequence[INT64]) -> TensorType: """reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor""" - - raise NotImplementedError() + rank = len(self.shape) + paddings = _process_padding(padding, rank) + return op.Pad(self, paddings, mode="reflect") def aten_reflection_pad3d_backward( @@ -1484,14 +1587,14 @@ def aten_reflection_pad3d_backward( raise NotImplementedError() -@torch_op("aten::relu") +@torch_op("aten::relu", trace_only=True) def aten_relu(self: TReal) -> TReal: """relu(Tensor self) -> Tensor""" return op.Relu(self) -@torch_op("aten::relu6", traceable=True) +@torch_op("aten::relu6", trace_only=True) def aten_relu6(self: TReal) -> TReal: """relu6(Tensor self) -> Tensor""" @@ -1499,18 +1602,13 @@ def aten_relu6(self: TReal) -> TReal: return op.Min(op.Relu(self), six) -@torch_op("aten::replication_pad1d") -def aten_replication_pad1d(self: TensorType, padding: INT64) -> TensorType: +@torch_op("aten::replication_pad1d", trace_only=True) +def aten_replication_pad1d(self: TensorType, padding: Sequence[INT64]) -> TensorType: """replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor""" - # assert len(padding) == 2 - # Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y] - start = op.Slice(padding, [0], [1], axes=[0]) - end = op.Slice(padding, [1], [2], axes=[0]) - padding_onnx = op.Concat( - op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0 - ) - return op.Pad(self, padding_onnx, mode="edge") + rank = len(self.shape) + paddings = _process_padding(padding, rank) + return op.Pad(self, paddings, mode="edge") def aten_replication_pad1d_backward( @@ -1521,32 +1619,13 @@ def aten_replication_pad1d_backward( raise NotImplementedError() -@torch_op("aten::replication_pad2d") -def aten_replication_pad2d(self: TTensor, padding: INT64) -> TTensor: +@torch_op("aten::replication_pad2d", trace_only=True) +def aten_replication_pad2d(self: TTensor, padding: Sequence[INT64]) -> TTensor: """replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor""" - neg_1 = op.Constant(value_ints=[-1]) - zero = op.Constant(value_ints=[0]) - # [0] * (rank * 2 - len(padding)) - rank = Rank(self) - zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1) - zeros = op.Expand(zero, zero_count) - # list(padding[:]) + [0] * (dim * 2 - len(padding)) - torch_paddings = op.Concat(padding, zeros, axis=0) - # paddings[-2::-2] - size_d = op.Size(torch_paddings) - steps = op.Constant(value_ints=[-2]) - starts = steps - ends = op.Sub(starts, size_d) - odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - # paddings[-1::-2] - starts = neg_1 - ends = op.Sub(starts, size_d) - even_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - # paddings[-2::-2] + paddings[-1::-2] - onnx_padding = op.Concat(odd_elements, even_elements, axis=0) - - return op.Pad(self, onnx_padding, mode="edge") + rank = len(self.shape) + paddings = _process_padding(padding, rank) + return op.Pad(self, paddings, mode="edge") def aten_replication_pad2d_backward( @@ -1557,32 +1636,13 @@ def aten_replication_pad2d_backward( raise NotImplementedError() -@torch_op("aten::replication_pad3d") -def aten_replication_pad3d(self: TTensor, padding: INT64) -> TTensor: +@torch_op("aten::replication_pad3d", trace_only=True) +def aten_replication_pad3d(self: TTensor, padding: Sequence[INT64]) -> TTensor: """replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor""" - neg_1 = op.Constant(value_ints=[-1]) - zero = op.Constant(value_ints=[0]) - # [0] * (rank * 2 - len(padding)) - rank = Rank(self) - zero_count = op.Reshape(op.Sub(op.Mul(rank, 2), op.Size(padding)), neg_1) - zeros = op.Expand(zero, zero_count) - # list(padding[:]) + [0] * (dim * 2 - len(padding)) - torch_paddings = op.Concat(padding, zeros, axis=0) - # paddings[-2::-2] - size_d = op.Size(torch_paddings) - steps = op.Constant(value_ints=[-2]) - starts = steps - ends = op.Sub(starts, size_d) - odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - # paddings[-1::-2] - starts = neg_1 - ends = op.Sub(starts, size_d) - even_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - # paddings[-2::-2] + paddings[-1::-2] - onnx_padding = op.Concat(odd_elements, even_elements, axis=0) - - return op.Pad(self, onnx_padding, mode="edge") + rank = len(self.shape) + paddings = _process_padding(padding, rank) + return op.Pad(self, paddings, mode="edge") def aten_replication_pad3d_backward( @@ -1620,7 +1680,6 @@ def aten_rrelu_with_noise_backward( raise NotImplementedError() -@torch_op("aten::scaled_dot_product_attention", private=True) def _causal_attention_mask(query: TFloat, key: TFloat) -> TFloat: """Create a causal mask for the given query and key tensors. @@ -1636,20 +1695,30 @@ def _causal_attention_mask(query: TFloat, key: TFloat) -> TFloat: Returns: Tensor of shape [L, S] """ - target_length = op.Shape(query)[-2:-1] - source_length = op.Shape(key)[-2:-1] + q_shape = op.Shape(query) + k_shape = op.Shape(key) + + target_length = op.Slice( + q_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) + source_length = op.Slice( + k_shape, op.Constant(value_ints=[-2]), op.Constant(value_ints=[-1]) + ) # attn_mask = torch.ones(L, S) := { size = op.Concat(target_length, source_length, axis=0) - attn_mask = op.Expand(1.0, size) + attn_mask = op.Expand(op.Constant(value_float=1.0), size) # } attn_mask = op.Trilu(attn_mask, upper=0) # The causal mask has 0s in the lower triangle and -inf in the upper triangle. - attn_mask = op.Where(op.Equal(attn_mask, 0.0), op.Constant(value_float=-float("inf")), 0.0) + attn_mask = op.Where( + op.Equal(attn_mask, op.Constant(value_float=0.0)), + op.Constant(value_float=-float("inf")), + op.Constant(value_float=0.0), + ) attn_mask = op.CastLike(attn_mask, query) return attn_mask -@torch_op("aten::scaled_dot_product_attention", private=True) def _attention_scale(query: TFloat) -> TFloat: """Calculate the scale factor for the attention result. @@ -1659,22 +1728,85 @@ def _attention_scale(query: TFloat) -> TFloat: Returns: Scalar scale factor := 1 / math.sqrt(query.size(-1)) """ - embedding_size = op.CastLike(op.Shape(query)[-1], query) - scale = op.Div(1.0, op.Sqrt(embedding_size)) + q_shape = op.Shape(query) + q_last_dim = op.Gather(q_shape, op.Constant(value_ints=[-1])) + embedding_size = op.CastLike(q_last_dim, query) + one = op.Constant(value_float=1.0) + cast_one = op.CastLike(one, query) + scale = op.Div(cast_one, op.Sqrt(embedding_size)) return scale +def _attention_repeat_kv_for_group_query( + query: TFloat, key: TFloat, value: TFloat +) -> Tuple[TFloat, TFloat]: + """Expand key and value for group query attention. + + repeat_interleave is applied on key and value to match the number of heads in query. + + Args: + query: Tensor of shape [B, q_num_heads, q_S, E] + key: Tensor of shape [B, k_num_heads, kv_S, E] + value: Tensor of shape [B, v_num_heads, kv_S, E] + + Returns: + Tuple of (expanded_key, expanded_value) where: + - expanded_key: Tensor of shape [B, q_num_heads, kv_S, E] + - expanded_value: Tensor of shape [B, q_num_heads, kv_S, E + """ + + assert ( + query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0 + ), ( + "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" + ) + + # NOTE: QKV are expected to be 4D tensors + + batch_size = op.Shape(query, start=0, end=1) # [B] + q_num_heads = op.Shape(query, start=1, end=2) # [Hq] + kv_num_heads = op.Shape(key, start=1, end=2) # [Hk] + qk_head_size = op.Shape(key, start=3, end=4) # [Dk] + v_head_size = op.Shape(value, start=3, end=4) # [Dv] + new_kv_seq_len = op.Shape(key, start=2, end=3) # [T] + + interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk + two = op.Constant(value_int=2) + k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk] + v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv] + + k_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0 + ) + k_expand = op.Expand(k_unsqueezed, k_expand_shape) + v_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0 + ) + v_expand = op.Expand(v_unsqueezed, v_expand_shape) + + k_attention_shape = op.Concat( + batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0 + ) + v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0) + + expanded_key = op.Reshape(k_expand, k_attention_shape) + expanded_value = op.Reshape(v_expand, v_attention_shape) + + return expanded_key, expanded_value + + @torch_op("aten::scaled_dot_product_attention", trace_only=True) def aten_scaled_dot_product_attention( query: TFloat, key: TFloat, value: TFloat, - attn_mask: Optional[TFloat] = None, + attn_mask: Optional[TensorType] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + enable_gqa: bool = False, ) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor + """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -1690,9 +1822,13 @@ def aten_scaled_dot_product_attention( L is the target sequence length, S is the source sequence length, and E is the embedding size. """ # Use trace_only to handle optional inputs - assert (not is_causal) or ( - is_causal and attn_mask is None - ), "is_causal and attn_mask cannot be set at the same time" + assert (not is_causal) or (is_causal and attn_mask is None), ( + "is_causal and attn_mask cannot be set at the same time" + ) + + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" + ) # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html if scale is None: @@ -1702,17 +1838,28 @@ def aten_scaled_dot_product_attention( if is_causal: attn_mask = _causal_attention_mask(query, key) + if enable_gqa: + key, value = _attention_repeat_kv_for_group_query(query, key, value) + else: + assert query.shape[1] == key.shape[1] == value.shape[1], ( + "SDPA (MHA) requires q_num_heads = kv_num_heads" + ) + if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( query, key, value, scale, dropout_p ) + if attn_mask.dtype == ir.DataType.BOOL: + return _aten_scaled_dot_product_attention_bool_mask_onnx( + query, key, value, attn_mask, scale, dropout_p + ) + return _aten_scaled_dot_product_attention_float_mask_onnx( query, key, value, attn_mask, scale, dropout_p ) -@torch_op("aten::_scaled_dot_product_flash_attention", private=True) def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs( query: TFloat, ) -> Tuple[FLOAT, INT64, INT64, FLOAT]: @@ -1720,15 +1867,11 @@ def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs( op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) ) logsumexp = op.Expand(0.0, query_first_three_dims) - # TODO: Eliminate `make_tensor` usage when ORT supports empty tensor. - empty_tensor_int = op.Cast( - op.ConstantOfShape( - op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], [])) - ), - to=INT64.dtype, + empty_tensor_int = op.ConstantOfShape( + op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64)) ) empty_tensor_float = op.ConstantOfShape( - op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], [])) + op.Constant(value=ir.tensor([], dtype=ir.DataType.FLOAT)) ) empty_int = op.Constant(value_int=0) @@ -1742,7 +1885,7 @@ def aten__scaled_dot_product_flash_attention( value: TFloat, dropout_p: float = 0.0, is_causal: bool = False, - return_debug_mask: bool = False, # pylint: disable=unused-argument + return_debug_mask: bool = False, scale: Optional[float] = None, ) -> Tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]: """_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -1779,7 +1922,6 @@ def aten__scaled_dot_product_flash_attention( ) -@torch_op("aten::_scaled_dot_product_efficient_attention", private=True, traceable=True) def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, @@ -1788,9 +1930,9 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query = op.Transpose(query, perm=[0, 2, 1, 3]) query_shape = op.Shape(query) - query_first_dims = query_shape[:1] - query_second_dims = query_shape[1:2] - num_heads = query_shape[-2:-1] + query_first_dims = op.Slice(query_shape, op.Constant(value_ints=[_INT64_MIN]), [1]) + query_second_dims = op.Slice(query_shape, [1], [2]) + num_heads = op.Slice(query_shape, [-2], [-1]) if compute_log_sumexp: logsumexp_dim = op.Cast( @@ -1803,22 +1945,50 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0)) # See Note [Seed and Offset]: - empty_tensor_int = op.Cast( - op.ConstantOfShape( - op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], [])) - ), - to=INT64.dtype, + empty_tensor_int = op.ConstantOfShape( + op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64)) ) return logsum_exp, empty_tensor_int +@torch_op("aten::_scaled_dot_product_flash_attention_for_cpu", trace_only=True) +def aten__scaled_dot_product_flash_attention_for_cpu( + query: TFloat, + key: TFloat, + value: TFloat, + dropout_p: float = 0.0, + is_causal: bool = False, + attn_mask: Optional[TFloat] = None, + scale: Optional[float] = None, +) -> Tuple[TFloat, FLOAT]: + """_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)""" + result = aten_scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + query_shape = op.Shape(query) + query_first_dims = op.Slice(query_shape, [0], [1]) + query_second_dims = op.Slice(query_shape, [1], [2]) + num_heads = op.Slice(query_shape, [-2], [-1]) + logsumexp_dim = op.Cast( + op.Ceil(op.Cast(query_second_dims, to=FLOAT.dtype) / 32.0) * 32.0, to=INT64.dtype + ) + logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, logsumexp_dim, axis=0)) + return result, logsum_exp + + @torch_op("aten::_scaled_dot_product_efficient_attention", trace_only=True) def aten__scaled_dot_product_efficient_attention( query: TFloat, key: TFloat, value: TFloat, - attn_bias: Optional[TFloat], # pylint: disable=unused-argument + attn_bias: Optional[TFloat], compute_log_sumexp: bool, dropout_p: float = 0.0, is_causal: bool = False, @@ -1827,7 +1997,7 @@ def aten__scaled_dot_product_efficient_attention( """_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)""" result = aten_scaled_dot_product_attention( - query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale + query, key, value, attn_bias, dropout_p=dropout_p, is_causal=is_causal, scale=scale ) # The followings are not comsumed by the graph. @@ -1846,58 +2016,6 @@ def aten__scaled_dot_product_efficient_attention( ) -@torch_op("aten::scaled_dot_product_attention", trace_only=True) -def aten_scaled_dot_product_attention_bool_mask( - query: TFloat, - key: TFloat, - value: TFloat, - attn_mask: Optional[BOOL] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, -) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor - - Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - - Equivalent to the PyTorch code:: - scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale - attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask - attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask - attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p) - return attn_weight @ V - - where Q, K, V are the query, key, and value tensors, respectively. - L is the target sequence length, S is the source sequence length, and E is the embedding size. - """ - # Use trace_only to handle optional inputs - assert (not is_causal) or ( - is_causal and attn_mask is None - ), "is_causal and attn_mask cannot be set at the same time" - - if scale is None: - scale = _attention_scale(query) - scale = op.CastLike(scale, query) - - if is_causal: - attn_mask = _causal_attention_mask(query, key) - # The causal mask is always float - return _aten_scaled_dot_product_attention_float_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - if attn_mask is None: - return _aten_scaled_dot_product_attention_no_mask_onnx( - query, key, value, scale, dropout_p - ) - - return _aten_scaled_dot_product_attention_bool_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - -@torch_op("aten::scaled_dot_product_attention", private=True) def _aten_scaled_dot_product_attention_no_mask_onnx( query: TFloat, key: TFloat, @@ -1907,9 +2025,9 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( ) -> TFloat: # Swap the last two axes of key key_shape = op.Shape(key) - key_last_dim = key_shape[-1:] - key_second_last_dim = key_shape[-2:-1] - key_first_dims = key_shape[:-2] + key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) + key_second_last_dim = op.Slice(key_shape, [-2], [-1]) + key_first_dims = op.Slice(key_shape, op.Constant(value_ints=[_INT64_MIN]), [-2]) # Contract the dimensions that are not the last two so we can transpose # with a static permutation. key_squeezed_shape = op.Concat( @@ -1928,11 +2046,11 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( op.MatMul(query_scaled, key_transposed_scaled), axis=-1, ) - attn_weight, _ = op.Dropout(attn_weight, dropout_p) + if dropout_p != 0: + attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) -@torch_op("aten::scaled_dot_product_attention", private=True) def _aten_scaled_dot_product_attention_bool_mask_onnx( query: TFloat, key: TFloat, @@ -1943,9 +2061,9 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( ) -> TFloat: # Swap the last two axes of key key_shape = op.Shape(key) - key_last_dim = key_shape[-1:] - key_second_last_dim = key_shape[-2:-1] - key_first_dims = key_shape[:-2] + key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) + key_second_last_dim = op.Slice(key_shape, [-2], [-1]) + key_first_dims = op.Slice(key_shape, op.Constant(value_ints=[_INT64_MIN]), [-2]) # Contract the dimensions that are not the last two so we can transpose # with a static permutation. key_squeezed_shape = op.Concat( @@ -1961,16 +2079,24 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( query_scaled = op.Mul(query, op.Sqrt(scale)) key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) - attn_mask = op.Where(attn_mask, 0.0, op.Constant(value_float=-float("inf"))) + zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) + neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) + attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), axis=-1, ) - attn_weight, _ = op.Dropout(attn_weight, dropout_p) + # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values + # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output. + # This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match + # the behavior of PyTorch with boolean masks. + # Reference: https://github.com/pytorch/pytorch/issues/103749 + attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight) + if dropout_p != 0: + attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) -@torch_op("aten::scaled_dot_product_attention", private=True) def _aten_scaled_dot_product_attention_float_mask_onnx( query: TFloat, key: TFloat, @@ -1981,9 +2107,9 @@ def _aten_scaled_dot_product_attention_float_mask_onnx( ) -> TFloat: # Swap the last two axes of key key_shape = op.Shape(key) - key_last_dim = key_shape[-1:] - key_second_last_dim = key_shape[-2:-1] - key_first_dims = key_shape[:-2] + key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) + key_second_last_dim = op.Slice(key_shape, [-2], [-1]) + key_first_dims = op.Slice(key_shape, op.Constant(value_ints=[_INT64_MIN]), [-2]) # Contract the dimensions that are not the last two so we can transpose # with a static permutation. key_squeezed_shape = op.Concat( @@ -2002,7 +2128,8 @@ def _aten_scaled_dot_product_attention_float_mask_onnx( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), axis=-1, ) - attn_weight, _ = op.Dropout(attn_weight, dropout_p) + if dropout_p != 0: + attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) @@ -2012,10 +2139,11 @@ def aten_sigmoid_backward(grad_output: TensorType, output: TensorType) -> Tensor raise NotImplementedError() -def aten_silu(self: TensorType) -> TensorType: +@torch_op("aten::silu", trace_only=True) +def aten_silu(self: TFloat) -> TFloat: """silu(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Mul(self, op.Sigmoid(self)) def aten_silu_backward(grad_output: TensorType, self: TensorType) -> TensorType: @@ -2202,27 +2330,20 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str: return "align_corners" if align_corners else "pytorch_half_pixel" -@torch_op( - ( - "aten::upsample_bicubic2d", - "aten::upsample_bilinear2d", - "aten::upsample_nearest1d", - "aten::upsample_nearest2d", - "aten::upsample_nearest3d", - ), - private=True, -) def _aten_upsample_output_size( self: TReal, output_size: INT64, mode: str, coordinate_transformation_mode: str, + antialias: int = 0, ) -> TReal: - self_shape = op.Shape(self) - starts = op.Constant(value_ints=[0]) - ends = op.Constant(value_ints=[2]) - batch_channel = op.Slice(self_shape, starts, ends) - output_size = op.Concat(batch_channel, output_size, axis=0) + batch_and_channel = op.Shape(self, end=2, start=0) + # When output_size is passed in as a list of integers, the torch.onnx + # graph builder when handling op.Concat may fail + # to determine the output type. We cast it to INT64 to ensure the output + output_size = op.Cast(output_size, to=INT64.dtype) + # Append the batch and channel dimensions to the requested output size + output_size = op.Concat(batch_and_channel, output_size, axis=0) return op.Resize( self, None, @@ -2231,25 +2352,28 @@ def _aten_upsample_output_size( mode=mode, coordinate_transformation_mode=coordinate_transformation_mode, nearest_mode="floor", + antialias=antialias, ) -@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True) def _aten_upsample_scales( self: TReal, - scale_factors: TFloat, + scale_factors: Sequence[float], mode: str, coordinate_transformation_mode: str, + antialias: int = 0, ) -> TReal: - scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) - scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) return op.Resize( self, None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + op.Constant( + value_floats=[1.0, 1.0, *scale_factors] + ), # format should be: [1.0, 1.0, scale_h, scale_w] None, mode=mode, coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", + antialias=antialias, ) @@ -2274,6 +2398,28 @@ def aten_upsample_bicubic2d( ) +@torch_op("aten::_upsample_bicubic2d_aa", trace_only=True) +def aten__upsample_bicubic2d_aa( + self: TReal, + output_size: INT64, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> TReal: + """_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, + antialias=1, + ) + + @torch_op("aten::upsample_bicubic2d.vec", trace_only=True) def aten_upsample_bicubic2d_vec( self: TReal, @@ -2287,7 +2433,7 @@ def aten_upsample_bicubic2d_vec( if scale_factors is not None: result = _aten_upsample_scales( self, - op.Constant(value_floats=scale_factors), + scale_factors, mode="cubic", coordinate_transformation_mode=coordinate_transformation_mode, ) @@ -2336,6 +2482,28 @@ def aten_upsample_bilinear2d( ) +@torch_op("aten::_upsample_bilinear2d_aa", trace_only=True) +def aten__upsample_bilinear2d_aa( + self: TReal, + output_size: INT64, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> TReal: + """_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + coordinate_transformation_mode=coordinate_transformation_mode, + mode="linear", + antialias=1, + ) + + @torch_op("aten::upsample_bilinear2d.vec", trace_only=True) def aten_upsample_bilinear2d_vec( self: TReal, @@ -2349,11 +2517,12 @@ def aten_upsample_bilinear2d_vec( if scale_factors is not None: result = _aten_upsample_scales( self, - op.Constant(value_floats=scale_factors), + scale_factors, mode="linear", coordinate_transformation_mode=coordinate_transformation_mode, ) else: + assert output_size is not None result = _aten_upsample_output_size( self, output_size, @@ -2382,9 +2551,8 @@ def aten_upsample_linear1d( self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None ) -> TReal: """upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor""" - # FIXME(justinchuby): Support when scales is provided and align_corners is False - del scales coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + # scales is ignored in PyTorch return _aten_upsample_output_size( self, output_size, @@ -2407,31 +2575,35 @@ def aten_upsample_linear1d_backward( @torch_op("aten::upsample_nearest1d", trace_only=True) def aten_upsample_nearest1d( - self: TReal, size: INT64, scale_factor: Optional[float] = None + self: TReal, output_size: INT64, scales: Optional[float] = None ) -> TReal: """upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor""" - if size is not None: - return _aten_upsample_output_size(self, size, "nearest", "asymmetric") + if scales is not None: + return _aten_upsample_scales(self, [scales], "nearest", "asymmetric") else: - return _aten_upsample_nearest1d_scales(self, scale_factor) + return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric") -@torch_op("aten::upsample_nearest1d", private=True) -def _aten_upsample_nearest1d_scales( - self: TReal, - scale_factors: TFloat, +@torch_op( + ( + "aten::upsample_nearest1d.vec", + "aten::upsample_nearest2d.vec", + "aten::upsample_nearest3d.vec", + ), + trace_only=True, +) +def aten_upsample_nearestnd_vec( + input: TReal, + output_size: Optional[INT64] = None, + scale_factors: Optional[Sequence[float]] = None, ) -> TReal: - scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) - scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) - return op.Resize( - self, - None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode="nearest", - coordinate_transformation_mode="asymmetric", - nearest_mode="floor", - ) + """upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor""" + + if scale_factors is not None: + return _aten_upsample_scales(input, scale_factors, "nearest", "asymmetric") + else: + assert output_size is not None + return _aten_upsample_output_size(input, output_size, "nearest", "asymmetric") def aten_upsample_nearest1d_backward( @@ -2448,18 +2620,21 @@ def aten_upsample_nearest1d_backward( @torch_op("aten::upsample_nearest2d", trace_only=True) def aten_upsample_nearest2d( self: TReal, - size: INT64, + output_size: INT64, scales_h: Optional[float] = None, scales_w: Optional[float] = None, ) -> TReal: """upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor""" - # NOTE: trace_only because optional attributes are not supported by ONNX - # TODO(justinchuby): Conditionally use scales - del scales_h - del scales_w - - return _aten_upsample_output_size(self, size, "nearest", "asymmetric") + if scales_h is not None and scales_w is not None: + return _aten_upsample_scales( + self, + [scales_h, scales_w], + "nearest", + "asymmetric", + ) + else: + return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric") def aten_upsample_nearest2d_backward( @@ -2477,18 +2652,22 @@ def aten_upsample_nearest2d_backward( @torch_op("aten::upsample_nearest3d", trace_only=True) def aten_upsample_nearest3d( self: TReal, - size: INT64, + output_size: INT64, scales_d: Optional[float] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, ) -> TReal: """upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor""" - del scales_h - del scales_w - del scales_d - - return _aten_upsample_output_size(self, size, "nearest", "asymmetric") + if scales_d is not None and scales_h is not None and scales_w is not None: + return _aten_upsample_scales( + self, + [scales_d, scales_h, scales_w], + "nearest", + "asymmetric", + ) + else: + return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric") def aten_upsample_nearest3d_backward( @@ -2528,6 +2707,33 @@ def aten_upsample_trilinear3d( ) +@torch_op("aten::upsample_trilinear3d.vec", trace_only=True) +def aten_upsample_trilinear3d_vec( + self: TReal, + output_size: INT64, + align_corners: bool, + scale_factors: Optional[Sequence[float]], +) -> TReal: + """upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" + + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + if scale_factors is not None: + result = _aten_upsample_scales( + self, + scale_factors, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + else: + result = _aten_upsample_output_size( + self, + output_size, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + return result + + def aten_upsample_trilinear3d_backward( grad_output: TensorType, output_size: INT64, diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index 3136559b13..f53e9c1133 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -19,31 +19,35 @@ from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import RealType, TTensor from onnxscript.onnx_opset import opset18 as op -from onnxscript.onnx_types import TensorType +from onnxscript.onnx_types import BOOL, TensorType -def prims_abs(self: TensorType) -> TensorType: +@torch_op("prims::abs", trace_only=True) +def prims_abs(self: TTensor) -> TTensor: """abs(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Abs(self) +@torch_op("prims::acos", trace_only=True) def prims_acos(self: TensorType) -> TensorType: """acos(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Acos(self) +@torch_op("prims::acosh", trace_only=True) def prims_acosh(self: TensorType) -> TensorType: """acosh(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Acosh(self) -def prims_add(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::add", trace_only=True) +def prims_add(self: TTensor, other: TTensor) -> TTensor: """add(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Add(self, other) def prims_amax( @@ -78,22 +82,25 @@ def prims_as_strided_scatter( raise NotImplementedError() -def prims_asin(self: TensorType) -> TensorType: +@torch_op("prims::asin", trace_only=True) +def prims_asin(self: TTensor) -> TTensor: """asin(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Asin(self) -def prims_asinh(self: TensorType) -> TensorType: +@torch_op("prims::asinh", trace_only=True) +def prims_asinh(self: TTensor) -> TTensor: """asinh(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Asinh(self) -def prims_atan(self: TensorType) -> TensorType: +@torch_op("prims::atan", trace_only=True) +def prims_atan(self: TTensor) -> TTensor: """atan(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Atan(self) def prims_atan2(self: TensorType, other: TensorType) -> TensorType: @@ -102,10 +109,11 @@ def prims_atan2(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -def prims_atanh(self: TensorType) -> TensorType: +@torch_op("prims::atanh", trace_only=True) +def prims_atanh(self: TTensor) -> TTensor: """atanh(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Atanh(self) def prims_bessel_i0(self: TensorType) -> TensorType: @@ -168,12 +176,33 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("prims::broadcast_in_dim", trace_only=True) def prims_broadcast_in_dim( - a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int] + a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int] ) -> TensorType: """broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)""" - raise NotImplementedError() + target_rank = len(shape) + + if not broadcast_dimensions: + # Special case: no broadcast dimensions - all target dims should be 1 + return op.Expand(a, common_ops.merge_dims(shape)) + + # Create base shape of all 1s + ones = [1] * target_rank + + # For each broadcast dimension, we'll replace the 1 with the actual input dimension + # Since broadcast_dimensions is compile-time known, we can do this with individual operations + intermediate_shape = ones + + for i, broadcast_dim in enumerate(broadcast_dimensions): + # Get the input dimension value + input_dim_value = op.Shape(a, start=i, end=i + 1) + intermediate_shape[broadcast_dim] = input_dim_value + + # Reshape input to intermediate shape and expand to target + reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape)) + return op.Expand(reshaped, shape) def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType: @@ -188,10 +217,11 @@ def prims_cbrt(self: TensorType) -> TensorType: raise NotImplementedError() -def prims_ceil(self: TensorType) -> TensorType: +@torch_op("prims::ceil", trace_only=True) +def prims_ceil(self: TTensor) -> TTensor: """ceil(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Ceil(self) def prims_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType: @@ -239,16 +269,18 @@ def prims_copy_to(a: TensorType, b: TensorType) -> TensorType: raise NotImplementedError() -def prims_cos(self: TensorType) -> TensorType: +@torch_op("prims::cos", trace_only=True) +def prims_cos(self: TTensor) -> TTensor: """cos(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Cos(self) -def prims_cosh(self: TensorType) -> TensorType: +@torch_op("prims::cosh", trace_only=True) +def prims_cosh(self: TTensor) -> TTensor: """cosh(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Cosh(self) @torch_op("prims::device_put") @@ -268,10 +300,11 @@ def prims_digamma(self: TensorType) -> TensorType: raise NotImplementedError() -def prims_div(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::div", trace_only=True) +def prims_div(self: TTensor, other: TTensor) -> TTensor: """div(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Div(self, other) def prims_empty(shape: INT64, dtype: int, device: str, requires_grad: bool) -> TensorType: @@ -288,16 +321,18 @@ def prims_empty_strided( raise NotImplementedError() -def prims_eq(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::eq", trace_only=True) +def prims_eq(self: TTensor, other: TTensor) -> TTensor: """eq(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Equal(self, other) -def prims_erf(self: TensorType) -> TensorType: +@torch_op("prims::erf", trace_only=True) +def prims_erf(self: TTensor) -> TTensor: """erf(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Erf(self) def prims_erf_inv(self: TensorType) -> TensorType: @@ -318,10 +353,11 @@ def prims_erfcx(self: TensorType) -> TensorType: raise NotImplementedError() -def prims_exp(self: TensorType) -> TensorType: +@torch_op("prims::exp", trace_only=True) +def prims_exp(self: TTensor) -> TTensor: """exp(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Exp(self) def prims_exp2(self: TensorType) -> TensorType: @@ -360,10 +396,11 @@ def prims_fill(self: TensorType, value: float) -> TensorType: raise NotImplementedError() -def prims_floor(self: TensorType) -> TensorType: +@torch_op("prims::floor", trace_only=True) +def prims_floor(self: TTensor) -> TTensor: """floor(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Floor(self) def prims_fmax(self: TensorType, other: TensorType) -> TensorType: @@ -406,16 +443,18 @@ def prims_gcd(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -def prims_ge(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::ge", trace_only=True) +def prims_ge(self: TTensor, other: TTensor) -> TTensor: """ge(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.GreaterOrEqual(self, other) -def prims_gt(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::gt", trace_only=True) +def prims_gt(self: TTensor, other: TTensor) -> TTensor: """gt(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Greater(self, other) def prims_hypot(self: TensorType, other: TensorType) -> TensorType: @@ -462,10 +501,11 @@ def prims_item(a: TensorType) -> float: raise NotImplementedError() +@torch_op("prims::le", trace_only=True) def prims_le(self: TensorType, other: TensorType) -> TensorType: """le(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.LessOrEqual(self, other) def prims_lgamma(self: TensorType) -> TensorType: @@ -474,10 +514,11 @@ def prims_lgamma(self: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("prims::log", trace_only=True) def prims_log(self: TensorType) -> TensorType: """log(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Log(self) def prims_log10(self: TensorType) -> TensorType: @@ -498,10 +539,11 @@ def prims_log2(self: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("prims::lt", trace_only=True) def prims_lt(self: TensorType, other: TensorType) -> TensorType: """lt(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Less(self, other) def prims_maximum(self: TensorType, other: TensorType) -> TensorType: @@ -528,10 +570,11 @@ def prims_minium_value(dtype: int) -> float: raise NotImplementedError() -def prims_mul(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::mul", trace_only=True) +def prims_mul(self: TTensor, other: TTensor) -> TTensor: """mul(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Mul(self, other) def prims_ndtri(self: TensorType) -> TensorType: @@ -540,16 +583,18 @@ def prims_ndtri(self: TensorType) -> TensorType: raise NotImplementedError() -def prims_ne(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::ne", trace_only=True) +def prims_ne(self: TTensor, other: TTensor) -> TTensor: """ne(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Not(op.Equal(self, other)) -def prims_neg(self: TensorType) -> TensorType: +@torch_op("prims::neg", trace_only=True) +def prims_neg(self: TTensor) -> TTensor: """neg(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Neg(self) def prims_nextafter(self: TensorType, other: TensorType) -> TensorType: @@ -566,10 +611,11 @@ def prims_normal( raise NotImplementedError() -def prims_pow(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::pow", trace_only=True) +def prims_pow(self: TTensor, other: TTensor) -> TTensor: """pow(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Pow(self, other) def prims_prod( @@ -598,16 +644,18 @@ def prims_remainder(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -def prims_reshape(a: TensorType, shape: INT64) -> TensorType: +@torch_op("prims::reshape", trace_only=True) +def prims_reshape(a: TTensor, shape: INT64) -> TTensor: """reshape(Tensor a, SymInt[] shape) -> Tensor""" - raise NotImplementedError() + return op.Reshape(a, shape) +@torch_op("prims::resize", trace_only=True) def prims_resize(a: TensorType, shape: INT64) -> TensorType: """resize(Tensor a, SymInt[] shape) -> Tensor""" - raise NotImplementedError() + return op.Expand(a, shape) def prims_rev(a: TensorType, dims: Sequence[int]) -> TensorType: @@ -616,10 +664,11 @@ def prims_rev(a: TensorType, dims: Sequence[int]) -> TensorType: raise NotImplementedError() +@torch_op("prims::round", trace_only=True) def prims_round(self: TensorType) -> TensorType: """round(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Round(self) def prims_rsqrt(self: TensorType) -> TensorType: @@ -660,16 +709,18 @@ def prims_signbit(self: TensorType) -> TensorType: raise NotImplementedError() -def prims_sin(self: TensorType) -> TensorType: +@torch_op("prims::sin", trace_only=True) +def prims_sin(self: TTensor) -> TTensor: """sin(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Sin(self) -def prims_sinh(self: TensorType) -> TensorType: +@torch_op("prims::sinh", trace_only=True) +def prims_sinh(self: TTensor) -> TTensor: """sinh(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Sinh(self) def prims_slice( @@ -700,22 +751,25 @@ def prims_split_dim(a: TensorType, dim: int, outer_length: INT64) -> TensorType: raise NotImplementedError() -def prims_sqrt(self: TensorType) -> TensorType: +@torch_op("prims::sqrt", trace_only=True) +def prims_sqrt(self: TTensor) -> TTensor: """sqrt(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Sqrt(self) -def prims_squeeze(a: TensorType, dimensions: Sequence[int]) -> TensorType: +@torch_op("prims::squeeze", trace_only=True) +def prims_squeeze(a: TTensor, dimensions: Sequence[int]) -> TTensor: """squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)""" - raise NotImplementedError() + return op.Squeeze(a, axes=dimensions) -def prims_sub(self: TensorType, other: TensorType) -> TensorType: +@torch_op("prims::sub", trace_only=True) +def prims_sub(self: TTensor, other: TTensor) -> TTensor: """sub(Tensor self, Tensor other) -> Tensor""" - raise NotImplementedError() + return op.Sub(self, other) def prims_sum( @@ -732,22 +786,25 @@ def prims_svd(A: TensorType, full_matrices: bool) -> tuple[TensorType, TensorTyp raise NotImplementedError() -def prims_tan(self: TensorType) -> TensorType: +@torch_op("prims::tan", trace_only=True) +def prims_tan(self: TTensor) -> TTensor: """tan(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Tan(self) -def prims_tanh(self: TensorType) -> TensorType: +@torch_op("prims::tanh", trace_only=True) +def prims_tanh(self: TTensor) -> TTensor: """tanh(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Tanh(self) +@torch_op("prims::transpose", trace_only=True) def prims_transpose(a: TensorType, permutation: Sequence[int]) -> TensorType: """transpose(Tensor(a) a, int[] permutation) -> Tensor(a)""" - raise NotImplementedError() + return op.Transpose(a, perm=permutation) def prims_trunc(self: TensorType) -> TensorType: @@ -764,6 +821,7 @@ def prims_uniform( raise NotImplementedError() +@torch_op("prims::var", trace_only=True) def prims_var( inp: TensorType, dims: Optional[Sequence[int]], @@ -772,7 +830,26 @@ def prims_var( ) -> TensorType: """var(Tensor inp, int[]? dims, *, int correction, ScalarType? output_dtype=None) -> Tensor""" - raise NotImplementedError() + if not dims: + # dims can be empty in practice. We just use a None so it is not added in the ONNX graph + dims = None + sub_mean = op.Sub(inp, op.ReduceMean(inp, dims, keepdims=True)) + sqr_mean = op.Mul(sub_mean, sub_mean) + var = op.ReduceMean(sqr_mean, dims, keepdims=False) + # Adjust var according to correction value + if correction != 0: + inp_shape = op.Shape(inp) + dim_size = op.Gather(inp_shape, dims, axis=0) + numel_float = op.CastLike(op.ReduceProd(dim_size, keepdims=False), inp) + mul = op.Mul(var, numel_float) + # Subtract the correction value + sub = op.Sub(numel_float, op.CastLike(correction, inp)) + var = op.Div(mul, sub) + + if output_dtype is not None and output_dtype != -1: + var = op.Cast(var, to=output_dtype) + + return var def prims_view_of(a: TensorType) -> TensorType: @@ -781,10 +858,11 @@ def prims_view_of(a: TensorType) -> TensorType: raise NotImplementedError() -def prims_where(pred: TensorType, a: TensorType, b: TensorType) -> TensorType: +@torch_op("prims::where", trace_only=True) +def prims_where(pred: BOOL, a: TTensor, b: TTensor) -> TTensor: """where(Tensor pred, Tensor a, Tensor b) -> Tensor""" - raise NotImplementedError() + return op.Where(pred, a, b) def prims_zeta(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py new file mode 100644 index 0000000000..92962a9ea6 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" +# pylint: disable=unused-argument +"""quantized_decomposed ops defined in https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py + +- No inplace operators. +- All functions should not have the script() decorator. This is because + we want to delay the compilation of the function. +""" + +from __future__ import annotations + +from onnxscript.function_libs.torch_lib.ops import common +from onnxscript.function_libs.torch_lib.registration import torch_op +from onnxscript.onnx_opset import opset18 as op +from onnxscript.onnx_types import TensorType + + +@torch_op( + ( + "quantized_decomposed::quantize_per_tensor", + "quantized_decomposed::quantize_per_tensor.tensor", + "quantized_decomposed::quantize_per_tensor.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_quantize_per_tensor( + input: TensorType, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: int, +) -> TensorType: + # TODO(justinchuby): Use dtype when we use opset 21 + return op.QuantizeLinear(input, scale, common.constant(zero_point, dtype=dtype)) + + +@torch_op( + ( + "quantized_decomposed::dequantize_per_tensor", + "quantized_decomposed::dequantize_per_tensor.tensor", + "quantized_decomposed::dequantize_per_tensor.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_dequantize_per_tensor( + input: TensorType, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: int, + out_dtype: int = -1, +) -> TensorType: + # TODO(justinchuby): Use dtype when we use opset 21 + dequantized = op.DequantizeLinear(input, scale, common.constant(zero_point, dtype=dtype)) + if out_dtype in (-1, None): + # out_dtype can be None as well + return dequantized + assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}" + return op.Cast(dequantized, to=out_dtype) diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 6719581f62..1b123394d3 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -15,14 +15,12 @@ import math from typing import Optional, Sequence -from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16 +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType _MATH_PI = math.pi -IsScalar = common_ops.IsScalar def aten_special_airy_ai(x: TensorType) -> TensorType: @@ -92,21 +90,21 @@ def aten_special_entr(self: TensorType) -> TensorType: @torch_op(("aten::erf", "aten::special_erf")) -def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erf(self: TFloat) -> TFloat: """erf(Tensor self) -> Tensor""" return op.Erf(self) @torch_op(("aten::erfc", "aten::special_erfc")) -def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erfc(self: TFloat) -> TFloat: """erfc(Tensor self) -> Tensor""" return op.Sub(1, op.Erf(self)) @torch_op("aten::special_erfcx") -def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erfcx(self: TFloat) -> TFloat: """special_erfcx(Tensor self) -> Tensor""" return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self))) @@ -130,10 +128,11 @@ def aten_special_expit(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_special_expm1(self: TensorType) -> TensorType: +@torch_op(("aten::expm1", "aten::special_expm1")) +def aten_special_expm1(self: TFloat) -> TFloat: """special_expm1(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Sub(op.Exp(self), 1) def aten_special_gammainc(self: TensorType, other: TensorType) -> TensorType: @@ -214,15 +213,13 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::log_softmax", "aten::special_log_softmax"), trace_only=True) -def aten_special_log_softmax( - self: TFloatOrBFloat16, dim: int, dtype: int = -1 -) -> TFloatOrBFloat16: +@torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True) +def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: """special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor""" - self_is_scalar = IsScalar(self) + self_is_scalar = len(self.shape) == 0 if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + self = op.Unsqueeze(self, [0]) result = op.LogSoftmax(self, axis=dim) if dtype != -1: result = op.Cast(result, to=dtype) @@ -364,8 +361,8 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::xlogy") -def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +@torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other")) +def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat: """special_xlogy(Tensor self, Tensor other) -> Tensor""" # https://pytorch.org/docs/stable/special.html#torch.special.xlogy diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 05d8f62179..162d69d747 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -1,9 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Registry for aten functions.""" from __future__ import annotations import re -from types import FunctionType from typing import Any, Callable, Generator, Optional import onnxscript @@ -99,8 +100,7 @@ def torch_op( trace_only: bool = False, private: bool = False, complex: bool = False, - traceable: bool = False, -) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: +) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. Args: @@ -113,24 +113,12 @@ def torch_op( private: Whether the function is private (not directly exposed). It should be true for all functions with names starting with "_". complex: Whether the function expects complex-valued inputs. - traceable: Whether the function can also be traced. This is an **experimental** flag. - A function is traceable if it can both be scripted and traced to produce - the same result for a given input. Specifically: - - - A function _can_ be tagged with traceable if its if branches (if any) - can be statically evaluated. - - A function _should_ be tagged with traceable if it contains if branches - and/or CastLike nodes so that they can be evaluated away with the - EXPERIMENTAL_PREFER_TRACING on. - - A function without if branches or CastLike nodes _should not_ be tagged - with traceable because inlining will do the same thing. - - A function with `@graph` defined for a `Scan` op is not traceable yet. """ if registry is None: registry = default_registry def wrapper( - func: FunctionType, + func: Callable, ) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction: # Compile the function custom_opset = onnxscript.values.Opset(domain=_constants.DOMAIN, version=1) @@ -139,9 +127,7 @@ def wrapper( if trace_only: processed_func = onnxscript.values.TracedOnnxFunction(custom_opset, func) else: - assert isinstance(func, FunctionType) processed_func = onnxscript.script(opset=custom_opset)(func) - processed_func.experimental_traceable = traceable assert registry is not None for name_ in _check_and_normalize_names(name): diff --git a/onnxscript/function_libs/torch_lib/tensor_typing.py b/onnxscript/function_libs/torch_lib/tensor_typing.py index 7b5287f417..1f27c0cff0 100644 --- a/onnxscript/function_libs/torch_lib/tensor_typing.py +++ b/onnxscript/function_libs/torch_lib/tensor_typing.py @@ -42,7 +42,7 @@ INT64, UINT8, ] -_FloatType = Union[FLOAT16, FLOAT, DOUBLE] +_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16] IntType = Union[INT8, INT16, INT32, INT64] RealType = Union[ BFLOAT16, @@ -61,7 +61,6 @@ TTensor2 = TypeVar("TTensor2", bound=_TensorType) TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING]) TFloat = TypeVar("TFloat", bound=_FloatType) -TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]) TInt = TypeVar("TInt", bound=IntType) TReal = TypeVar("TReal", bound=RealType) diff --git a/onnxscript/ir/README.md b/onnxscript/ir/README.md index dae5c09a5b..21d5cd124d 100644 --- a/onnxscript/ir/README.md +++ b/onnxscript/ir/README.md @@ -1,22 +1,3 @@ -# ONNX IR +# Where is the code? -An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation. - -## Features ✨ - -- Full ONNX spec support: all valid models representable by ONNX protobuf, and a subset of invalid models (so you can load and fix them). -- Low memory footprint: mmap'ed external tensors; unified interface for ONNX TensorProto, Numpy arrays and PyTorch Tensors etc. No tensor size limitation. Zero copies. -- Straightforward access patterns: Access value information and traverse the graph topology at ease. -- Robust mutation: Create as many iterators as you like on the graph while mutating it. -- Speed: Performant graph manipulation, serialization/deserialization to Protobuf. -- Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way. -- No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format. - -## Code Organization 🗺️ - -- [`_protocols.py`](_protocols.py): Interfaces defined for all entities in the IR. -- [`_core.py`](_core.py): Implementation of the core entities in the IR, including `Model`, `Graph`, `Node`, `Value`, and others. -- [`_enums.py`](_enums.py): Definition of the type enums that correspond to the `DataType` and `AttributeType` in `onnx.proto`. -- [`_name_authority.py`](_name_authority.py): The authority for giving names to entities in the graph, used internally. -- [`_linked_list.py`](_linked_list.py): The data structure as the node container in the graph that supports robust iteration and mutation. Internal. -- [`_metadata.py`](_metadata.py): Metadata store for all entities in the IR. +The ONNX IR has migrated to https://github.com/onnx/ir-py as a standalone project. The original onnxscript APIs are aliased here for compatibility. diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 7bfbeabad9..6240347886 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -1,129 +1,4 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""In-memory intermediate representation for ONNX graphs.""" - -__all__ = [ - # Modules - "serde", - # IR classes - "Attr", - "AttrFloat32", - "AttrFloat32s", - "AttrGraph", - "AttrGraphs", - "AttrInt64", - "AttrInt64s", - "AttrSparseTensor", - "AttrSparseTensors", - "AttrString", - "AttrStrings", - "AttrTensor", - "AttrTensors", - "TypeAndShape", - "AttrTypeProto", - "AttrTypeProtos", - "SymbolicDim", - "ExternalTensor", - "StringTensor", - "Function", - "Graph", - "GraphView", - "Input", - "Model", - "Node", - "RefAttr", - "Shape", - "Tensor", - "Value", - "TensorType", - "OptionalType", - "SequenceType", - "SparseTensorType", - # Protocols - "ArrayCompatible", - "DLPackCompatible", - "TensorProtocol", - "ValueProtocol", - "ModelProtocol", - "NodeProtocol", - "GraphProtocol", - "GraphViewProtocol", - "AttributeProtocol", - "ReferenceAttributeProtocol", - "SparseTensorProtocol", - "SymbolicDimProtocol", - "ShapeProtocol", - "TypeProtocol", - "MapTypeProtocol", - "FunctionProtocol", - # Enums - "AttributeType", - "DataType", - # Types - "OperatorIdentifier", - # Protobuf compatible types - "TensorProtoTensor", -] - -from onnxscript.ir import serde -from onnxscript.ir._core import ( - Attr, - AttrFloat32, - AttrFloat32s, - AttrGraph, - AttrGraphs, - AttrInt64, - AttrInt64s, - AttrSparseTensor, - AttrSparseTensors, - AttrString, - AttrStrings, - AttrTensor, - AttrTensors, - AttrTypeProto, - AttrTypeProtos, - ExternalTensor, - Function, - Graph, - GraphView, - Input, - Model, - Node, - OptionalType, - RefAttr, - SequenceType, - Shape, - SparseTensorType, - StringTensor, - SymbolicDim, - Tensor, - TensorType, - TypeAndShape, - Value, -) -from onnxscript.ir._enums import ( - AttributeType, - DataType, -) -from onnxscript.ir._protocols import ( - ArrayCompatible, - AttributeProtocol, - DLPackCompatible, - FunctionProtocol, - GraphProtocol, - GraphViewProtocol, - MapTypeProtocol, - ModelProtocol, - NodeProtocol, - OperatorIdentifier, - ReferenceAttributeProtocol, - ShapeProtocol, - SparseTensorProtocol, - SymbolicDimProtocol, - TensorProtocol, - TypeProtocol, - ValueProtocol, -) -from onnxscript.ir.serde import TensorProtoTensor +# pylint: disable=wildcard-import,unused-wildcard-import +from onnx_ir import * # type: ignore # noqa: F403 diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py deleted file mode 100644 index 7eba1cb283..0000000000 --- a/onnxscript/ir/_convenience.py +++ /dev/null @@ -1,287 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Convenience methods for constructing and manipulating the IR. - -This is an internal only module. We should choose to expose some of the methods -after they are proven to be useful. -""" - -from __future__ import annotations - -__all__ = [ - "convert_attribute", - "convert_attributes", - "replace_all_uses_with", -] - -from typing import Mapping, Sequence, Union - -import onnx - -from onnxscript.ir import _core, _enums, _protocols, serde - -SupportedAttrTypes = Union[ - str, - int, - float, - Sequence[int], - Sequence[float], - Sequence[str], - _protocols.TensorProtocol, # This includes all in-memory tensor types - onnx.TensorProto, - _core.Attr, - _core.RefAttr, - _protocols.GraphProtocol, - Sequence[_protocols.GraphProtocol], - _protocols.TypeProtocol, - Sequence[_protocols.TypeProtocol], - None, -] - - -def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType: - """Infer the attribute type based on the type of the Python object.""" - if isinstance(attr, int): - return _enums.AttributeType.INT - if isinstance(attr, float): - return _enums.AttributeType.FLOAT - if isinstance(attr, str): - return _enums.AttributeType.STRING - if isinstance(attr, (_core.Attr, _core.RefAttr)): - return attr.type - if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr): - return _enums.AttributeType.INTS - if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr): - return _enums.AttributeType.FLOATS - if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr): - return _enums.AttributeType.STRINGS - if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)): - # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower - return _enums.AttributeType.TENSOR - if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)): - return _enums.AttributeType.GRAPH - if isinstance(attr, Sequence) and all( - isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr - ): - return _enums.AttributeType.GRAPHS - if isinstance( - attr, - (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol), - ): - return _enums.AttributeType.TYPE_PROTO - if isinstance(attr, Sequence) and all( - isinstance( - x, - ( - _core.TensorType, - _core.SequenceType, - _core.OptionalType, - _protocols.TypeProtocol, - ), - ) - for x in attr - ): - return _enums.AttributeType.TYPE_PROTOS - raise TypeError(f"Unsupported attribute type: '{type(attr)}'") - - -def convert_attribute( - name: str, - attr: SupportedAttrTypes, - attr_type: _enums.AttributeType | None = None, -) -> _core.Attr | _core.RefAttr: - """Convert a Python object to a _core.Attr object. - - This method is useful when constructing nodes with attributes. It infers the - attribute type based on the type of the Python value. - - Args: - name: The name of the attribute. - attr: The value of the attribute. - attr_type: The type of the attribute. This is required when attr is None. - When provided, it overrides the inferred type. - - Returns: - A ``Attr`` object. - - Raises: - ValueError: If :param:`attr` is ``None`` and :param:`attr_type` is not provided. - TypeError: If the type of the attribute is not supported. - """ - if attr is None: - if attr_type is None: - raise ValueError("attr_type must be provided when attr is None") - return _core.Attr(name, attr_type, None) - - if isinstance(attr, (_core.Attr, _core.RefAttr)): - if attr.name != name: - raise ValueError( - f"Attribute name '{attr.name}' does not match provided name '{name}'" - ) - if attr_type is not None and attr.type != attr_type: - raise ValueError( - f"Attribute type '{attr.type}' does not match provided type '{attr_type}'" - ) - return attr - - if attr_type is None: - attr_type = _infer_attribute_type(attr) - - if attr_type == _enums.AttributeType.INT: - return _core.AttrInt64(name, attr) # type: ignore - if attr_type == _enums.AttributeType.FLOAT: - return _core.AttrFloat32(name, attr) # type: ignore - if attr_type == _enums.AttributeType.STRING: - return _core.AttrString(name, attr) # type: ignore - if attr_type == _enums.AttributeType.INTS: - return _core.AttrInt64s(name, attr) # type: ignore - if attr_type == _enums.AttributeType.FLOATS: - return _core.AttrFloat32s(name, attr) # type: ignore - if attr_type == _enums.AttributeType.STRINGS: - return _core.AttrStrings(name, attr) # type: ignore - if attr_type == _enums.AttributeType.TENSOR: - if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)): - return _core.AttrTensor(name, attr) - if isinstance(attr, onnx.TensorProto): - return _core.AttrTensor(name, serde.TensorProtoTensor(attr)) - if attr_type == _enums.AttributeType.GRAPH: - return _core.AttrGraph(name, attr) # type: ignore[arg-type] - if attr_type == _enums.AttributeType.GRAPHS: - return _core.AttrGraphs(name, attr) # type: ignore[arg-type] - if attr_type == _enums.AttributeType.TYPE_PROTO: - return _core.AttrTypeProto(name, attr) # type: ignore[arg-type] - if attr_type == _enums.AttributeType.TYPE_PROTOS: - return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type] - raise TypeError(f"Unsupported attribute type: '{type(attr)}'") - - -def convert_attributes( - attrs: Mapping[str, SupportedAttrTypes], -) -> list[_core.Attr | _core.RefAttr]: - """Convert a dictionary of attributes to a list of _core.Attr objects. - - It infers the attribute type based on the type of the value. The supported - types are: int, float, str, Sequence[int], Sequence[float], Sequence[str], - :class:`_core.Tensor`, and :class:`_core.Attr`:: - - >>> from onnxscript import ir - >>> import onnx - >>> import numpy as np - >>> attrs = { - ... "int": 1, - ... "float": 1.0, - ... "str": "hello", - ... "ints": [1, 2, 3], - ... "floats": [1.0, 2.0, 3.0], - ... "strings": ["hello", "world"], - ... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])), - ... "tensor_proto": - ... onnx.TensorProto( - ... dims=[3], - ... data_type=onnx.TensorProto.FLOAT, - ... float_data=[1.0, 2.0, 3.0], - ... name="proto", - ... ), - ... "graph": ir.Graph([], [], nodes=[], name="graph0"), - ... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")], - ... "type_proto": ir.TensorType(ir.DataType.FLOAT), - ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)], - ... } - >>> convert_attributes(attrs) - [AttrInt64('int', 1), AttrFloat32('float', 1.0), AttrString('str', 'hello'), AttrInt64s('ints', [1, 2, 3]), AttrFloat32s('floats', [1.0, 2.0, 3.0]), AttrStrings('strings', ['hello', 'world']), AttrTensor('tensor', Tensor(array([1., 2., 3.]), name='')), AttrTensor('tensor_proto', TensorProtoTensor(name='proto')), AttrInt64s('graph', Graph( - name='graph0', - inputs=( - - ), - outputs=( - - ), - len()=0 - )), AttrGraphs('graphs', [Graph( - name='graph1', - inputs=( - - ), - outputs=( - - ), - len()=0 - ), Graph( - name='graph2', - inputs=( - - ), - outputs=( - - ), - len()=0 - )]), AttrTypeProto('type_proto', Tensor(FLOAT)), AttrTypeProtos('type_protos', [Tensor(FLOAT), Tensor(FLOAT)])] - - Args: - attrs: A dictionary of {: } to convert. - - Returns: - A list of _core.Attr objects. - """ - attributes: list[_core.Attr | _core.RefAttr] = [] - for name, attr in attrs.items(): - attributes.append(convert_attribute(name, attr)) - return attributes - - -def replace_all_uses_with( - values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], - replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], -) -> None: - """Replace all uses of the given values with the replacements. - - This is useful when nodes in the graph are replaced with new nodes, where - the old users need to be updated to use the outputs of the new nodes. - - For example, suppose we have the following graph:: - - A -> {B, C} - - We want to replace the node A with a new node D:: - - >>> from onnxscript import ir - >>> input = ir.Input("input") - >>> node_a = ir.Node("", "A", [input]) - >>> node_b = ir.Node("", "B", node_a.outputs) - >>> node_c = ir.Node("", "C", node_a.outputs) - >>> node_d = ir.Node("", "D", [input]) - >>> replace_all_uses_with(node_a.outputs, node_d.outputs) - >>> len(node_b.inputs) - 1 - >>> node_b.inputs[0].producer().op_type - 'D' - >>> len(node_c.inputs) - 1 - >>> node_c.inputs[0].producer().op_type - 'D' - >>> len(node_a.outputs[0].uses()) - 0 - - When values and replacements are sequences, they are zipped into pairs. All - users of the first value is replaced with the first replacement, and so on. - - .. note:: - You still need to update the graph outputs if any of the values being - replaced are part of the graph outputs. Be sure to remove the old nodes - from the graph using ``graph.remove()`` if they are no longer needed. - - Args: - values: The value or values to be replaced. - replacements: The new value or values to use as inputs. - """ - if not isinstance(values, Sequence): - values = (values,) - if not isinstance(replacements, Sequence): - replacements = (replacements,) - if len(values) != len(replacements): - raise ValueError("The number of values and replacements must match.") - for value, replacement in zip(values, replacements): - for user_node, index in tuple(value.uses()): - user_node.replace_input_with(index, replacement) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py deleted file mode 100644 index 1dedd0b6a7..0000000000 --- a/onnxscript/ir/_core.py +++ /dev/null @@ -1,2618 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""data structures for the intermediate representation.""" - -# NOTES for developers: -# NOTE: None of these classes will have a "to_onnx" or "from_protobuf" method because -# We cannot assume that the build tool chain has protoc installed and would like -# to keep this module protobuf free. This way we separate the concerns of the IR -# and the serialization/deserialization. -# -# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead. - -from __future__ import annotations - -import abc -import contextlib -import dataclasses -import math -import mmap -import os -import sys -import textwrap -import typing -from typing import ( - AbstractSet, - Any, - Collection, - Generic, - Iterable, - Iterator, - OrderedDict, - Sequence, - Union, -) - -import numpy as np - -from onnxscript.ir import ( - _display, - _enums, - _linked_list, - _metadata, - _name_authority, - _protocols, - _type_casting, -) - -if typing.TYPE_CHECKING: - import numpy.typing as npt - from typing_extensions import TypeGuard - -TArrayCompatible = typing.TypeVar( - "TArrayCompatible", - bound=Union[_protocols.ArrayCompatible, _protocols.DLPackCompatible], -) - -# System is little endian -_IS_LITTLE_ENDIAN = sys.byteorder == "little" -# Data types that are not supported by numpy -_NON_NUMPY_NATIVE_TYPES = frozenset( - ( - _enums.DataType.BFLOAT16, - _enums.DataType.FLOAT8E4M3FN, - _enums.DataType.FLOAT8E4M3FNUZ, - _enums.DataType.FLOAT8E5M2, - _enums.DataType.FLOAT8E5M2FNUZ, - _enums.DataType.INT4, - _enums.DataType.UINT4, - ) -) - - -def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]: - """Use this function to check if an object is compatible with numpy. - - Avoid isinstance checks with the ArrayCompatible protocol for performance reasons. - """ - return hasattr(obj, "__array__") - - -def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]: - """Use this function to check if an object is compatible with DLPack. - - Avoid isinstance checks with the DLPackCompatible protocol for performance reasons. - """ - return hasattr(obj, "__dlpack__") - - -class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable): - """Convenience Shared methods for classes implementing TensorProtocol.""" - - __slots__ = () - - def _printable_type_shape(self) -> str: - """Return a string representation of the shape and data type.""" - return f"{self.dtype},{self.shape}" - - def _repr_base(self) -> str: - """Base string for the repr method. - - Example: Tensor - """ - return f"{self.__class__.__name__}<{self._printable_type_shape()}>" - - @property - def size(self) -> int: - """The number of elements in the tensor.""" - return np.prod(self.shape.numpy()) # type: ignore[return-value,attr-defined] - - @property - def nbytes(self) -> int: - """The number of bytes in the tensor.""" - # Use math.ceil because when dtype is INT4, the itemsize is 0.5 - return math.ceil(self.dtype.itemsize * self.size) - - def display(self, *, page: bool | None = None) -> None: - rich = _display.require_rich() - - if rich is None: - status_manager = contextlib.nullcontext() - else: - import rich.status # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel - - status_manager = rich.status.Status(f"Computing tensor stats for {self!r}") - - from onnxscript._thirdparty import ( # pylint: disable=import-outside-toplevel - asciichartpy, - ) - - with status_manager: - # Construct the text to display - lines = [] - array = self.numpy().flatten() - lines.append(repr(self)) - lines.append("") - nan_values = np.isnan(array) - nan_count = np.count_nonzero(nan_values) - inf_count = np.count_nonzero(np.isinf(array)) - numbers = array[~nan_values] - lines.append( - f"Min: {np.min(numbers)}, Max: {np.max(numbers)}, " - f"NaN count: {nan_count}, " - f"Inf count: {inf_count}" - ) - # Compute sparsity - sparse_threathold = 1e-6 - # NOTE: count_nonzero() is faster than sum() for boolean arrays - sparsity = np.count_nonzero(np.abs(array) < sparse_threathold) / array.size - lines.append(f"Sparsity (abs<{sparse_threathold}): {sparsity:.2f}") - - # Compute histogram - finite_numbers = array[np.isfinite(array)] - lines.append("Histogram:") - hist, bin_edges = np.histogram(finite_numbers, bins=80, density=False) - lines.append( - asciichartpy.plot( - hist, bin_edges=bin_edges, cfg={"height": 8, "format": "{:8.0f}"} - ) - ) - - text = "\n".join(lines) - - if rich is None: - print(text) - elif page: - import rich.console # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel - - console = rich.console.Console() - with console.pager(styles=True): - console.print(text) - else: - rich.print(text) - - -def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) -> None: - """Check if the numpy array dtype matches the IR data type. - - When the dtype is not one of the numpy native dtypes, the value needs need to be: - - - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits. - - ``uint8`` for uint4. - - ``uint8`` for 8-bit data types. - - ``uint16`` for bfloat16 - """ - if dtype in _NON_NUMPY_NATIVE_TYPES: - if dtype.itemsize == 2 and array.dtype != np.uint16: - # TODO(justinchuby): Support the storage dtypes like uint16 for bfloat16. - raise TypeError( - f"The numpy array dtype must be uint16 (not {array.dtype}) for IR data type {dtype}." - ) - if dtype.itemsize == 1 and array.dtype != np.uint8: - raise TypeError( - f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}." - ) - if dtype == _enums.DataType.INT4: - if array.dtype not in (np.int8, np.uint8): - raise TypeError( - f"The numpy array dtype must be int8 or uint8 (not {array.dtype}) for IR data type {dtype}." - ) - if dtype == _enums.DataType.UINT4: - if array.dtype != np.uint8: - raise TypeError( - f"The numpy array dtype must be uint8 (not {array.dtype}) for IR data type {dtype}." - ) - return - - try: - dtype_numpy = _enums.DataType.from_numpy(array.dtype) - except TypeError as e: - raise TypeError( - "Failed to convert the numpy dtype to an IR data type. " - "If you are using a non-native dtype, be sure to specify the corresponding IR dtype when " - "creating a Tensor." - ) from e - - if dtype_numpy != dtype: - raise TypeError( - f"The numpy array dtype {array.dtype} does not match the IR data type {dtype}." - ) - - -class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): - """An immutable concrete tensor. - - This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array - compatible object (e.g. ``np.ndarray``, ``torch.Tensor``) or a ``DLPack`` compatible object. - The tensor is immutable and the data is not copied at initialization. - - To create a tensor from a numpy array:: - - >>> import numpy as np - >>> array = np.array([1, 2, 3]) - >>> tensor = Tensor(array) - >>> # The tensor itself can be treated as a numpy array because it implements the __array__ method - >>> np.allclose(tensor, array) - True - - To get a numpy array from the tensor, call :meth:`numpy`. To convert the tensor - to a byte string for serialization, call :meth:`tobytes`. - - It is recommended to check the size of the tensor first before accessing the - underlying data, because accessing the data may be expensive and incur IO - overhead. - - Subclass this class to efficiently handle different types of tensors from different frameworks. - - Attributes: - name: The name of the tensor. - shape: The shape of the tensor. - dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum. - doc_string: Documentation string. - raw: The raw data behind this tensor. It can be anything. - size: The number of elements in the tensor. - nbytes: The number of bytes in the tensor. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - __slots__ = ( - "_raw", - "_dtype", - "_shape", - "name", - "doc_string", - "_metadata_props", - "_metadata", - ) - - def __init__( - self, - value: TArrayCompatible, - dtype: _enums.DataType | None = None, - *, - shape: Shape | None = None, - name: str = "", - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> None: - """Initialize a tensor. - - Args: - value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object. - When the dtype is not one of the numpy native dtypes, the value needs - to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16 - when the value is a numpy array; :param:`dtype` must be specified in this case. - dtype: The data type of the tensor. It can be None only when value is a numpy array. - Users are responsible for making sure the dtype matches the value when value is not a numpy array. - shape: The shape of the tensor. If None, the shape is obtained from the value. - name: The name of the tensor. - doc_string: The documentation string. - metadata_props: The metadata properties. - - Raises: - TypeError: If the value is not a numpy array compatible or a DLPack compatible object. - TypeError: If the value is a numpy array and the dtype is specified but does not match the dtype of the array. - ValueError: If the shape is not specified and the value does not have a shape attribute. - ValueError: If the dtype is not specified and the value is not a numpy array. - """ - # NOTE: We should not do any copying here for performance reasons - if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value): - raise TypeError(f"Expected an array compatible object, got {type(value)}") - if shape is None: - # Obtain the shape from the value - if not hasattr(value, "shape"): - raise ValueError( - f"Expected an object with a shape attribute, but {type(value)} does not have shape. " - "Please specify the shape explicitly." - ) - self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009 - else: - self._shape = shape - self._shape._frozen = True - if dtype is None: - if isinstance(value, np.ndarray): - self._dtype = _enums.DataType.from_numpy(value.dtype) - else: - raise ValueError( - "The dtype must be specified when the value is not a numpy array." - ) - else: - if isinstance(value, np.ndarray): - # Make sure the dtype matches the value - _check_numpy_representation_type(value, dtype) - # Users are responsible for making sure the dtype matches the value - # when value is not a numpy array - self._dtype = dtype - self._raw = value - self.name = name - self.doc_string = doc_string - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props = metadata_props - - def __array__(self, dtype: Any = None) -> np.ndarray: - if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw): - return self._raw.__array__(dtype) - assert _compatible_with_dlpack( - self._raw - ), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}" - return np.from_dlpack(self._raw) - - def __dlpack__(self, *, stream: Any = None) -> Any: - if _compatible_with_dlpack(self._raw): - return self._raw.__dlpack__(stream=stream) - return self.__array__().__dlpack__(stream=stream) - - def __dlpack_device__(self) -> tuple[int, int]: - if _compatible_with_dlpack(self._raw): - return self._raw.__dlpack_device__() - return self.__array__().__dlpack_device__() - - def __repr__(self) -> str: - return f"{self._repr_base()}({self._raw!r}, name={self.name!r})" - - @property - def dtype(self) -> _enums.DataType: - """The data type of the tensor. Immutable.""" - return self._dtype - - @property - def shape(self) -> Shape: - """The shape of the tensor. Immutable.""" - return self._shape - - @property - def raw(self) -> TArrayCompatible: - """Backing data of the tensor. Immutable.""" - return self._raw # type: ignore[return-value] - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array. - - When the data type is not supported by numpy, the value is the bit representation - of the dtype: - - - ``int8`` for int4, with the sign bit extended to 8 bits. - - ``uint8`` for uint4. - - ``uint8`` for 8-bit data types like float8. - - ``uint16`` for bfloat16. - """ - if isinstance(self._raw, np.ndarray): - return self._raw - # We do not cache the value to save memory - return self.__array__() - - def tobytes(self) -> bytes: - """Returns the value as bytes encoded in little endian. - - Override this method for more efficient serialization when the raw - value is not a numpy array. - """ - # TODO(justinchuby): Support DLPack - array = self.numpy() - if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}: - # Pack the array into int4 - array = _type_casting.pack_int4(array) - else: - assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" - if not _IS_LITTLE_ENDIAN: - array = array.view(array.dtype.newbyteorder("<")) - return array.tobytes() - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - -class ExternalTensor(TensorBase, _protocols.TensorProtocol): - """An immutable concrete tensor with its data store on disk. - - This class uses memory mapping to avoid loading the tensor into memory, - when the data type is supported by numpy. Otherwise, the tensor is loaded - into memory lazily when accessed. - - Calling :attr:`shape` does not incur IO. Checking shape before loading - the tensor is recommended if IO overhead and memory usage is a concern. - - To obtain an array, call :meth:`numpy`. To obtain the bytes, - call :meth:`tobytes`. - - The :attribute:`path` can be a relative path or an absolute path. - Serializers should handle the path correctly to conform with the ONNX spec. - - Attributes: - path: The path to the data file. This can be a relative path or an absolute path. - offset: The offset in bytes from the start of the file. - length: The length of the data in bytes. - dtype: The data type of the tensor. - shape: The shape of the tensor. - name: The name of the tensor. It must be specified. - doc_string: The documentation string. - metadata_props: The metadata properties. - """ - - __slots__ = ( - "_path", - "_offset", - "_length", - "_dtype", - "_shape", - "name", - "doc_string", - "_array", - "raw", - "_metadata_props", - "_metadata", - ) - - def __init__( - self, - path: os.PathLike | str, - offset: int | None, - length: int | None, - dtype: _enums.DataType, - *, - shape: Shape, - name: str, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> None: - self._path = path - self._offset: int | None = offset - self._length: int | None = length - self._dtype: _enums.DataType = dtype - self.name: str = name # mutable - self._shape: Shape = shape - self._shape._frozen = True - self.doc_string: str | None = doc_string # mutable - self._array: np.ndarray | None = None - self.raw: mmap.mmap | None = None - self._metadata_props = metadata_props - self._metadata: _metadata.MetadataStore | None = None - - @property - def path(self) -> str | os.PathLike: - # Immutable - return self._path - - @property - def offset(self) -> int | None: - # Immutable - return self._offset - - @property - def length(self) -> int | None: - # Immutable - return self._length - - @property - def dtype(self) -> _enums.DataType: - # Immutable - return self._dtype - - @property - def shape(self) -> Shape: - # Immutable - return self._shape - - def _load(self): - assert self._array is None, "Bug: The array should be loaded only once." - # Map the whole file into the memory - # TODO(justinchuby): Verify if this would exhaust the memory address space - with open(self._path, "rb") as f: - self.raw = mmap.mmap( - f.fileno(), - 0, - access=mmap.ACCESS_READ, - ) - # Handle the byte order correctly by always using little endian - dt = np.dtype(self.dtype.numpy()).newbyteorder("<") - self._array = np.frombuffer( - self.raw, dtype=dt, offset=self.offset or 0, count=self.size - ).reshape(self.shape.numpy()) - - def __array__(self, dtype: Any = None) -> np.ndarray: - if self._array is None: - self._load() - assert self._array is not None - return self._array.__array__(dtype) - - def __dlpack__(self, *, stream: Any = None) -> Any: - return self.numpy().__dlpack__(stream=stream) - - def __repr__(self) -> str: - return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})" - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array. - - The data will be memory mapped into memory and will not taken up physical memory space. - """ - if self._array is None: - self._load() - assert self._array is not None - return self._array - - def tobytes(self) -> bytes: - """Return the bytes of the tensor. - - This will load the tensor into memory. - """ - if self.raw is None: - self._load() - assert self.raw is not None - offset = self._offset or 0 - length = self._length or self.nbytes - return self.raw[offset : offset + length] - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - -class StringTensor(TensorBase, _protocols.TensorProtocol): - """Multidimensional array of strings (as binary data to match the string_data field in TensorProto).""" - - __slots__ = ( - "_raw", - "_shape", - "name", - "doc_string", - "_metadata_props", - "_metadata", - ) - - def __init__( - self, - value: Sequence[bytes] | npt.NDArray[np.bytes_], - *, - shape: Shape | None = None, - name: str = "", - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> None: - """Initialize a tensor. - - Args: - value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes. - shape: The shape of the tensor. If None, the shape is obtained from the value. - name: The name of the tensor. - doc_string: The documentation string. - metadata_props: The metadata properties. - """ - if shape is None: - if not hasattr(value, "shape"): - raise ValueError( - f"Expected an object with a shape attribute, but {type(value)} does not have shape. " - "Please specify the shape explicitly." - ) - self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009 - else: - self._shape = shape - self._shape._frozen = True - self._raw = value - self.name = name - self.doc_string = doc_string - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props = metadata_props - - def __array__(self, dtype: Any = None) -> np.ndarray: - if isinstance(self._raw, np.ndarray): - return self._raw - assert isinstance( - self._raw, Sequence - ), f"Bug: Expected a sequence, got {type(self._raw)}" - return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy()) - - def __dlpack__(self, *, stream: Any = None) -> Any: - del stream # unused - raise TypeError("StringTensor does not support DLPack") - - def __dlpack_device__(self) -> tuple[int, int]: - raise TypeError("StringTensor does not support DLPack") - - def __repr__(self) -> str: - return f"{self._repr_base()}({self._raw!r}, name={self.name!r})" - - @property - def dtype(self) -> _enums.DataType: - """The data type of the tensor. Immutable.""" - return _enums.DataType.STRING - - @property - def shape(self) -> Shape: - """The shape of the tensor. Immutable.""" - return self._shape - - @property - def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]: - """Backing data of the tensor. Immutable.""" - return self._raw # type: ignore[return-value] - - def numpy(self) -> npt.NDArray[np.bytes_]: - """Return the tensor as a numpy array.""" - return self.__array__() - - def tobytes(self) -> bytes: - raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.") - - def string_data(self) -> Sequence[bytes]: - """Return the string data of the tensor.""" - if isinstance(self._raw, np.ndarray): - return self._raw.flatten().tolist() - return self._raw - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - -class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): - __slots__ = ("_value",) - - def __init__(self, value: str | None) -> None: - """Initialize a symbolic dimension. - - Args: - value: The value of the dimension. It should not be an int. - """ - if isinstance(value, int): - raise TypeError("The value of a SymbolicDim cannot be an int") - self._value = value - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SymbolicDim): - return self.value == other - return self.value == other.value - - def __hash__(self) -> int: - return hash(self.value) - - @property - def value(self) -> str | None: - return self._value - - def __str__(self) -> str: - return f"{self._value}" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._value})" - - -class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): - __slots__ = ("_dims", "_frozen") - - def __init__( - self, - dims: Iterable[int | SymbolicDim | str | None], - /, - denotations: Iterable[str | None] | None = None, - frozen: bool = False, - ) -> None: - """Initialize a shape. - - Args: - dims: The dimensions of the shape. Each dimension can be an integer or a - SymbolicDim or any Python object. When a ``dim`` is not an integer or a - SymbolicDim, it is converted to a SymbolicDim. - denotations: The denotations of the dimensions. If None, the denotations are not set. - Standard denotation can optionally be used to denote tensor - dimensions with standard semantic descriptions to ensure - that operations are applied to the correct axis of a tensor. - Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition - for pre-defined dimension denotations. - frozen: If True, the shape is immutable and cannot be modified. This - is useful when the shape is initialized by a Tensor. - """ - self._dims: list[int | SymbolicDim] = [ - SymbolicDim(dim) if not isinstance(dim, (int, SymbolicDim)) else dim - for dim in dims - ] - self._denotations: list[str | None] = ( - list(denotations) if denotations is not None else [None] * len(self._dims) - ) - if len(self._denotations) != len(self._dims): - raise ValueError( - "The number of denotations, when provided, must be equal to the number of dimensions." - ) - self._frozen: bool = frozen - - @property - def dims(self) -> tuple[int | SymbolicDim, ...]: - """All dimensions in the shape. - - This property is read-only. Use __getitem__ and __setitem__ to modify the shape or create a new shape. - """ - return tuple(self._dims) - - def rank(self) -> int: - """The rank of the shape.""" - return len(self._dims) - - def numpy(self) -> tuple[int, ...]: - if any(not isinstance(dim, int) for dim in self._dims): - raise ValueError(f"Cannot convert the shape {self} to a tuple of ints") - return tuple(dim for dim in self._dims) # type: ignore - - def __len__(self) -> int: - return len(self._dims) - - def __iter__(self) -> Iterator[int | SymbolicDim]: - return iter(self._dims) - - @typing.overload - def __getitem__(self, index: int) -> int | SymbolicDim: ... - - @typing.overload - def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ... - - def __getitem__(self, index): - return tuple(self._dims)[index] - - def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None: - """Set the dimension at the index. - - Args: - index: The index of the dimension. - value: The value of the dimension. - - Raises: - TypeError: If the shape is frozen and cannot be modified. - TypeError: If the value is not an int or SymbolicDim. - """ - if self._frozen: - raise TypeError("The shape is frozen and cannot be modified.") - if isinstance(value, str) or value is None: - value = SymbolicDim(value) - if not isinstance(value, (int, SymbolicDim)): - raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'") - - self._dims[index] = value - - def get_denotation(self, index: int) -> str | None: - """Return the denotation of the dimension at the index. - - Args: - index: The index of the dimension. - - Returns: - The denotation of the dimension. - """ - return self._denotations[index] - - def set_denotation(self, index: int, denotation: str | None) -> None: - """Set the denotation of the dimension at the index. - - Args: - index: The index of the dimension. - denotation: The denotation of the dimension. - """ - self._denotations[index] = denotation - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._dims!r})" - - def __str__(self) -> str: - """Return a string representation of the shape. - - E.g. [n,1,3] - """ - return f"[{','.join([str(dim) for dim in self._dims])}]" - - def __eq__(self, other: object) -> bool: - """Return True if the shapes are equal. - - Two shapes are eqaul if all their dimensions are equal. - """ - if isinstance(other, Shape): - return self._dims == other._dims - if not isinstance(other, Iterable): - return False - return self._dims == list(other) - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) - - -def _quoted(string: str) -> str: - """Return a quoted string. - - This function is used to quote value/node names in the IR for better readability. - """ - return f'"{string}"' - - -class Node(_protocols.NodeProtocol, _display.PrettyPrintable): - """IR Node. - - If the ``graph`` is provided, the node will be added to the graph. Otherwise, - user is responsible to call ``graph.append(node)`` (or other mutation methods - in :class:`Graph`) to add the node to the graph. - - After the node is initialized, it will add itself as a user of the input values. - - The output values of the node are created during node initialization and are immutable. - To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with - the new output values by calling :meth:`replace_input_with` on the using nodes - of this node's outputs. - """ - - __slots__ = ( - "_name", - "_domain", - "_op_type", - "_inputs", - "_outputs", - "_attributes", - "_overload", - "_version", - "doc_string", - "_metadata", - "_metadata_props", - "_graph", - ) - - def __init__( - self, - domain: str, - op_type: str, - inputs: Iterable[Value | None], - attributes: Iterable[Attr | RefAttr] = (), - *, - overload: str = "", - num_outputs: int | None = None, - outputs: Sequence[Value] | None = None, - version: int | None = None, - graph: Graph | None = None, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ): - """Initialize a node and add it as a user of the input values. - - Args: - domain: The domain of the operator. For onnx operators, this is an empty string. - op_type: The name of the operator. - inputs: The input values. When an input is None, it is an empty input. - attributes: The attributes. RefAttr can be used only when the node is defined in a Function. - overload: The overload name when the node is invoking a function. - num_outputs: The number of outputs of the node. If not specified, the number is 1. - outputs: The output values. If None, the outputs are created during initialization. - version: The version of the operator. If None, the version is unspecified and will follow that of the graph. - graph: The graph that the node belongs to. If None, the node is not added to any graph. - A `Node` must belong to zero or one graph. - name: The name of the node. If None, the node is anonymous. - doc_string: The documentation string. - metadata_props: The metadata properties. - - Raises: - TypeError: If the attributes are not Attr or RefAttr. - ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs. - ValueError: If an output value is None, when outputs is specified. - ValueError: If an output value has a producer set already, when outputs is specified. - """ - self._name = name - self._domain: str = domain - self._op_type: str = op_type - # NOTE: Make inputs immutable with the assumption that they are not mutated - # very often. This way all mutations can be tracked. - # If necessary, we can cache the inputs and outputs as tuples. - self._inputs: tuple[Value | None, ...] = tuple(inputs) - # Values belong to their defining nodes. The values list is immutable - self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs) - attributes = tuple(attributes) - if attributes and not isinstance(attributes[0], (Attr, RefAttr)): - raise TypeError( - f"Expected the attributes to be Attr or RefAttr, got {type(attributes[0])}. " - "If you are copying the attributes from another node, make sure you call " - "node.attributes.values() because it is a dictionary." - ) - self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict( - (attr.name, attr) for attr in attributes - ) - self._overload: str = overload - # TODO(justinchuby): Potentially support a version range - self._version: int | None = version - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - self._graph: Graph | None = graph - self.doc_string = doc_string - - # Add the node as a use of the inputs - for i, input_value in enumerate(self._inputs): - if input_value is not None: - input_value._add_usage(self, i) # pylint: disable=protected-access - - # Add the node to the graph if graph is specified - if self._graph is not None: - self._graph.append(self) - - def _create_outputs( - self, num_outputs: int | None, outputs: Sequence[Value] | None - ) -> tuple[Value, ...]: - """Check the parameters and create outputs for the node. - - Args: - num_outputs: The number of outputs of the node. - outputs: The output values of the node. - - Returns: - The output values of the node. - - Raises: - ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs. - ValueError: If an output value is None. - ValueError: If an output value has a producer set already. - """ - # Check num_outputs and outputs are consistent - if num_outputs is not None and outputs is not None and num_outputs != len(outputs): - raise ValueError( - "num_outputs must be the same as len(outputs) when num_outputs is specified." - "num_outputs: {num_outputs}, outputs: {outputs}" - ) - # 1. If outputs is specified (can be empty []), use the outputs - if outputs is not None: - # Check all output values are valid first - for output in outputs: - if output is None: - raise ValueError(f"Output value cannot be None. All outputs: {outputs}") - if output.producer() is not None: - raise ValueError( - f"Supplied output value cannot have a producer when used for initializing a Node. " - f"Output: {output}. All outputs: {outputs}" - ) - result = [] - for i, output in enumerate(outputs): - output._producer = self # pylint: disable=protected-access - output._index = i # pylint: disable=protected-access - result.append(output) - return tuple(result) - - # 2. If num_outputs is specified, create num_outputs outputs - if num_outputs is None: - # Default to 1 output - num_outputs = 1 - assert num_outputs is not None - return tuple(Value(self, index=i) for i in range(num_outputs)) - - def __str__(self) -> str: - node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * ( - self._overload != "" - ) - inputs_text = ( - "(" - + ", ".join( - [ - ( - f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}" - if x is not None - else "None" - ) - for x in self._inputs - ] - ) - + ")" - ) - attributes_text = ( - (" {" + ", ".join([f"{k}={v}" for k, v in self._attributes.items()]) + "}") - if self._attributes - else "" - ) - outputs_text = ", ".join(str(x) for x in self._outputs) - - return f"{outputs_text} ⬅️ {node_type_text}{inputs_text}{attributes_text}" - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(name={self._name!r}, domain={self._domain!r}, " - f"op_type={self._op_type!r}, inputs={self._inputs!r}, attributes={self._attributes!r}, " - f"overload={self._overload!r}, outputs={self._outputs!r}, " - f"version={self._version!r}, doc_string={self.doc_string!r})" - ) - - @property - def name(self) -> str | None: - return self._name - - @name.setter - def name(self, value: str | None) -> None: - self._name = value - - @property - def domain(self) -> str: - return self._domain - - @domain.setter - def domain(self, value: str) -> None: - self._domain = value - - @property - def version(self) -> int | None: - return self._version - - @version.setter - def version(self, value: int | None) -> None: - self._version = value - - @property - def op_type(self) -> str: - return self._op_type - - @op_type.setter - def op_type(self, value: str) -> None: - self._op_type = value - - @property - def overload(self) -> str: - return self._overload - - @overload.setter - def overload(self, value: str) -> None: - self._overload = value - - @property - def inputs(self) -> Sequence[Value | None]: - return self._inputs - - @inputs.setter - def inputs(self, _: Any) -> None: - raise AttributeError( - "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." - ) - - def replace_input_with(self, index: int, value: Value | None) -> None: - """Replace an input with a new value.""" - if index < 0 or index >= len(self.inputs): - raise ValueError(f"Index out of range: {index}") - old_input = self.inputs[index] - self._inputs = tuple( - value if i == index else old_input for i, old_input in enumerate(self.inputs) - ) - if old_input is not None: - old_input._remove_usage(self, index) # pylint: disable=protected-access - if value is not None: - value._add_usage(self, index) # pylint: disable=protected-access - - def prepend(self, /, nodes: Node | Iterable[Node]) -> None: - """Insert a node before this node in the list of nodes in the graph. - - It is the same as calling ``graph.insert_before(self, nodes)``. - - Example:: - - Before: previous_node -> self - previous_node' -> node -> next_node' - After: previous_node -> node -> self - previous_node' -> next_node' - - Args: - nodes: A node or a sequence of nodes to put before this node. - """ - if self._graph is None: - raise ValueError("The node to prepend to does not belong to any graph.") - self._graph.insert_before(self, nodes) - - def append(self, /, nodes: Node | Iterable[Node]) -> None: - """Insert a node after this node in the list of nodes in the graph. - - It is the same as calling ``graph.insert_after(self, nodes)``. - - Example:: - - Before: previous_node -> self - previous_node' -> node -> next_node' - After: previous_node -> self -> node - previous_node' -> next_node' - - Args: - nodes: A node or a sequence of nodes to put after this node. - """ - if self._graph is None: - raise ValueError("The node to append to does not belong to any graph.") - self._graph.insert_after(self, nodes) - - @property - def outputs(self) -> Sequence[Value]: - return self._outputs - - @outputs.setter - def outputs(self, _: Sequence[Value]) -> None: - raise AttributeError("outputs is immutable. Please create a new node instead.") - - @property - def attributes(self) -> OrderedDict[str, Attr | RefAttr]: - return self._attributes - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def graph(self) -> Graph | None: - return self._graph - - @graph.setter - def graph(self, value: Graph | None) -> None: - self._graph = value - - def op_identifier(self) -> _protocols.OperatorIdentifier: - return self.domain, self.op_type, self.overload - - def display(self, *, page: bool | None = None) -> None: - # Add the node's name to the displayed text - print(f"Node: {self.name!r}") - if self.doc_string: - print(f"Doc: {self.doc_string}") - super().display(page=page) - - -class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable): - """Tensor types that are non recursive types.""" - - __slots__ = ("_dtype", "denotation") - - def __init__(self, dtype: _enums.DataType, *, denotation: str | None = None) -> None: - self._dtype = dtype - self.denotation = denotation - - @property - def dtype(self) -> _enums.DataType: - return self._dtype - - @dtype.setter - def dtype(self, value: _enums.DataType) -> None: - self._dtype = value - - @property - def elem_type(self) -> _enums.DataType: - """Return the element type of the tensor type""" - return self.dtype - - def __eq__(self, other: object) -> bool: - if self.__class__ is not other.__class__: - return False - return self.dtype == other.dtype # type: ignore[attr-defined] - - def __repr__(self) -> str: - # Remove "Type" from name for display - short_name = self.__class__.__name__[:-4] - return f"{short_name}({self.dtype!r})" - - -class TensorType(_TensorTypeBase): - """A type that represents a tensor.""" - - def __str__(self) -> str: - return f"{self.dtype}" - - -class SparseTensorType(_TensorTypeBase): - """A type that represents a sparse tensor.""" - - -class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable): - """Base for recursive types like Optional and Sequence.""" - - __slots__ = ("_elem_type", "denotation") - - def __init__( - self, elem_type: _protocols.TypeProtocol, *, denotation: str | None = None - ) -> None: - self._elem_type = elem_type - self.denotation = denotation - - @property - def dtype(self) -> _enums.DataType: - return self._elem_type.dtype - - @dtype.setter - def dtype(self, value: _enums.DataType) -> None: - self._elem_type.dtype = value - - @property - def elem_type(self) -> _protocols.TypeProtocol: - return self._elem_type - - def __eq__(self, other: object) -> bool: - if not isinstance(other, _RecursiveTypeBase): - return False - if self.__class__ != other.__class__: - return False - # Recursively compare the type of the elements - return self.elem_type == other.elem_type - - def __repr__(self) -> str: - # Remove "Type" from name for display - short_name = self.__class__.__name__[:-4] - return f"{short_name}({self.elem_type!r})" - - -class SequenceType(_RecursiveTypeBase): - """A type that represents a sequence of elements.""" - - -class OptionalType(_RecursiveTypeBase): - """A type that represents an optional element.""" - - -class Value(_protocols.ValueProtocol, _display.PrettyPrintable): - """IR Value. - - A value is a named entity that can be used to represent an input or output of a graph, - a function, or a node. The information it stores generalizes over ``ValueInfoProto`` - in the ONNX specification. - - A :class:`Value` is always not owned or owned by exactly one node. When the value is not - owned, it must be an input of a graph or a function. ``producer`` and ``index`` - are ``None``. - - When the value is owned by a node, it is an output of the node. - The node that produces the value can be accessed with :meth:`producer`. - The index of the output of the node that produces the value can be accessed with - :meth:`index`. - - To find all the nodes that use this value as an input, call :meth:`uses`. - - To check if the value is an output of a graph, call :meth:`is_graph_output`. - - Attributes: - name: The name of the value. A value is always named when it is part of a graph. - shape: The shape of the value. - type: The type of the value. - metadata_props: Metadata. - """ - - __slots__ = ( - "_const_value", - "_index", - "_metadata_props", - "_metadata", - "_name", - "_producer", - "_shape", - "_type", - "_uses", - "doc_string", - ) - - def __init__( - self, - producer: Node | None, - *, - index: int | None, - name: str | None = None, - shape: Shape | None = None, - type: _protocols.TypeProtocol | None = None, - doc_string: str | None = None, - const_value: _protocols.TensorProtocol - | Sequence[_protocols.TensorProtocol] - | None = None, - ) -> None: - # producer is None when the value is an input or an initializer - self._producer: Node | None = producer - self._index: int | None = index - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = None - - self._name: str | None = name - self._shape: Shape | None = shape - self._type: _protocols.TypeProtocol | None = type - # TODO(justinchuby): Handle initialization when a const value is provided - # We can get shape and type information from the const value - self._const_value = const_value - # Use a collection of (Node, int) to store uses. This is needed - # because a single use can use the same value multiple times. - # Use a dictionary to preserve insertion order so that the visiting order is deterministic - self._uses: dict[tuple[Node, int], None] = {} - self.doc_string = doc_string - - def __repr__(self) -> str: - value_name = self.name if self.name else "anonymous:" + str(id(self)) - producer = self.producer() - producer_text = ( - producer.name is not None or "anonymous_node:" + str(id(producer)) - if producer is not None - else None - ) - return f"{self.__class__.__name__}({value_name!r}, type={self.type!r}, shape={self.shape}, producer={producer_text}, index={self.index()})" - - def __str__(self) -> str: - value_name = self.name if self.name is not None else "anonymous:" + str(id(self)) - shape_text = str(self.shape) if self.shape is not None else "?" - type_text = str(self.type) if self.type is not None else "?" - - # Quote the name because in reality the names can have invalid characters - # that make them hard to read - return f"%{_quoted(value_name)}<{type_text},{shape_text}>" - - def producer(self) -> Node | None: - """The node that produces this value.""" - return self._producer - - def index(self) -> int | None: - """The index of the output of the defining node.""" - return self._index - - def uses(self) -> Collection[tuple[Node, int]]: - """Return a set of uses of the value. - - The set contains tuples of ``(Node, index)`` where the index is the index of the input - of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``. - """ - return self._uses.keys() - - def _add_usage(self, use: Node, index: int) -> None: - """Add a usage of this value. - - This is an internal method. It should only be called by the Node class. - """ - self._uses[(use, index)] = None - - def _remove_usage(self, use: Node, index: int) -> None: - """Remove a node from the uses of this value. - - This is an internal method. It should only be called by the Node class. - """ - self._uses.pop((use, index)) - - @property - def name(self) -> str | None: - return self._name - - @name.setter - def name(self, value: str | None) -> None: - self._name = value - - @property - def type(self) -> _protocols.TypeProtocol | None: - """The type of the tensor. - - Example types can be ``TensorType``, ``SparseTensorType``, ``SequenceType``, ``OptionalType``. - To obtain the data type of the tensor, use ``type.dtype`` or conveniently - :attribute:`dtype`. - """ - return self._type - - @type.setter - def type(self, value: _protocols.TypeProtocol | None) -> None: - self._type = value - - @property - def dtype(self) -> _enums.DataType | None: - """The data type of the tensor.""" - if self._type is None: - return None - return self._type.dtype - - @dtype.setter - def dtype(self, value: _enums.DataType) -> None: - """Set the data type of the tensor. - - If the type is not set, it will be initialized to a new TensorType. To - set the type as other types like ``SequenceType``, initialize the type - then set :attribute:`type` instead. - """ - if self._type is None: - self._type = TensorType(value) - else: - self._type.dtype = value - - @property - def shape(self) -> Shape | None: - return self._shape - - @shape.setter - def shape(self, value: Shape | None) -> None: - if value is None: - self._shape = None - return - if isinstance(value, Shape): - self._shape = value - return - raise TypeError(f"Expected value to be a Shape or None, got '{type(value)}'") - - @property - def const_value( - self, - ) -> _protocols.TensorProtocol | Sequence[_protocols.TensorProtocol] | None: - """A concrete value. - - The value can be backed by different raw data types, such as numpy arrays. - The only guarantee is that it conforms TensorProtocol. - """ - return self._const_value - - @const_value.setter - def const_value( - self, - value: _protocols.TensorProtocol | Sequence[_protocols.TensorProtocol] | None, - ) -> None: - self._const_value = value - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def is_graph_output(self) -> bool: - """Whether the value is an output of a graph.""" - if (producer := self.producer()) is None: - return False - if (graph := producer.graph) is None: - return False - # Cannot use `in` because __eq__ may be defined by subclasses, even though - # it is not recommended - return any(output is self for output in graph.outputs) - - -class Input(Value): - """Input of a Graph or a Function.""" - - # Slots already defined in Value - __slots__ = () - - def __init__( - self, - name: str | None = None, - shape: Shape | None = None, - type: _protocols.TypeProtocol | None = None, - doc_string: str | None = None, - ) -> None: - super().__init__( - None, index=None, name=name, shape=shape, type=type, doc_string=doc_string - ) - - -def _check_node_safe_to_remove( - node: Node, to_remove: AbstractSet[Node], graph_outputs: AbstractSet[Value] -) -> None: - """Check if a node is safe to remove. - - 1. It checks to make sure there are no users of the node that are not - to be removed before removing it. - 2. It checks the node does not contribute to any graph outputs. - - This check is typically O(1) assuming the number of uses of the node is small - - Args: - node: The node to check. - to_remove: A set of nodes that are to be removed. - This set is used to check if the node is still being used by other - nodes that are not to be removed. - graph_outputs: A set of values that are outputs of the graph. - - Raises: - ValueError: If the node does not belong to this graph or if there are users of the node. - ValueError: If the node is still being used by other nodes not to be removed. - """ - for output in node.outputs: - if output in graph_outputs: - raise ValueError( - f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True." - ) - for use, _ in output.uses(): - if use in to_remove: - continue - raise ValueError( - f"Node '{use!r}' is still being used by other nodes that are not to be " - f"removed. All of its uses: {list(output.uses())!r}" - ) - - -class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable): - """IR Graph. - - Graph represents a computation graph. In addition to the ONNX specification - specified fields, it also contains a mapping of :attr:`opset_imports`. This - allows different subgraphs to import different opsets. It is the responsibility - of the deserializer to reconcile the different opsets. - - The `nodes` are not guaranteed to be topologically sorted. But the - iteration order should be deterministic across different runs. It is the - responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Graph. The Graph can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(graph)``. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - __slots__ = ( - "name", - "_inputs", - "_outputs", - "_initializers", - "_doc_string", - "_opset_imports", - "_nodes", - "_metadata", - "_metadata_props", - "_name_authority", - ) - - def __init__( - self, - inputs: Sequence[Input], - outputs: Sequence[Value], - *, - nodes: Iterable[Node], - initializers: Sequence[_protocols.TensorProtocol] = (), - doc_string: str | None = None, - opset_imports: dict[str, int] | None = None, - name: str | None = None, - metadata_props: dict[str, str] | None = None, - ): - self.name = name - - # Private fields that are not to be accessed by any other classes - self._inputs = list(inputs) - self._outputs = list(outputs) - for initializer in initializers: - if isinstance(initializer, str): - raise TypeError( - "Initializer must be a TensorProtocol, not a string. " - "If you are copying the initializers from another graph, " - "make sure you call graph.initializers.values() because it is a dictionary." - ) - if initializer.name is None: - raise ValueError(f"Initializer must have a name: {initializer}") - self._initializers = {tensor.name: tensor for tensor in initializers} - self._doc_string = doc_string - self._opset_imports = opset_imports or {} - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet() - # Be sure the initialize the name authority before extending the nodes - # because it is used to name the nodes and their outputs - self._name_authority = _name_authority.NameAuthority() - # Call self.extend not self._nodes.extend so the graph reference is added to the nodes - self.extend(nodes) - - @property - def inputs(self) -> list[Input]: - return self._inputs - - @property - def outputs(self) -> list[Value]: - return self._outputs - - @property - def initializers(self) -> dict[str, _protocols.TensorProtocol]: - return self._initializers - - @property - def doc_string(self) -> str | None: - return self._doc_string - - @doc_string.setter - def doc_string(self, value: str | None) -> None: - self._doc_string = value - - @property - def opset_imports(self) -> dict[str, int]: - return self._opset_imports - - def __getitem__(self, index: int) -> Node: - return self._nodes[index] - - def __len__(self) -> int: - return len(self._nodes) - - def __iter__(self) -> Iterator[Node]: - return iter(self._nodes) - - def __reversed__(self) -> Iterator[Node]: - return reversed(self._nodes) - - def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node: - """Set the graph reference for the node and assign names to it and its outputs if they don't have one.""" - if node.graph is not None and node.graph is not self: - raise ValueError( - f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()." - ) - # Give the node and its output values names if they don't not have one - if node.name is None: - self._name_authority.name_node(node) - for value in node._outputs: # pylint: disable=protected-access - if value.name is None: - self._name_authority.name_value(value) - node.graph = self - return node - - # Mutation methods - def append(self, node: Node, /) -> None: - """Append a node to the graph in O(1) time. - - Args: - node: The node to append. - - Raises: - ValueError: If the node belongs to another graph. - """ - self._set_node_graph_to_self_and_assign_names(node) - self._nodes.append(node) - - def extend(self, nodes: Iterable[Node], /) -> None: - """Extend the graph with the given nodes in O(#new_nodes) time. - - Args: - nodes: The nodes to extend the graph with. - - Raises: - ValueError: If any node belongs to another graph. - """ - nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in nodes] - self._nodes.extend(nodes) - - def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: - """Remove nodes from the graph in O(#num of nodes) time. - - If any errors are raise, to ensure the graph is not left in an inconsistent state, - the graph is not modified. - - Args: - nodes: The node to remove. - safe: If True, performs the following actions before removal: - 1. It checks to make sure there are no users of the node that are not - to be removed before removing it. - 2. It checks the node does not contribute to any graph outputs. - 3. It removes references to all inputs so it is no longer a user of other nodes. - - Raises: - ValueError: If any node to remove does not belong to this graph. - ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node. - ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed. - """ - if not isinstance(nodes, Iterable): - nodes_set: AbstractSet[Node] = {nodes} - else: - nodes_set = frozenset(nodes) - graph_outputs = frozenset(self.outputs) - for node in nodes_set: - if node.graph is not self: - raise ValueError(f"The node '{node!r}' does not belong to this graph.") - if safe: - # Check 1, 2 - _check_node_safe_to_remove(node, nodes_set, graph_outputs) - for node in nodes_set: - if safe: - # 3. Detach from all inputs so that it is no longer a user of other nodes - for i in range(len(node.inputs)): - node.replace_input_with(i, None) - # Set attributes to remove the node from this graph - node.graph = None - self._nodes.remove(node) - - def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: - """Insert new nodes after the given node in O(#new_nodes) time. - - Args: - node: The node to insert after. - new_nodes: The new nodes to insert. - - Raises: - ValueError: If any node belongs to another graph. - """ - if isinstance(new_nodes, Node): - new_nodes = (new_nodes,) - new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes] - self._nodes.insert_after(node, new_nodes) - - def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: - """Insert new nodes before the given node in O(#new_nodes) time. - - Args: - node: The node to insert before. - new_nodes: The new nodes to insert. - - Raises: - ValueError: If any node belongs to another graph. - """ - if isinstance(new_nodes, Node): - new_nodes = (new_nodes,) - new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes] - self._nodes.insert_before(node, new_nodes) - - def sort(self) -> None: - """Topologically sort the nodes in the graph.""" - raise NotImplementedError("Not implemented yet") - - # End of mutation methods - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def __str__(self) -> str: - return _graph_str(self) - - def __repr__(self) -> str: - return _graph_repr(self) - - -def _graph_str(graph: Graph | GraphView) -> str: - """Return a string representation of the graph.""" - # TODO(justinchuby): Show docstrings and metadata - inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs) - outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs) - initializers_text = ",\n".join(str(x) for x in graph.initializers.values()) - if initializers_text: - initializers_text = ( - "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n)," - ) - signature = f"""\ -graph( - name={graph.name or 'anonymous_graph:' + str(id(graph))}, - inputs=({textwrap.indent(inputs_text, ' '*8)} - ), - outputs=({textwrap.indent(outputs_text, ' '*8)} - ),{textwrap.indent(initializers_text, ' '*4)} -)""" - node_count = len(graph) - number_width = len(str(node_count)) - node_lines = [] - for i, node in enumerate(graph): - node_name = node.name if node.name else f":anonymous_node:{id(node)}" - node_text = f"# {node_name}\n{node}" - indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) - # Remove the leading spaces - indented_node_text = indented_node_text.strip() - node_lines.append(f"{i:>{number_width}} | {indented_node_text}") - returns = ", ".join(str(x) for x in graph.outputs) - body = ( - "{\n" - + textwrap.indent("\n".join(node_lines), " " * 4) - + textwrap.indent(f"\nreturn {returns}", " " * 4) - + "\n}" - ) - - return f"{signature} {body}" - - -def _graph_repr(graph: Graph | GraphView) -> str: - """Return an repr string of the graph.""" - inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs) - outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs) - initializers_text = ",\n".join(str(x) for x in graph.initializers.values()) - if initializers_text: - initializers_text = ( - "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n)," - ) - return f"""\ -{graph.__class__.__name__}( - name={graph.name or 'anonymous_graph:' + str(id(graph))!r}, - inputs=({textwrap.indent(inputs_text, ' '*8)} - ), - outputs=({textwrap.indent(outputs_text, ' '*8)} - ),{textwrap.indent(initializers_text, ' '*4)} - len()={len(graph)} -)""" - - -class GraphView(Sequence[Node], _display.PrettyPrintable): - """A read-only view on a graph. - - The GraphView is useful for analysis of a subgraph. It can be initialized - with a subset of nodes from a :class:`Graph`. Creating GraphView does not - change the ownership of the nodes, and so it is possible to create multiple - GraphViews that contain the same nodes. If the underlying nodes / connections - are mutated, the mutation will be reflected in all views as well. - - The graph view can be serialized to ONNX:: - - graph_proto = ir.serde.serialize_graph(graph_view) - - It can also be used to create a model:: - - model = ir.Model(graph_view, ir_version=8) - model_proto = ir.serde.serialize_model(model) - - The model created with a GraphView will have a fixed topology, and its graph - will remain read-only as a GraphView. No copying will be done during the - initialization process. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - __slots__ = ( - "name", - "inputs", - "outputs", - "initializers", - "doc_string", - "opset_imports", - "nodes", - "_metadata", - "_metadata_props", - ) - - def __init__( - self, - inputs: Sequence[Value], - outputs: Sequence[Value], - *, - nodes: Iterable[Node], - initializers: Sequence[_protocols.TensorProtocol] = (), - doc_string: str | None = None, - opset_imports: dict[str, int] | None = None, - name: str | None = None, - metadata_props: dict[str, str] | None = None, - ): - self.name = name - self.inputs = tuple(inputs) - self.outputs = tuple(outputs) - for initializer in initializers: - if initializer.name is None: - raise ValueError(f"Initializer must have a name: {initializer}") - self.initializers = {tensor.name: tensor for tensor in initializers} - self.doc_string = doc_string - self.opset_imports = opset_imports or {} - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - self._nodes: tuple[Node, ...] = tuple(nodes) - - def __getitem__(self, index: int) -> Node: - return self._nodes[index] - - def __len__(self) -> int: - return len(self._nodes) - - def __iter__(self) -> Iterator[Node]: - return iter(self._nodes) - - def __reversed__(self) -> Iterator[Node]: - return reversed(self._nodes) - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def __str__(self) -> str: - return _graph_str(self) - - def __repr__(self) -> str: - return _graph_repr(self) - - -class Model(_protocols.ModelProtocol, _display.PrettyPrintable): - __slots__ = ( - "graph", - "ir_version", - "producer_name", - "producer_version", - "domain", - "model_version", - "doc_string", - "_functions", - "_metadata", - "_metadata_props", - ) - """IR Model. - - A model is a container for a graph and metadata. - - Attributes: - graph: The graph of the model. - ir_version: The version of the IR. - producer_name: The name of the producer. - producer_version: The version of the producer. - domain: The domain of the model. - model_version: The version of the model. - doc_string: Documentation string. - functions: The functions defined in the model. - metadata_props: Metadata. - """ - - def __init__( - self, - graph: Graph, - *, - ir_version: int, - producer_name: str | None = None, - producer_version: str | None = None, - domain: str | None = None, - model_version: int | None = None, - doc_string: str | None = None, - functions: Sequence[Function] = (), - meta_data_props: dict[str, str] | None = None, - ) -> None: - self.graph: Graph = graph - self.ir_version = ir_version - self.producer_name = producer_name - self.producer_version = producer_version - self.domain = domain - self.model_version = model_version - self.doc_string = doc_string - self._functions = {func.identifier(): func for func in functions} - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = meta_data_props - - @property - def functions(self) -> dict[_protocols.OperatorIdentifier, Function]: - return self._functions - - @property - def opset_imports(self) -> dict[str, int]: - return self.graph.opset_imports - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def __str__(self) -> str: - # TODO(justinchuby): Show docstrings and metadata - signature = f"""\ -< - ir_version={self.ir_version!r}, - opset_imports={self.opset_imports!r}, - producer_name={self.producer_name!r}, - producer_version={self.producer_version!r}, - domain={self.domain!r}, - model_version={self.model_version!r}, ->""" - graph_text = str(self.graph) - functions_text = ",\n\n".join(str(func) for func in self.functions.values()) - return f"{signature}\n{graph_text}" + f"\n\n{functions_text}" * len(self.functions) - - def __repr__(self) -> str: - return f"""\ -Model( - ir_version={self.ir_version!r}, - opset_imports={self.opset_imports!r}, - producer_name={self.producer_name!r}, - producer_version={self.producer_version!r}, - domain={self.domain!r}, - model_version={self.model_version!r}, - functions={self.functions!r}, - graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()} -)""" - - -class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable): - """IR functions. - - Like a graph, a function can have nodes that are not topologically sorted. It is - the responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Function. The Function can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(function)``. - - Attributes: - name: The function name. - domain: The domain this function is defined in. - overload: The overload name when the function is overloaded. - inputs: The input values of the function. - attributes: The attributes this function defines. - outputs: The output values of the function. - opset_imports: Opsets imported by the function. - doc_string: Documentation string. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - __slots__ = ( - "_domain", - "_name", - "_overload", - "_graph", - "_attributes", - "_metadata", - "_metadata_props", - ) - - def __init__( - self, - domain: str, - name: str, - overload: str = "", - *, - # Ensure the inputs and outputs of the function belong to a graph - # and not from an outer scope - graph: Graph, - attributes: Sequence[Attr], - metadata_props: dict[str, str] | None = None, - ) -> None: - self._domain = domain - self._name = name - self._overload = overload - self._graph = graph - self._attributes = OrderedDict((attr.name, attr) for attr in attributes) - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - - def identifier(self) -> _protocols.OperatorIdentifier: - return self.domain, self.name, self.overload - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, value: str) -> None: - self._name = value - - @property - def domain(self) -> str: - return self._domain - - @domain.setter - def domain(self, value: str) -> None: - self._domain = value - - @property - def overload(self) -> str: - return self._overload - - @overload.setter - def overload(self, value: str) -> None: - self._overload = value - - @property - def inputs(self) -> list[Input]: - return self._graph.inputs - - @property - def outputs(self) -> list[Value]: - return self._graph.outputs - - @property - def attributes(self) -> OrderedDict[str, Attr]: - return self._attributes - - def __getitem__(self, index: int) -> Node: - return self._graph.__getitem__(index) - - def __len__(self) -> int: - return self._graph.__len__() - - def __iter__(self) -> Iterator[Node]: - return self._graph.__iter__() - - def __reversed__(self) -> Iterator[Node]: - return self._graph.__reversed__() - - @property - def doc_string(self) -> str | None: - return self._graph.doc_string - - @doc_string.setter - def doc_string(self, value: str | None) -> None: - self._graph.doc_string = value - - @property - def opset_imports(self) -> dict[str, int]: - return self._graph.opset_imports - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - # Mutation methods - def append(self, node: Node, /) -> None: - """Append a node to the function in O(1) time.""" - self._graph.append(node) - - def extend(self, nodes: Iterable[Node], /) -> None: - """Extend the function with the given nodes in O(#new_nodes) time.""" - self._graph.extend(nodes) - - def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: - """Remove nodes from the graph in O(#num of nodes) time. - - If any errors are raise, to ensure the graph is not left in an inconsistent state, - the graph is not modified. - - Args: - nodes: The node to remove. - safe: If True, performs the following actions before removal: - 1. It checks to make sure there are no users of the node that are not - to be removed before removing it. - 2. It checks the node does not contribute to any graph outputs. - 3. It removes references to all inputs so it is no longer a user of other nodes. - - Raises: - ValueError: If any node to remove does not belong to this graph. - ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node. - ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed. - """ - self._graph.remove(nodes, safe=safe) - - def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None: - """Insert new nodes after the given node in O(#new_nodes) time.""" - self._graph.insert_after(node, new_nodes) - - def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None: - """Insert new nodes before the given node in O(#new_nodes) time.""" - self._graph.insert_before(node, new_nodes) - - def sort(self) -> None: - """Topologically sort the nodes in the function.""" - self._graph.sort() - - # End of mutation methods - - def __str__(self) -> str: - full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "") - inputs_text = ",\n".join(str(x) for x in self.inputs) - outputs_text = ",\n".join(str(x) for x in self.outputs) - attributes_text = ",\n".join( - f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is None) - for attr in self.attributes.values() - ) - if attributes_text: - attributes_text = ( - "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}" - ) - signature = f"""\ -< - opset_imports={self.opset_imports!r}, -> -def {full_name}( - inputs=( -{textwrap.indent(inputs_text, ' '*8)} - ),{textwrap.indent(attributes_text, ' '*4)} - outputs=( -{textwrap.indent(outputs_text, ' '*8)} - ), -)""" - node_count = len(self) - number_width = len(str(node_count)) - node_lines = [] - for i, node in enumerate(self): - node_name = node.name if node.name else f":anonymous_node:{id(node)}" - node_text = f"# {node_name}\n{node}" - indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) - # Remove the leading spaces - indented_node_text = indented_node_text.strip() - node_lines.append(f"{i:>{number_width}} | {indented_node_text}") - returns = ", ".join(str(x) for x in self.outputs) - body = ( - "{\n" - + textwrap.indent("\n".join(node_lines), " " * 4) - + textwrap.indent(f"\nreturn {returns}", " " * 4) - + "\n}" - ) - - return f"{signature} {body}" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})" - - -class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable): - """Reference attribute.""" - - __slots__ = ("_name", "_ref_attr_name", "_type", "doc_string") - - def __init__( - self, - name: str, - ref_attr_name: str, - type: _enums.AttributeType, - *, - doc_string: str | None = None, - ) -> None: - self._name = name - self._ref_attr_name = ref_attr_name - self._type = type - self.doc_string = doc_string - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, value: str) -> None: - self._name = value - - @property - def ref_attr_name(self) -> str: - return self._ref_attr_name - - @ref_attr_name.setter - def ref_attr_name(self, value: str) -> None: - self._ref_attr_name = value - - @property - def type(self) -> _enums.AttributeType: - return self._type - - @type.setter - def type(self, value: _enums.AttributeType) -> None: - self._type = value - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})" - - -class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable): - """Base class for ONNX attributes.""" - - __slots__ = ("name", "type", "value", "doc_string") - - def __init__( - self, - name: str, - type: _enums.AttributeType, - value: Any, - *, - doc_string: str | None = None, - ): - self.name = name - self.type = type - self.value = value - self.doc_string = doc_string - - def __eq__(self, other: object) -> bool: - if not isinstance(other, _protocols.AttributeProtocol): - return False - - if self.name != other.name: - return False - if self.type != other.type: - return False - if self.value != other.value: - return False - if self.doc_string != other.doc_string: - return False - return True - - def __str__(self) -> str: - return str(self.value) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})" - - -class _SpecializedAttr(Attr): - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.name!r}, {self.value!r})" - - -# NOTE: The following classes are just supporting classes (partially applied) for convenience -# But I think they would be useful to have in the IR by having the type info -# explicitly in the class type. -class AttrFloat32(_SpecializedAttr): - def __init__(self, name: str, value: float, doc_string: str | None = None): - super().__init__( - name, - _enums.AttributeType.FLOAT, - value, - doc_string=doc_string, - ) - - -class AttrInt64(_SpecializedAttr): - def __init__(self, name: str, value: int, doc_string: str | None = None): - super().__init__( - name, - _enums.AttributeType.INT, - value, - doc_string=doc_string, - ) - - -class AttrString(_SpecializedAttr): - def __init__(self, name: str, value: str, doc_string: str | None = None): - super().__init__( - name, - _enums.AttributeType.STRING, - value, - doc_string=doc_string, - ) - - -class AttrTensor(_SpecializedAttr): - def __init__( - self, - name: str, - value: _protocols.TensorProtocol, - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.TENSOR, - value, - doc_string=doc_string, - ) - - -class AttrGraph(_SpecializedAttr): - def __init__( - self, - name: str, - value: Graph, - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.GRAPH, - value, - doc_string=doc_string, - ) - - def __str__(self) -> str: - return textwrap.indent("\n" + super().__str__(), " " * 4) - - -class AttrFloat32s(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[float], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.FLOATS, - value, - doc_string=doc_string, - ) - - -class AttrInt64s(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[int], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.INTS, - value, - doc_string=doc_string, - ) - - -class AttrStrings(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[str], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.STRINGS, - value, - doc_string=doc_string, - ) - - -class AttrTensors(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[_protocols.TensorProtocol], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.TENSORS, - value, - doc_string=doc_string, - ) - - -class AttrGraphs(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[Graph], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.GRAPHS, - value, - doc_string=doc_string, - ) - - -# NOTE: SparseTensor should be a sparse tensor proto -class AttrSparseTensor(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[_protocols.SparseTensorProtocol], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.SPARSE_TENSOR, - value, - doc_string=doc_string, - ) - - -class AttrSparseTensors(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[_protocols.SparseTensorProtocol], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.SPARSE_TENSORS, - value, - doc_string=doc_string, - ) - - -@dataclasses.dataclass -class TypeAndShape: - """Type and shape. - - Useful for constructing a type proto. - """ - - type: _protocols.TypeProtocol | None - shape: Shape | None - - -class AttrTypeProto(_SpecializedAttr): - def __init__( - self, - name: str, - value: TypeAndShape, - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.TYPE_PROTO, - value, - doc_string=doc_string, - ) - - -class AttrTypeProtos(_SpecializedAttr): - def __init__( - self, - name: str, - value: Sequence[TypeAndShape], - doc_string: str | None = None, - ): - super().__init__( - name, - _enums.AttributeType.TYPE_PROTOS, - value, - doc_string=doc_string, - ) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py deleted file mode 100644 index 103e5b1700..0000000000 --- a/onnxscript/ir/_core_test.py +++ /dev/null @@ -1,580 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from __future__ import annotations - -import pathlib -import tempfile -import unittest -from typing import Any - -import numpy as np -import onnx -import onnx.external_data_helper -import parameterized -import torch - -from onnxscript.ir import _core, _enums - - -class TensorTest(unittest.TestCase): - def test_initialize(self): - tensor = _core.Tensor( - np.random.rand(1, 2).astype(np.float32), - dtype=_enums.DataType.FLOAT, - shape=_core.Shape((1, 2)), - name="test", - ) - self.assertEqual(tensor.name, "test") - self.assertEqual(tensor.dtype, _enums.DataType.FLOAT) - self.assertEqual(tensor.shape, _core.Shape((1, 2))) - np.testing.assert_array_equal(tensor, tensor) - - def test_init_raises_when_value_is_not_array(self): - with self.assertRaises(TypeError): - _core.Tensor(42) - - def test_init_requires_type_when_value_is_not_np_array(self): - torch_tensor = torch.tensor(42) - with self.assertRaises(ValueError): - _core.Tensor(torch_tensor) - - @parameterized.parameterized.expand( - [ - ("bfloat16", np.uint16, _enums.DataType.BFLOAT16), - ( - "float8e4m3fn", - np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)})), - _enums.DataType.FLOAT8E4M3FN, - ), - ("float8e4m3fnuz", np.uint8, _enums.DataType.FLOAT8E4M3FNUZ), - ("float8e5m2", np.uint8, _enums.DataType.FLOAT8E5M2), - ("float8e5m2fnuz", np.uint8, _enums.DataType.FLOAT8E5M2FNUZ), - ("int4", np.int8, _enums.DataType.INT4), - ("int4_uint8", np.uint8, _enums.DataType.INT4), - ("uint4", np.uint8, _enums.DataType.UINT4), - ] - ) - def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: _enums.DataType): - array = np.array([0b1, 0b11], dtype=np_dtype) - tensor = _core.Tensor(array, dtype=dtype) - self.assertEqual(tensor.dtype, dtype) - np.testing.assert_array_equal(tensor, array) - - def test_initialize_with_just_np_array(self): - array = np.random.rand(1, 2) - tensor = _core.Tensor(array) - np.testing.assert_array_equal(tensor, array) - - def test_initialize_raises_when_numpy_dtype_doesnt_match(self): - array = np.random.rand(1, 2).astype(np.float32) - with self.assertRaises(TypeError): - _core.Tensor(array, dtype=_enums.DataType.INT64) - - def test_initialize_raises_when_numpy_dtype_doesnt_match_custom_dtype(self): - custom_dtype = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)})) - array = np.random.rand(1, 2).astype(custom_dtype) - with self.assertRaises(TypeError): - _core.Tensor(array, dtype=_enums.DataType.BFLOAT16) - - def test_initialize_with_torch_tensor(self): - array = np.random.rand(1, 2).astype(np.int64) - np_tensor = _core.Tensor(array) - torch_tensor = _core.Tensor(torch.tensor(array), dtype=_enums.DataType.INT64) - np.testing.assert_array_equal(torch_tensor, array) - np.testing.assert_array_equal(torch_tensor, np_tensor) - - def test_dlpack_np_to_torch(self): - array = np.random.rand(1, 2).astype(np.float32) - tensor = _core.Tensor(array) - torch_tensor = torch.from_dlpack(tensor) - np.testing.assert_array_equal(torch_tensor, array) - - def test_dlpack_torch_to_np(self): - torch_tensor = torch.rand(1, 2) - tensor = _core.Tensor(torch_tensor, dtype=_enums.DataType.FLOAT) - array = np.from_dlpack(tensor) - np.testing.assert_array_equal(array, torch_tensor) - - def test_repr(self): - tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32)) - self.assertIsInstance(repr(tensor), str) - - def test_dtype_returns_data_type_enum(self): - tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32)) - self.assertEqual(tensor.dtype, _enums.DataType.FLOAT) - - def test_shape(self): - tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32)) - self.assertEqual(tensor.shape, _core.Shape((1, 2))) - - def test_numpy_returns_np_array(self): - array = np.random.rand(1, 2).astype(np.float32) - tensor = _core.Tensor(array) - np.testing.assert_equal(tensor.numpy(), array) - - def test_numpy_returns_data_when_dtype_is_not_supported(self): - array = np.array([1], dtype=np.uint8) - tensor = _core.Tensor(array, dtype=_enums.DataType.INT4) - np.testing.assert_equal(tensor.numpy(), array) - - def test_tobytes(self): - array = np.random.rand(1, 2).astype(np.float32) - torch_tensor = torch.tensor(array) - tensor = _core.Tensor(torch_tensor, dtype=_enums.DataType.FLOAT) - self.assertEqual(tensor.tobytes(), array.tobytes()) - - def test_tobtyes_returns_packed_data_for_int4(self): - array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=_enums.DataType.INT4) - self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01") - - def test_tobtyes_returns_packed_data_for_uint4(self): - array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=_enums.DataType.UINT4) - self.assertEqual(tensor.tobytes(), b"\x10r\x0f") - - def test_metadata(self): - array = np.random.rand(1, 2).astype(np.float32) - tensor = _core.Tensor(array) - tensor.meta["test"] = 1 - self.assertEqual(tensor.meta["test"], 1) - tensor.metadata_props["test"] = "any string" - self.assertEqual(tensor.metadata_props["test"], "any string") - - -class ExternalTensorTest(unittest.TestCase): - """Test the memory mapped external tensor class.""" - - def setUp(self): - self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with - self.external_data_name = "test_model.bin" - self.base_path = self.temp_dir.name - self.data = np.random.rand(2, 42).astype(np.float32) - self.data_float16 = np.random.rand(2, 42).astype(np.float16) - self.model = self._simple_model_with_external( - self.base_path, self.external_data_name, self.data - ) - - def tearDown(self) -> None: - self.temp_dir.cleanup() - - def _simple_model_with_external( - self, base_path: str, external_data_name: str, data: np.ndarray - ) -> onnx.ModelProto: - input = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [None]) - output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [None]) - raw_data = data.tobytes() - tensor = onnx.helper.make_tensor( - "input", onnx.TensorProto.FLOAT, data.shape, raw_data, raw=True - ) - raw_data2 = self.data_float16.tobytes() - tensor2 = onnx.helper.make_tensor( - "input2", onnx.TensorProto.FLOAT16, data.shape, raw_data2, raw=True - ) - onnx.external_data_helper.set_external_data( - tensor, external_data_name, offset=0, length=len(raw_data) - ) - onnx.external_data_helper.set_external_data( - tensor2, external_data_name, offset=len(raw_data), length=len(raw_data2) - ) - - node = onnx.helper.make_node("Identity", inputs=["input"], outputs=["output"]) - model = onnx.helper.make_model( - onnx.helper.make_graph( - [node], "test_graph", [input], [output], initializer=[tensor, tensor2] - ) - ) - tensor.ClearField("raw_data") - tensor2.ClearField("raw_data") - # Save the data to disk - with open(pathlib.Path(base_path) / external_data_name, "wb") as f: - f.write(raw_data) - f.write(raw_data2) - return model - - def test_initialize(self): - external_tensor = self.model.graph.initializer[0] - external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) - tensor = _core.ExternalTensor( - path=pathlib.Path(self.base_path) / external_info.location, - offset=external_info.offset, - length=external_info.length, - dtype=_enums.DataType.FLOAT, - name="input", - shape=_core.Shape(external_tensor.dims), - ) - self.assertEqual(tensor.dtype, _enums.DataType.FLOAT) - np.testing.assert_equal(tensor, self.data) - # Ensure repeated reads are consistent - np.testing.assert_equal(tensor, self.data) - - def test_totypes_returns_correct_data_in(self): - external_tensor = self.model.graph.initializer[0] - external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) - tensor = _core.ExternalTensor( - path=pathlib.Path(self.base_path) / external_info.location, - offset=external_info.offset, - length=external_info.length, - dtype=_enums.DataType.FLOAT, - name="input", - shape=_core.Shape(external_tensor.dims), - ) - external_tensor2 = self.model.graph.initializer[1] - external_info2 = onnx.external_data_helper.ExternalDataInfo(external_tensor2) - tensor2 = _core.ExternalTensor( - path=pathlib.Path(self.base_path) / external_info2.location, - offset=external_info2.offset, - length=external_info2.length, - dtype=_enums.DataType.FLOAT16, - name="input", - shape=_core.Shape(external_tensor2.dims), - ) - self.assertEqual(tensor.tobytes(), self.data.tobytes()) - self.assertEqual(tensor2.tobytes(), self.data_float16.tobytes()) - # Ensure repeated reads are consistent - self.assertEqual(tensor.tobytes(), self.data.tobytes()) - self.assertEqual(tensor2.tobytes(), self.data_float16.tobytes()) - - -class SymbolicDimTest(unittest.TestCase): - def test_init_raises_when_value_is_int(self): - # Static dimensions should be python integers - with self.assertRaises(TypeError): - _core.SymbolicDim(42) - - @parameterized.parameterized.expand([("str", "any string"), ("None", None)]) - def test_equality_with_other_dimensions(self, _: str, value: Any): - dim1 = _core.SymbolicDim(value) - dim2 = _core.SymbolicDim(value) - self.assertEqual(dim1, dim2) - - @parameterized.parameterized.expand([("str", "any string"), ("None", None)]) - def test_equality_with_python_values(self, _: str, value: Any): - dim = _core.SymbolicDim(value) - self.assertEqual(dim, value) - self.assertIn(value, [dim]) - self.assertIn(dim, [value]) - - @parameterized.parameterized.expand([("str", "any string"), ("None", None)]) - def test_it_is_hashable(self, _: str, value: Any): - dim = _core.SymbolicDim(value) - self.assertEqual(hash(dim), hash(value)) - self.assertIn(dim, {dim}) - self.assertIn(dim, {value}) - - -class ShapeTest(unittest.TestCase): - def test_init_raises_when_denotations_and_dims_have_different_lengths(self): - with self.assertRaisesRegex(ValueError, "denotations"): - _core.Shape([42], ["DATA_CHANNEL", "BATCH"]) - - def test_int_dimensions_are_python_ints(self): - shape = _core.Shape([42]) - self.assertIsInstance(shape[0], int) - - @parameterized.parameterized.expand( - [ - ("empty", (), ()), - ("1d", (42,), (42,)), - ("int", (42, 42), (42, 42)), - ("str", ("any string", "any string"), ("any string", "any string")), - ("None", (None, None), (None, None)), - ] - ) - def test_eq_with_other_shapes( - self, _: str, dims_1: tuple[Any, ...], dims_2: tuple[Any, ...] - ): - shape_1 = _core.Shape(dims_1) - shape_2 = _core.Shape(dims_2) - self.assertEqual(shape_1, shape_2) - - @parameterized.parameterized.expand( - [ - ("empty", ()), - ("1d", (42,)), - ("int", (42, 42)), - ("str", ("any string", "any string")), - ("None", (None, None)), - ] - ) - def test_eq_with_tuple(self, _: str, dims: tuple[Any, ...]): - shape = _core.Shape(dims) - self.assertEqual(shape, dims) - - @parameterized.parameterized.expand( - [ - ("empty", []), - ( - "1d", - [ - 42, - ], - ), - ("int", [42, 42]), - ("str", ["any string", "any string"]), - ("None", [None, None]), - ] - ) - def test_eq_with_list(self, _: str, dims: list[Any]): - shape = _core.Shape(dims) - self.assertEqual(shape, dims) - - def test_eq_with_np_shape(self): - dims = (42,) - array = np.zeros(dims) - shape = _core.Shape(dims) - self.assertEqual(shape, array.shape) - - @parameterized.parameterized.expand( - [ - ("empty", (), (1,)), - ("d", (42,), (0,)), - ("rank", (42, 42), (42, 42, 42)), - ("str", ("any string",), (42,)), - ("None", (None, None), (None, 42)), - ] - ) - def test_ne_with_other_shapes( - self, _: str, dims_1: tuple[Any, ...], dims_2: tuple[Any, ...] - ): - shape_1 = _core.Shape(dims_1) - shape_2 = _core.Shape(dims_2) - self.assertNotEqual(shape_1, shape_2) - - def test_ne_with_random_object(self): - shape = _core.Shape((42,)) - self.assertNotEqual(shape, 42) - - def test_setitem_raises_when_shape_is_frozen(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",), frozen=True) - with self.assertRaisesRegex(TypeError, "frozen"): - shape[0] = 1 - - def test_getitem(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) - self.assertEqual(shape[0], 42) - - def test_getitem_accepts_a_slice(self): - shape = _core.Shape([1, 2, 3, 4]) - self.assertEqual(shape[1:3], (2, 3)) - - @parameterized.parameterized.expand( - [ - ("int", 42), - ("str", "any string"), - ("None", None), - ("SymbolicDim", _core.SymbolicDim("any string")), - ] - ) - def test_setitem(self, _: str, value): - shape = _core.Shape([0]) - shape[0] = value - dim = shape[0] - if isinstance(dim, _core.SymbolicDim): - self.assertEqual(dim.value, value) - else: - self.assertEqual(dim, value) - - def test_get_denotation(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) - self.assertEqual(shape.get_denotation(0), "DATA_CHANNEL") - - def test_set_denotation(self): - shape = _core.Shape([42, 0], ["DATA_CHANNEL", "BATCH"]) - shape.set_denotation(1, "UPDATED") - self.assertEqual(shape.get_denotation(1), "UPDATED") - - def test_set_denotation_is_still_possible_when_shape_is_frozen(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",), frozen=True) - shape.set_denotation(0, "UPDATED") - self.assertEqual(shape.get_denotation(0), "UPDATED") - - -class ValueTest(unittest.TestCase): - def test_initialize(self): - _ = _core.Value(None, index=0) - - def test_meta(self): - value = _core.Value(None, index=0) - value.meta["test"] = 1 - self.assertEqual(value.meta["test"], 1) - value.metadata_props["test"] = "any string" - self.assertEqual(value.metadata_props["test"], "any string") - - # TODO(justinchuby): Test all methods - - -class NodeTest(unittest.TestCase): - def setUp(self) -> None: - self.v0 = _core.Value(None, index=None) - self.v1 = _core.Value(None, index=None) - self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3) - - def test_init_with_values(self): - self.assertEqual(self.node.domain, "test") - self.assertEqual(self.node.op_type, "TestOp") - self.assertEqual(self.node.inputs, (self.v0, self.v1)) - self.assertEqual(len(self.node.outputs), 3) - self.assertEqual(self.node.attributes, {}) - - def test_init_with_preinitialized_outputs(self): - out_1 = _core.Value( - None, - index=None, - name="out_1", - shape=_core.Shape([1]), - type=_core.TensorType(_enums.DataType.BFLOAT16), - ) - out_2 = _core.Value( - None, - index=None, - name="out_2", - shape=_core.Shape([2]), - type=_core.TensorType(_enums.DataType.INT4), - ) - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[out_1, out_2]) - self.assertEqual(node.outputs[0].name, "out_1") - self.assertEqual(node.outputs[0].shape, _core.Shape([1])) - self.assertEqual(node.outputs[0].dtype, _enums.DataType.BFLOAT16) - self.assertEqual(node.outputs[1].name, "out_2") - self.assertEqual(node.outputs[1].shape, _core.Shape([2])) - self.assertEqual(node.outputs[1].dtype, _enums.DataType.INT4) - self.assertIs(node.outputs[0], out_1) - self.assertIs(node.outputs[1], out_2) - self.assertIs(node.outputs[0].producer(), node) - self.assertIs(node.outputs[1].producer(), node) - self.assertIs(node.outputs[0].index(), 0) - self.assertIs(node.outputs[1].index(), 1) - - def test_init_raises_when_num_outputs_does_not_match_outputs(self): - with self.assertRaisesRegex(ValueError, "outputs"): - _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=2, outputs=[]) - - def test_init_with_zero_num_outputs(self): - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=0) - self.assertEqual(node.outputs, ()) - - def test_init_with_empty_outputs(self): - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[]) - self.assertEqual(node.outputs, ()) - - def test_init_produces_one_output_with_unspecified_output_argument(self): - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1)) - self.assertEqual(len(node.outputs), 1) - - def test_metadata(self): - self.node.meta["test"] = 1 - self.assertEqual(self.node.meta["test"], 1) - self.node.metadata_props["test"] = "any string" - self.assertEqual(self.node.metadata_props["test"], "any string") - - def test_it_is_added_to_a_graph_if_specified(self): - graph = _core.Graph( - (self.v0, self.v1), # type: ignore - self.node.outputs, - nodes=(self.node,), - opset_imports={"": 1}, - ) - self.assertIn(self.node, graph) - - # TODO(justinchuby): Test all methods - - -class GraphTest(unittest.TestCase): - def setUp(self) -> None: - self.v0 = _core.Input(name="v0") - self.v1 = _core.Input(name="v1") - self.node = _core.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) - self.graph = _core.Graph( - (self.v0, self.v1), - self.node.outputs, - nodes=(self.node,), - opset_imports={"": 1}, - ) - - def test_initialize(self): - self.assertEqual(self.graph.inputs, [self.v0, self.v1]) - self.assertEqual(self.graph.outputs, [*self.node.outputs]) - self.assertEqual(self.graph.opset_imports, {"": 1}) - self.assertEqual(self.graph.initializers, {}) - self.assertIsNone(self.graph.doc_string) - - def test_it_is_iterable_of_nodes(self): - self.assertEqual(list(self.graph), [self.node]) - - def test_metadata(self): - self.graph.meta["test"] = 1 - self.assertEqual(self.graph.meta["test"], 1) - self.graph.metadata_props["test"] = "any string" - self.assertEqual(self.graph.metadata_props["test"], "any string") - - def test_remove_removes_node_from_graph(self): - self.graph.remove(self.node) - self.assertEqual(list(self.graph), []) - self.assertIsNone(self.node.graph) - - def test_remove_does_not_change_input_users(self): - self.graph.remove(self.node) - self.assertEqual(tuple(self.v0.uses()), ((self.node, 0),)) - self.assertEqual(tuple(self.v1.uses()), ((self.node, 1),)) - - def test_remove_does_not_change_graph_in_out(self): - self.graph.remove(self.node) - self.assertEqual(self.graph.inputs, [self.v0, self.v1]) - self.assertEqual(self.graph.outputs, list(self.node.outputs)) - - def test_remove_raises_when_node_does_not_belong_to_graph(self): - node = _core.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) - with self.assertRaisesRegex(ValueError, "graph"): - self.graph.remove(node) - - def test_remove_safe_raises_when_node_output_is_graph_output(self): - with self.assertRaisesRegex(ValueError, "output"): - self.graph.remove(self.node, safe=True) - - def test_remove_safe_raises_when_node_has_users(self): - v0 = _core.Input(name="v0") - v1 = _core.Input(name="v1") - add_node = _core.Node("", "Add", inputs=(v0, v1), num_outputs=1) - identity_node = _core.Node("", "Identity", inputs=add_node.outputs, num_outputs=1) - graph = _core.Graph( - (v0, v1), - identity_node.outputs, - nodes=(add_node, identity_node), - opset_imports={"": 1}, - ) - with self.assertRaisesRegex(ValueError, "used by other nodes"): - graph.remove(add_node, safe=True) - - def test_remove_safe_removes_uses_of_removed_nodes(self): - v0 = _core.Input(name="v0") - v1 = _core.Input(name="v1") - add_node = _core.Node("", "Add", inputs=(v0, v1), num_outputs=1) - identity_node = _core.Node("", "Identity", inputs=add_node.outputs, num_outputs=1) - graph = _core.Graph( - (v0, v1), - identity_node.outputs, - nodes=(add_node, identity_node), - opset_imports={"": 1}, - ) - # Remove add_node and check that it is no longer a consumer of v0 and v1 - sub_node = _core.Node("", "Sub", inputs=(v0, v1), num_outputs=1) - identity_node.replace_input_with(0, sub_node.outputs[0]) - graph.insert_before(identity_node, sub_node) - graph.remove(add_node, safe=True) - self.assertEqual(tuple(v0.uses()), ((sub_node, 0),)) - self.assertEqual(tuple(v1.uses()), ((sub_node, 1),)) - self.assertEqual(tuple(graph), (sub_node, identity_node)) - self.assertEqual(add_node.inputs, (None, None)) - - # TODO(justinchuby): Test graph mutation methods - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_display.py b/onnxscript/ir/_display.py deleted file mode 100644 index 937af92995..0000000000 --- a/onnxscript/ir/_display.py +++ /dev/null @@ -1,58 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Internal utilities for displaying the intermediate representation of a model. - -NOTE: All third-party imports should be scoped and imported only when used to avoid -importing unnecessary dependencies. -""" -# pylint: disable=import-outside-toplevel - -from __future__ import annotations - -from typing import Any - -_LONG_TEXT_LIMIT = 3000 - - -def require_rich() -> Any: - """Raise an ImportError if rich is not installed.""" - try: - import rich - except ImportError: - return None - return rich - - -class PrettyPrintable: - def display(self, *, page: bool | None = None) -> None: - """Pretty print the object. - - Args: - page: Whether to page the output if it is too long. - """ - rich = require_rich() - text = str(self) - - if rich is None: - print(text) - # Color print this message - print( - f"\n\n\u001b[36mTip: Install the rich library with 'pip install rich' to pretty print this {self.__class__.__name__}.\u001b[0m" - ) - return - - if page is None and len(text) > _LONG_TEXT_LIMIT: - # By default, page the output if it is too long - page = True - if page: - import rich.console - import rich.syntax - - console = rich.console.Console() - syntax = rich.syntax.Syntax(text, "cpp", theme="ansi_light") - with console.pager(styles=True): - console.print(syntax) - else: - rich.print(text) diff --git a/onnxscript/ir/_display_test.py b/onnxscript/ir/_display_test.py deleted file mode 100644 index 33e603a9b2..0000000000 --- a/onnxscript/ir/_display_test.py +++ /dev/null @@ -1,24 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Test display() methods in various classes.""" - -import contextlib -import unittest - -import numpy as np - -import onnxscript.ir as ir - - -class DisplayTest(unittest.TestCase): - def test_tensor_display_does_not_raise_on_nan_values(self): - array_with_nan = np.array([np.inf, -np.inf, np.nan, 5, -10], dtype=np.float32) - tensor = ir.Tensor(array_with_nan, dtype=ir.DataType.FLOAT) - with contextlib.redirect_stdout(None): - tensor.display() - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py deleted file mode 100644 index f2835fdad6..0000000000 --- a/onnxscript/ir/_enums.py +++ /dev/null @@ -1,159 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""ONNX IR enums that matches the ONNX spec.""" - -from __future__ import annotations - -import enum - -import numpy as np - - -class AttributeType(enum.IntEnum): - """Enum for the types of ONNX attributes.""" - - UNDEFINED = 0 - FLOAT = 1 - INT = 2 - STRING = 3 - TENSOR = 4 - GRAPH = 5 - FLOATS = 6 - INTS = 7 - STRINGS = 8 - TENSORS = 9 - GRAPHS = 10 - SPARSE_TENSOR = 11 - SPARSE_TENSORS = 12 - TYPE_PROTO = 13 - TYPE_PROTOS = 14 - - def __repr__(self) -> str: - return self.name - - def __str__(self) -> str: - return self.__repr__() - - -class DataType(enum.IntEnum): - """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``.""" - - # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64, - # but we should stick to the names used in the ONNX spec for consistency. - UNDEFINED = 0 - FLOAT = 1 - UINT8 = 2 - INT8 = 3 - UINT16 = 4 - INT16 = 5 - INT32 = 6 - INT64 = 7 - STRING = 8 - BOOL = 9 - FLOAT16 = 10 - DOUBLE = 11 - UINT32 = 12 - UINT64 = 13 - COMPLEX64 = 14 - COMPLEX128 = 15 - BFLOAT16 = 16 - FLOAT8E4M3FN = 17 - FLOAT8E4M3FNUZ = 18 - FLOAT8E5M2 = 19 - FLOAT8E5M2FNUZ = 20 - UINT4 = 21 - INT4 = 22 - - @classmethod - def from_numpy(cls, dtype: np.dtype) -> DataType: - """Returns the ONNX data type for the numpy dtype. - - Raises: - TypeError: If the data type is not supported by ONNX. - """ - if dtype not in _NP_TYPE_TO_DATA_TYPE: - raise TypeError(f"Unsupported numpy data type: {dtype}") - return cls(_NP_TYPE_TO_DATA_TYPE[dtype]) - - @property - def itemsize(self) -> float: - """Returns the size of the data type in bytes.""" - return _ITEMSIZE_MAP[self] - - def numpy(self) -> np.dtype: - """Returns the numpy dtype for the ONNX data type. - - Raises: - TypeError: If the data type is not supported by numpy. - """ - if self not in _DATA_TYPE_TO_NP_TYPE: - raise TypeError(f"Numpy does not support ONNX data type: {self}") - return _DATA_TYPE_TO_NP_TYPE[self] - - def __repr__(self) -> str: - return self.name - - def __str__(self) -> str: - return self.__repr__() - - -_ITEMSIZE_MAP = { - DataType.FLOAT: 4, - DataType.UINT8: 1, - DataType.INT8: 1, - DataType.UINT16: 2, - DataType.INT16: 2, - DataType.INT32: 4, - DataType.INT64: 8, - DataType.STRING: 1, - DataType.BOOL: 1, - DataType.FLOAT16: 2, - DataType.DOUBLE: 8, - DataType.UINT32: 4, - DataType.UINT64: 8, - DataType.COMPLEX64: 8, - DataType.COMPLEX128: 16, - DataType.BFLOAT16: 2, - DataType.FLOAT8E4M3FN: 1, - DataType.FLOAT8E4M3FNUZ: 1, - DataType.FLOAT8E5M2: 1, - DataType.FLOAT8E5M2FNUZ: 1, - DataType.UINT4: 0.5, - DataType.INT4: 0.5, -} - - -_NP_TYPE_TO_DATA_TYPE = { - np.dtype("bool"): DataType.BOOL, - np.dtype("complex128"): DataType.COMPLEX128, - np.dtype("complex64"): DataType.COMPLEX64, - np.dtype("float16"): DataType.FLOAT16, - np.dtype("float32"): DataType.FLOAT, - np.dtype("float64"): DataType.DOUBLE, - np.dtype("int16"): DataType.INT16, - np.dtype("int32"): DataType.INT32, - np.dtype("int64"): DataType.INT64, - np.dtype("int8"): DataType.INT8, - np.dtype("object"): DataType.STRING, - np.dtype("uint16"): DataType.UINT16, - np.dtype("uint32"): DataType.UINT32, - np.dtype("uint64"): DataType.UINT64, - np.dtype("uint8"): DataType.UINT8, -} - -# ONNX DataType to Numpy dtype. This mapping does not capture ONNX data -# types that are not supported by numpy. -_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()} -_DATA_TYPE_TO_NP_TYPE.update( - { - DataType.FLOAT8E4M3FN: np.dtype("uint8"), - DataType.FLOAT8E4M3FNUZ: np.dtype("uint8"), - DataType.FLOAT8E5M2: np.dtype("uint8"), - DataType.FLOAT8E5M2FNUZ: np.dtype("uint8"), - DataType.UINT4: np.dtype("uint8"), - DataType.INT4: np.dtype("int8"), - DataType.BFLOAT16: np.dtype("uint16"), - } -) diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py deleted file mode 100644 index a08debf0bf..0000000000 --- a/onnxscript/ir/_enums_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -import numpy as np -import onnx - -from onnxscript.ir import _enums - - -class DataTypeTest(unittest.TestCase): - def test_enums_are_the_same_as_spec(self): - self.assertEqual(_enums.DataType.FLOAT, onnx.TensorProto.FLOAT) - self.assertEqual(_enums.DataType.UINT8, onnx.TensorProto.UINT8) - self.assertEqual(_enums.DataType.INT8, onnx.TensorProto.INT8) - self.assertEqual(_enums.DataType.UINT16, onnx.TensorProto.UINT16) - self.assertEqual(_enums.DataType.INT16, onnx.TensorProto.INT16) - self.assertEqual(_enums.DataType.INT32, onnx.TensorProto.INT32) - self.assertEqual(_enums.DataType.INT64, onnx.TensorProto.INT64) - self.assertEqual(_enums.DataType.STRING, onnx.TensorProto.STRING) - self.assertEqual(_enums.DataType.BOOL, onnx.TensorProto.BOOL) - self.assertEqual(_enums.DataType.FLOAT16, onnx.TensorProto.FLOAT16) - self.assertEqual(_enums.DataType.DOUBLE, onnx.TensorProto.DOUBLE) - self.assertEqual(_enums.DataType.UINT32, onnx.TensorProto.UINT32) - self.assertEqual(_enums.DataType.UINT64, onnx.TensorProto.UINT64) - self.assertEqual(_enums.DataType.COMPLEX64, onnx.TensorProto.COMPLEX64) - self.assertEqual(_enums.DataType.COMPLEX128, onnx.TensorProto.COMPLEX128) - self.assertEqual(_enums.DataType.BFLOAT16, onnx.TensorProto.BFLOAT16) - self.assertEqual(_enums.DataType.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FN) - self.assertEqual(_enums.DataType.FLOAT8E4M3FNUZ, onnx.TensorProto.FLOAT8E4M3FNUZ) - self.assertEqual(_enums.DataType.FLOAT8E5M2, onnx.TensorProto.FLOAT8E5M2) - self.assertEqual(_enums.DataType.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E5M2FNUZ) - self.assertEqual(_enums.DataType.UINT4, onnx.TensorProto.UINT4) - self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4) - self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED) - - def test_from_numpy_takes_np_dtype_and_returns_data_type(self): - array = np.array([], dtype=np.float64) - self.assertEqual(_enums.DataType.from_numpy(array.dtype), _enums.DataType.DOUBLE) - - def test_numpy_returns_np_dtype(self): - self.assertEqual(_enums.DataType.DOUBLE.numpy(), np.dtype(np.float64)) - - def test_itemsize_returns_size_of_data_type_in_bytes(self): - self.assertEqual(_enums.DataType.DOUBLE.itemsize, 8) - self.assertEqual(_enums.DataType.INT4.itemsize, 0.5) - - def test_repr_and_str_return_name(self): - self.assertEqual(str(_enums.DataType.DOUBLE), "DOUBLE") - self.assertEqual(repr(_enums.DataType.DOUBLE), "DOUBLE") - - -class AttributeTypeTest(unittest.TestCase): - def test_enums_are_the_same_as_spec(self): - self.assertEqual(_enums.AttributeType.FLOAT, onnx.AttributeProto.FLOAT) - self.assertEqual(_enums.AttributeType.INT, onnx.AttributeProto.INT) - self.assertEqual(_enums.AttributeType.STRING, onnx.AttributeProto.STRING) - self.assertEqual(_enums.AttributeType.TENSOR, onnx.AttributeProto.TENSOR) - self.assertEqual(_enums.AttributeType.GRAPH, onnx.AttributeProto.GRAPH) - self.assertEqual(_enums.AttributeType.FLOATS, onnx.AttributeProto.FLOATS) - self.assertEqual(_enums.AttributeType.INTS, onnx.AttributeProto.INTS) - self.assertEqual(_enums.AttributeType.STRINGS, onnx.AttributeProto.STRINGS) - self.assertEqual(_enums.AttributeType.TENSORS, onnx.AttributeProto.TENSORS) - self.assertEqual(_enums.AttributeType.GRAPHS, onnx.AttributeProto.GRAPHS) - self.assertEqual(_enums.AttributeType.SPARSE_TENSOR, onnx.AttributeProto.SPARSE_TENSOR) - self.assertEqual( - _enums.AttributeType.SPARSE_TENSORS, onnx.AttributeProto.SPARSE_TENSORS - ) - self.assertEqual(_enums.AttributeType.TYPE_PROTO, onnx.AttributeProto.TYPE_PROTO) - self.assertEqual(_enums.AttributeType.TYPE_PROTOS, onnx.AttributeProto.TYPE_PROTOS) - self.assertEqual(_enums.AttributeType.UNDEFINED, onnx.AttributeProto.UNDEFINED) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_graph_comparison.py b/onnxscript/ir/_graph_comparison.py deleted file mode 100644 index 788b4b4d54..0000000000 --- a/onnxscript/ir/_graph_comparison.py +++ /dev/null @@ -1,25 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Utilities for comparing IR graphs.""" - -from __future__ import annotations - -from onnxscript.ir import _core - -# NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs -# NOTE(justinchuby): A graph may be specified with a set of inputs and outputs - - -def topologically_equal(graph1: _core.Graph, graph2: _core.Graph) -> bool: - """Return true if the two graphs are topologically equivalent, without considering initializers. - - Args: - graph1: The first graph to compare. - graph2: The second graph to compare. - - Returns: - True if the graphs are equal, False otherwise. - """ - raise NotImplementedError() diff --git a/onnxscript/ir/_invariants.py b/onnxscript/ir/_invariants.py deleted file mode 100644 index 8d009c3cc9..0000000000 --- a/onnxscript/ir/_invariants.py +++ /dev/null @@ -1,60 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Utilities to enforce invariants on the IR.""" - -from __future__ import annotations - -import functools -from typing import Any, Callable - - -class InvariantError(Exception): - """Raised when an invariant is violated.""" - - -class PreconditionError(InvariantError): - """Raised when a precondition is violated.""" - - -class PostconditionError(InvariantError): - """Raised when a postcondition is violated.""" - - -def requires( - preconditions: Callable[..., str | None], -) -> Callable[..., Callable[..., Any]]: - """Decorator to enforce preconditions on a function.""" - # TODO(justinchuby): Preserve python function signature with this decorator - - def decorator(func: Callable[..., None]) -> Callable[..., None]: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> None: - message = preconditions(*args, **kwargs) - if message is not None: - raise PreconditionError(message) - return func(*args, **kwargs) - - return wrapper - - return decorator - - -def ensures( - postconditions: Callable[..., str | None], -) -> Callable[..., Callable[..., Any]]: - """Decorator to enforce postconditions on a function.""" - - def decorator(func: Callable[..., None]) -> Callable[..., None]: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> None: - result = func(*args, **kwargs) - message = postconditions(*args, **kwargs) - if message is not None: - raise PostconditionError(message) - return result - - return wrapper - - return decorator diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py deleted file mode 100644 index 059a88f2b9..0000000000 --- a/onnxscript/ir/_linked_list.py +++ /dev/null @@ -1,278 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Mutable list for nodes in a graph with safe mutation properties.""" - -from __future__ import annotations - -from typing import Generic, Iterable, Iterator, Sequence, TypeVar - -T = TypeVar("T") - - -class _LinkBox(Generic[T]): - """A link in a doubly linked list that has a reference to the actual object in the link. - - The :class:`_LinkBox` is a container for the actual object in the list. It is used to - maintain the links between the elements in the linked list. The actual object is stored in the - :attr:`value` attribute. - - By using a separate container for the actual object, we can safely remove the object from the - list without losing the links. This allows us to remove the object from the list during - iteration and place the object into a different list without breaking any chains. - - This is an internal class and should only be initialized by the :class:`DoublyLinkedSet`. - - Attributes: - prev: The previous box in the list. - next: The next box in the list. - erased: A flag to indicate if the box has been removed from the list. - owning_list: The :class:`DoublyLinkedSet` to which the box belongs. - value: The actual object in the list. - """ - - __slots__ = ("prev", "next", "value", "owning_list") - - def __init__(self, owner: DoublyLinkedSet[T], value: T | None) -> None: - """Create a new link box. - - Args: - owner: The linked list to which this box belongs. - value: The value to be stored in the link box. When the value is None, - the link box is considered erased (default). The root box of the list - should be created with a None value. - """ - self.prev: _LinkBox[T] = self - self.next: _LinkBox[T] = self - self.value: T | None = value - self.owning_list: DoublyLinkedSet[T] = owner - - @property - def erased(self) -> bool: - return self.value is None - - def erase(self) -> None: - """Remove the link from the list and detach the value from the box.""" - if self.value is None: - raise ValueError("_LinkBox is already erased") - # Update the links - prev, next_ = self.prev, self.next - prev.next, next_.prev = next_, prev - # Detach the value - self.value = None - - def __repr__(self) -> str: - return f"_LinkBox({self.value!r}, erased={self.erased}, prev={self.prev.value!r}, next={self.next.value!r})" - - -class DoublyLinkedSet(Generic[T], Sequence[T]): - """A doubly linked ordered set of nodes. - - The container can be viewed as a set as it does not allow duplicate values. The order of the - elements is maintained. One can typically treat it as a doubly linked list with list-like - methods implemented. - - Adding and removing elements from the set during iteration is safe. Moving elements - from one set to another is also safe. - - During the iteration: - - If new elements are inserted after the current node, the iterator will - iterate over them as well. - - If new elements are inserted before the current node, they will - not be iterated over in this iteration. - - If the current node is lifted and inserted in a different location, - iteration will start from the "next" node at the _original_ location. - - Time complexity: - Inserting and removing nodes from the set is O(1). Accessing nodes by index is O(n), - although accessing nodes at either end of the set is O(1). I.e. - ``linked_set[0]`` and ``linked_set[-1]`` are O(1). - - Values need to be hashable. ``None`` is not a valid value in the set. - """ - - __slots__ = ("_root", "_length", "_value_ids_to_boxes") - - def __init__(self, values: Iterable[T] | None = None) -> None: - # Using the root node simplifies the mutation implementation a lot - # The list is circular. The root node is the only node that is not a part of the list values - root_ = _LinkBox(self, None) - self._root: _LinkBox = root_ - self._length = 0 - self._value_ids_to_boxes: dict[int, _LinkBox] = {} - if values is not None: - self.extend(values) - - def __iter__(self) -> Iterator[T]: - """Iterate over the elements in the list. - - - If new elements are inserted after the current node, the iterator will - iterate over them as well. - - If new elements are inserted before the current node, they will - not be iterated over in this iteration. - - If the current node is lifted and inserted in a different location, - iteration will start from the "next" node at the _original_ location. - """ - box = self._root.next - while box is not self._root: - if box.owning_list is not self: - raise RuntimeError(f"Element {box!r} is not in the list") - if not box.erased: - assert box.value is not None - yield box.value - box = box.next - - def __reversed__(self) -> Iterator[T]: - """Iterate over the elements in the list in reverse order.""" - box = self._root.prev - while box is not self._root: - if not box.erased: - assert box.value is not None - yield box.value - box = box.prev - - def __len__(self) -> int: - assert self._length == len( - self._value_ids_to_boxes - ), "Bug in the implementation: length mismatch" - return self._length - - def __getitem__(self, index: int) -> T: - """Get the node at the given index. - - Complexity is O(n). - """ - if index >= self._length or index < -self._length: - raise IndexError( - f"Index out of range: {index} not in range [-{self._length}, {self._length})" - ) - if index < 0: - # Look up from the end of the list - iterator = reversed(self) - item = next(iterator) - for _ in range(-index - 1): - item = next(iterator) - else: - iterator = iter(self) # type: ignore[assignment] - item = next(iterator) - for _ in range(index): - item = next(iterator) - return item - - def _insert_one_after( - self, - box: _LinkBox[T], - new_value: T, - ) -> _LinkBox[T]: - """Insert a new value after the given box. - - All insertion methods should call this method to ensure that the list is updated correctly. - - Example:: - Before: A <-> B <-> C - ^v0 ^v1 ^v2 - Call: _insert_one_after(B, v3) - After: A <-> B <-> new_box <-> C - ^v0 ^v1 ^v3 ^v2 - - Args: - box: The box which the new value is to be inserted. - new_value: The new value to be inserted. - """ - if new_value is None: - raise TypeError(f"{self.__class__.__name__} does not support None values") - if box.value is new_value: - # Do nothing if the new value is the same as the old value - return box - if box.owning_list is not self: - raise ValueError(f"Value {box.value!r} is not in the list") - - if (new_value_id := id(new_value)) in self._value_ids_to_boxes: - # If the value is already in the list, remove it first - self.remove(new_value) - - # Create a new _LinkBox for the new value - new_box = _LinkBox(self, new_value) - # original_box <=> original_next - # becomes - # original_box <=> new_box <=> original_next - original_next = box.next - box.next = new_box - new_box.prev = box - new_box.next = original_next - original_next.prev = new_box - - # Be sure to update the length and mapping - self._length += 1 - self._value_ids_to_boxes[new_value_id] = new_box - - return new_box - - def _insert_many_after( - self, - box: _LinkBox[T], - new_values: Iterable[T], - ): - """Insert multiple new values after the given box.""" - insertion_point = box - for new_value in new_values: - insertion_point = self._insert_one_after(insertion_point, new_value) - - def remove(self, value: T) -> None: - """Remove a node from the list.""" - if (value_id := id(value)) not in self._value_ids_to_boxes: - raise ValueError(f"Value {value!r} is not in the list") - box = self._value_ids_to_boxes[value_id] - # Remove the link box and detach the value from the box - box.erase() - - # Be sure to update the length and mapping - self._length -= 1 - del self._value_ids_to_boxes[value_id] - - def append(self, value: T) -> None: - """Append a node to the list.""" - _ = self._insert_one_after(self._root.prev, value) - - def extend( - self, - values: Iterable[T], - ) -> None: - for value in values: - self.append(value) - - def insert_after( - self, - value: T, - new_values: Iterable[T], - ) -> None: - """Insert new nodes after the given node. - - Args: - value: The value after which the new values are to be inserted. - new_values: The new values to be inserted. - """ - if (value_id := id(value)) not in self._value_ids_to_boxes: - raise ValueError(f"Value {value!r} is not in the list") - insertion_point = self._value_ids_to_boxes[value_id] - return self._insert_many_after(insertion_point, new_values) - - def insert_before( - self, - value: T, - new_values: Iterable[T], - ) -> None: - """Insert new nodes before the given node. - - Args: - value: The value before which the new values are to be inserted. - new_values: The new values to be inserted. - """ - if (value_id := id(value)) not in self._value_ids_to_boxes: - raise ValueError(f"Value {value!r} is not in the list") - insertion_point = self._value_ids_to_boxes[value_id].prev - return self._insert_many_after(insertion_point, new_values) - - def __repr__(self) -> str: - return f"DoublyLinkedSet({list(self)})" diff --git a/onnxscript/ir/_linked_list_test.py b/onnxscript/ir/_linked_list_test.py deleted file mode 100644 index a82b0e172b..0000000000 --- a/onnxscript/ir/_linked_list_test.py +++ /dev/null @@ -1,380 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Unit tests for the _linked_list module.""" - -from __future__ import annotations - -import unittest - -import parameterized - -from onnxscript.ir import _linked_list - - -class _TestElement: - def __init__(self, value): - self.value = value - - def __repr__(self) -> str: - return f"_TestElement({self.value})" - - -class DoublyLinkedSetTest(unittest.TestCase): - def test_empty_list(self): - linked_list = _linked_list.DoublyLinkedSet() - self.assertEqual(len(linked_list), 0) - self.assertEqual(list(linked_list), []) - self.assertEqual(list(reversed(linked_list)), []) - with self.assertRaises(IndexError): - _ = linked_list[0] - with self.assertRaises(IndexError): - _ = linked_list[-1] - - def test_append_single_element(self): - linked_list = _linked_list.DoublyLinkedSet() - elem = _TestElement(0) - linked_list.append(elem) - - self.assertEqual(len(linked_list), 1) - self.assertEqual(linked_list[0], elem) - self.assertEqual(linked_list[-1], elem) - self.assertEqual(list(linked_list), [elem]) - self.assertEqual(list(reversed(linked_list)), [elem]) - with self.assertRaises(IndexError): - _ = linked_list[1] - with self.assertRaises(IndexError): - _ = linked_list[-2] - - def test_append_multiple_elements(self): - linked_list = _linked_list.DoublyLinkedSet() - elems = [_TestElement(i) for i in range(3)] - for elem in elems: - linked_list.append(elem) - - self.assertEqual(len(linked_list), 3) - self.assertEqual(linked_list[0], elems[0]) - self.assertEqual(linked_list[1], elems[1]) - self.assertEqual(linked_list[2], elems[2]) - self.assertEqual(linked_list[-1], elems[2]) - self.assertEqual(linked_list[-2], elems[1]) - self.assertEqual(linked_list[-3], elems[0]) - self.assertEqual(list(linked_list), elems) - self.assertEqual(list(reversed(linked_list)), list(reversed(elems))) - - def test_extend(self): - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - self.assertEqual(len(linked_list), 3) - self.assertEqual(linked_list[0], elems[0]) - self.assertEqual(linked_list[1], elems[1]) - self.assertEqual(linked_list[2], elems[2]) - self.assertEqual(linked_list[-1], elems[2]) - self.assertEqual(linked_list[-2], elems[1]) - self.assertEqual(linked_list[-3], elems[0]) - self.assertEqual(list(linked_list), elems) - self.assertEqual(list(reversed(linked_list)), list(reversed(elems))) - - @parameterized.parameterized.expand( - [ - ("single_element", [0], 0, [1], [0, 1]), - ("single_element_negative_index", [0], -1, [1], [0, 1]), - ("multiple_elements", [0], 0, [1, 2], [0, 1, 2]), - ("multiple_elements_negative_index", [0], -1, [1, 2], [0, 1, 2]), - ( - "multiple_original_elements_insert_at_start", - [0, 1, 2], - 0, - [42, 43], - [0, 42, 43, 1, 2], - ), - ( - "multiple_original_elements_insert_at_middle", - [0, 1, 2], - 1, - [42, 43], - [0, 1, 42, 43, 2], - ), - ( - "multiple_original_elements_insert_at_end", - [0, 1, 2], - 2, - [42, 43], - [0, 1, 2, 42, 43], - ), - ] - ) - def test_insert_after( - self, _: str, original: list[int], location: int, insertion: list[int], expected: list - ) -> None: - # Construct the original list - elems = [_TestElement(i) for i in original] - linked_list = _linked_list.DoublyLinkedSet(elems) - - # Create the new elements - new_elements = [_TestElement(i) for i in insertion] - linked_list.insert_after(elems[location], new_elements) - - # Check the list - self.assertEqual(len(linked_list), len(expected)) - self.assertEqual([elem.value for elem in linked_list], expected) - - @parameterized.parameterized.expand( - [ - ("single_element", [0], 0, [1], [1, 0]), - ("single_element_negative_index", [0], -1, [1], [1, 0]), - ("multiple_elements", [0], 0, [1, 3], [1, 3, 0]), - ("multiple_elements_negative_index", [0], -1, [1, 3], [1, 3, 0]), - ( - "multiple_original_elements_insert_at_start", - [0, 1, 2], - 0, - [42, 43], - [42, 43, 0, 1, 2], - ), - ( - "multiple_original_elements_insert_at_middle", - [0, 1, 2], - 1, - [42, 43], - [0, 42, 43, 1, 2], - ), - ( - "multiple_original_elements_insert_at_end", - [0, 1, 2], - 2, - [42, 43], - [0, 1, 42, 43, 2], - ), - ] - ) - def test_insert_before( - self, _: str, original: list[int], location: int, insertion: list[int], expected: list - ) -> None: - # Construct the original list - elems = [_TestElement(i) for i in original] - linked_list = _linked_list.DoublyLinkedSet(elems) - - # Create the new elements - new_elements = [_TestElement(i) for i in insertion] - linked_list.insert_before(elems[location], new_elements) - - # Check the list - self.assertEqual(len(linked_list), len(expected)) - self.assertEqual([elem.value for elem in linked_list], expected) - self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1]) - - @parameterized.parameterized.expand( - [ - ("start", 0, [1, 2]), - ("middle", 1, [0, 2]), - ("end", 2, [0, 1]), - ("start_negative", -1, [0, 1]), - ("middle_negative", -2, [0, 2]), - ("end_negative", -3, [1, 2]), - ] - ) - def test_remove(self, _: str, index: int, expected: list[int]) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - - linked_list.remove(elems[index]) - - self.assertEqual(len(linked_list), 2) - self.assertEqual([elem.value for elem in linked_list], expected) - self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1]) - - def test_remove_raises_when_element_not_found(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - - with self.assertRaises(ValueError): - linked_list.remove(_TestElement(3)) - - def test_remove_raises_when_element_is_already_removed(self) -> None: - linked_list = _linked_list.DoublyLinkedSet() - elem = _TestElement(0) - linked_list.append(elem) - linked_list.remove(elem) - - with self.assertRaises(ValueError): - linked_list.remove(elem) - - def test_append_self_does_nothing(self) -> None: - linked_list = _linked_list.DoublyLinkedSet() - elem = _TestElement(0) - linked_list.append(elem) - - linked_list.append(elem) - - self.assertEqual(len(linked_list), 1) - self.assertEqual(linked_list[0], elem) - self.assertEqual(list(linked_list), [elem]) - self.assertEqual(list(reversed(linked_list)), [elem]) - - def test_append_supports_appending_element_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - - linked_list.append(elems[1]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [0, 2, 1]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [1, 2, 0]) - - def test_extend_supports_extending_elements_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - linked_list.extend(elems[::-1]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [2, 1, 0]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [0, 1, 2]) - - def test_insert_after_supports_inserting_element_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - linked_list.insert_after(elems[0], [elems[2]]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [0, 2, 1]) - - def test_insert_before_supports_inserting_element_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - linked_list.insert_before(elems[0], [elems[2]]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [2, 0, 1]) - - def test_iterator_supports_mutation_during_iteration_current_element(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - for elem in linked_list: - if elem.value == 1: - linked_list.remove(elem) - - self.assertEqual(len(linked_list), 2) - self.assertEqual([elem.value for elem in linked_list], [0, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0]) - - def test_iterator_supports_mutation_during_iteration_previous_element(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - for elem in linked_list: - if elem.value == 1: - linked_list.remove(elem) - linked_list.remove(elems[0]) - - self.assertEqual(len(linked_list), 1) - self.assertEqual([elem.value for elem in linked_list], [2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2]) - - def test_iterator_supports_mutation_during_iteration_next_element(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - for elem in linked_list: - if elem.value == 1: - linked_list.remove(elems[2]) - linked_list.remove(elem) - - self.assertEqual(len(linked_list), 1) - self.assertEqual([elem.value for elem in linked_list], [0]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [0]) - - def test_iterator_supports_mutation_in_nested_iteration_right_of_iterator(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - iter1_visited = [] - iter2_visited = [] - for elem in linked_list: - iter1_visited.append(elem.value) - for elem2 in linked_list: - iter2_visited.append(elem2.value) - if elem2.value == 1: - linked_list.remove(elem2) - - self.assertEqual(len(linked_list), 2) - self.assertEqual(iter1_visited, [0, 2]) - self.assertEqual(iter2_visited, [0, 1, 2, 0, 2]) - self.assertEqual([elem.value for elem in linked_list], [0, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0]) - - def test_iterator_supports_mutation_in_nested_iteration_when_iter_is_self(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - iter1_visited = [] - iter2_visited = [] - for elem in linked_list: - iter1_visited.append(elem.value) - for elem2 in linked_list: - iter2_visited.append(elem2.value) - if elem2.value == 0: # Remove the element the current iterator points to - linked_list.remove(elem2) - - self.assertEqual(len(linked_list), 2) - self.assertEqual(iter1_visited, [0, 1, 2]) - self.assertEqual(iter2_visited, [0, 1, 2, 1, 2, 1, 2]) - self.assertEqual([elem.value for elem in linked_list], [1, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1]) - - def test_iterator_supports_mutation_in_nested_iteration_left_of_iterator(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - iter1_visited = [] - iter2_visited = [] - for elem in linked_list: - iter1_visited.append(elem.value) - for elem2 in linked_list: - iter2_visited.append(elem2.value) - if ( - elem.value == 1 and elem2.value == 0 - ): # Remove the element before the current iterator points to - linked_list.remove(elems[0]) - - self.assertEqual(len(linked_list), 2) - self.assertEqual(iter1_visited, [0, 1, 2]) - self.assertEqual(iter2_visited, [0, 1, 2, 0, 1, 2, 1, 2]) - self.assertEqual([elem.value for elem in linked_list], [1, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1]) - - def test_insert_after_supports_element_from_different_list_during_iteration(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - other_linked_list = _linked_list.DoublyLinkedSet() - other_elem = _TestElement(42) - other_linked_list.append(other_elem) - - for elem in linked_list: - if elem.value == 1: - linked_list.insert_after(elem, [other_elem]) - - self.assertEqual(len(linked_list), 4) - self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0]) - # Other list remains unchanged - self.assertEqual(len(other_linked_list), 1) - self.assertEqual([elem.value for elem in other_linked_list], [42]) - - def test_insert_after_supports_taking_elements_from_another_doubly_linked_list( - self, - ) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - other_linked_list = _linked_list.DoublyLinkedSet() - other_elem = _TestElement(42) - other_linked_list.append(other_elem) - - linked_list.insert_after(elems[1], other_linked_list) - - self.assertEqual(len(linked_list), 4) - self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0]) - # Other list remains unchanged - self.assertEqual(len(other_linked_list), 1) - self.assertEqual([elem.value for elem in other_linked_list], [42]) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_metadata.py b/onnxscript/ir/_metadata.py deleted file mode 100644 index bbb01a9596..0000000000 --- a/onnxscript/ir/_metadata.py +++ /dev/null @@ -1,46 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Class for storing metadata about the IR objects.""" - -from __future__ import annotations - -import collections -from typing import Any, Mapping - - -class MetadataStore(collections.UserDict): - """Class for storing metadata about the IR objects. - - Metadata is stored as key-value pairs. The keys are strings and the values - can be any Python object. - - The metadata store also supports marking keys as invalid. This is useful - when a pass wants to mark a key that needs to be recomputed. - """ - - def __init__(self, data: Mapping[str, Any] | None = None, /) -> None: - super().__init__(data) - self._invalid_keys: set[str] = set() - - def __setitem__(self, key: str, item: Any) -> None: - self.data[key] = item - self._invalid_keys.discard(key) - - def invalidate(self, key: str) -> None: - self._invalid_keys.add(key) - - def is_valid(self, key: str) -> bool: - """Returns whether the value is valid. - - Note that default values (None) are not necessarily invalid. For example, - a shape that is unknown (None) may be still valid if shape inference has - determined that the shape is unknown. - - Whether a value is valid is solely determined by the user that sets the value. - """ - return key not in self._invalid_keys - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.data!r}, invalid_keys={self._invalid_keys!r})" diff --git a/onnxscript/ir/_name_authority.py b/onnxscript/ir/_name_authority.py deleted file mode 100644 index 856c86247e..0000000000 --- a/onnxscript/ir/_name_authority.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Auxiliary class for managing names in the IR.""" - -from __future__ import annotations - -from onnxscript.ir import _core - - -class NameAuthority: - """Class for giving names to values and nodes in the IR. - - The names are generated in the format ``val_{value_counter}`` for values and - ``node_{op_type}_{node_counter}`` for nodes. The counter is incremented each time - a new value or node is named. - - The class does not keep track of the names it has given, so it is possible to - generate names that conflicts with existing names. It is the responsibility of the - user to ensure that the names are unique (typically by running a name-fixing pass - on the graph). - """ - - def __init__(self): - self._value_counter = 0 - self._node_counter = 0 - - def name_value(self, value: _core.Value) -> None: - value.name = f"val_{self._value_counter}" - self._value_counter += 1 - - def name_node(self, node: _core.Node) -> None: - node.name = f"node_{node.op_type}_{self._node_counter}" - self._node_counter += 1 diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py deleted file mode 100644 index 7e5b791208..0000000000 --- a/onnxscript/ir/_protocols.py +++ /dev/null @@ -1,590 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Protocols for the ONNX IR. - -This file defines the interfaces for tools to interact with the IR. The interfaces -are designed such that tools leveraging the IR can be decoupled from the IR -implementation. This allows for the implementation to evolve independently of the -tools. -""" - -# 👀 -# NOTE: Why are we using protocols, instead of abstract base classes? -# -# Protocols are more flexible than abstract base classes. Users can define their -# own classes that implement the protocols without having to inherit from a -# specific base class. For example, a user can define a custom tensor class that -# implements the TensorProtocol without explicitly inheriting, and the IR can -# work with that class without any changes. -# -# `isinstance` checks can be slower with protocols. Avoid using `isinstance` -# checks when you can. Always check for concrete classes first. -# -# NOTE: Why are we using protocols, instead of using concrete classes directly? -# -# Protocols define the interface that is typically more stable. If you find yourself -# updating the protocols, pause 🛑, and carefully make sure it is absolutely needed -# and will improve the design. If you are adding new methods, consider if the method -# should be part of the protocol or if it should be a higher level convenience function -# defined outside the protocol. - -from __future__ import annotations - -import typing -from typing import ( - Any, - Collection, - Iterable, - Iterator, - Mapping, - MutableMapping, - MutableSequence, - OrderedDict, - Protocol, - Sequence, - Tuple, -) - -from onnxscript.ir import _enums - -if typing.TYPE_CHECKING: - import numpy as np - from typing_extensions import TypeAlias - -# An identifier that will uniquely identify an operator. E.g (domain, op_type, overload) -OperatorIdentifier: TypeAlias = Tuple[str, str, str] - - -@typing.runtime_checkable -class ArrayCompatible(Protocol): - """Protocol for array-like objects. - - An example of an array-like object is a numpy ndarray or a PyTorch Tensor. - Read more at https://numpy.org/devdocs/user/basics.interoperability.html - """ - - def __array__(self, dtype: Any) -> np.ndarray: ... - - -@typing.runtime_checkable -class DLPackCompatible(Protocol): - """Protocol for objects that can support dlpack. - - Computation backends can call __dlpack__ to obtain the underlying data in a - tensor without copying the data. This allows use to use tensorflow tensors etc. - without copying the data. - """ - - def __dlpack__(self, *, stream: Any = ...) -> Any: - """Return PyCapsule.""" - ... - - def __dlpack_device__(self) -> Any: - """Return the device.""" - ... - - -@typing.runtime_checkable -class TensorProtocol(ArrayCompatible, Protocol): - """Concrete tensor backed by data. - - The protocol does not specify how the data is stored. That data is exposed - through the :attr:`raw` attribute for examination, but accessing :attr:`raw` - is typically not needed. - - To use the tensor as a numpy array, call :meth:`numpy`. To convert the tensor - to a byte string for serialization, call :meth:`tobytes`. - - It is recommended to check the size of the tensor first before accessing the - underlying data, because accessing the data may be expensive and incur IO - overhead. - - Attributes: - name: The name of the tensor. - shape: The shape of the tensor. - dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum. - doc_string: Documentation string. - raw: The raw data behind this tensor. It can be anything. - size: The number of elements in the tensor. - nbytes: The number of bytes in the tensor. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str - shape: ShapeProtocol - dtype: _enums.DataType - doc_string: str | None - raw: Any - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - @property - def size(self) -> int: ... - - @property - def nbytes(self) -> int: ... - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array.""" - ... - - def __array__(self, dtype: Any = None) -> np.ndarray: - """Return the tensor as a numpy array, compatible with np.array.""" - ... - - def tobytes(self) -> bytes: - """Return the tensor as a byte string conformed to the ONNX specification, in little endian.""" - ... - - -@typing.runtime_checkable -class ValueProtocol(Protocol): - """Protocol for values. - - A value is a named entity that can be used to represent an input or output of a graph, - a function, or a node. The information it stores generalizes over ``ValueInfoProto`` - in the ONNX specification. - - A :class:`Value` is always not owned or owned by exactly one node. When the value is not - owned, it must be an input of a graph or a function. ``producer`` and ``index`` - are ``None``. - - When the value is owned by a node, it is an output of the node. - The node that produces the value can be accessed with :meth:`producer`. - The index of the output of the node that produces the value can be accessed with - :meth:`index`. - - To find all the nodes that use this value as an input, call :meth:`uses`. - - To check if the value is an output of a graph, call :meth:`is_graph_output`. - - Attributes: - name: The name of the value. A value is always named when it is part of a graph. - shape: The shape of the value. - type: The type of the value. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - doc_string: Documentation string. - """ - - name: str - shape: ShapeProtocol | None - type: TypeProtocol | None - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - doc_string: str | None - - def producer(self) -> NodeProtocol | None: - """The node that produces this value.""" - ... - - def index(self) -> int | None: - """The index of the output of the node that produces this value.""" - ... - - def uses(self) -> Collection[tuple[NodeProtocol, int]]: - """The set of (node, input_index) with node being those that use this value as an input.""" - ... - - def is_graph_output(self) -> bool: - """Whether this value is an output of a graph.""" - ... - - -@typing.runtime_checkable -class NodeProtocol(Protocol): - """Protocol for nodes. - - A node represents an invocation of an operation on the :class:`Value` s in - the computational graph. - - A node can be optionally named. A name should typically be assigned when the - node is added to a graph. - - :attr:`domain`, :attr:`op_type`, and :attr:`overload` together uniquely identify - the operator, and are always strings. For ONNX operators, :attr:`domain` and :attr:`overload` - are both empty strings. - - :attr:`inputs` and :attr:`outputs` are the input and output values of the node. - - :attr:`attributes` are the attributes of the node. The attributes are stored in an - ordered dictionary to preserve the order of the attributes. This is a deviation from - the current ONNX spec where attributes are unordered, but it is helpful for tools - that rely on the order of the attributes, e.g. those converting to and from Python - function keyword arguments. - - :attr:`version` is unique to the IR and is not specified in the ONNX spec. This - allows the IR to represent a graph with mixed opset versions. Deserializers - should decide how to reconcile the different versions within the graph. A typical - graph will have a single version, declared in the :class:`Graph` object and - the nodes will have ``None`` as the version. - - Attributes: - domain: The domain of the operator. E.g. ``""`` for ONNX operators. - op_type: The operator name. - overload: The overload name when the node is invoking a function. - inputs: Input values. - outputs: Output values. - attributes: The attributes of the operator. - version: The version of the operator. - doc_string: Documentation string. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str | None - domain: str - op_type: str - overload: str - inputs: Sequence[ValueProtocol] - outputs: Sequence[ValueProtocol] - attributes: OrderedDict[str, AttributeProtocol | ReferenceAttributeProtocol] - version: int | None - doc_string: str | None - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def replace_input_with(self, index: int, value: ValueProtocol | None) -> None: - """Set the input at the given index to the given value, replacing the original value.""" - ... - - -@typing.runtime_checkable -class GraphProtocol(Protocol): - """Protocol for graphs. - - Graph represents a computation graph. In addition to the ONNX specification - specified fields, it also contains a mapping of :attr:`opset_imports`. This - allows different subgraphs to import different opsets. It is the responsibility - of the deserializer to reconcile the different opsets. - - The nodes are not guaranteed to be topologically sorted. But the - iteration order should be deterministic across different runs. It is the - responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Graph. The Graph can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(graph)``. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - # TODO(justinchuby): Support quantization_annotation - name: str | None - inputs: MutableSequence[ValueProtocol] - outputs: MutableSequence[ValueProtocol] - initializers: MutableMapping[str, TensorProtocol] - doc_string: str - opset_imports: MutableMapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def __getitem__(self, index: int) -> NodeProtocol: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[NodeProtocol]: ... - def __reversed__(self) -> Iterator[NodeProtocol]: ... - - # Mutation methods - def append(self, node: NodeProtocol, /) -> None: - """Append a node to the graph.""" - ... - - def extend(self, nodes: Iterable[NodeProtocol], /) -> None: - """Extend the graph with the given nodes.""" - ... - - def remove(self, node: NodeProtocol, /) -> None: - """Remove a node from the graph.""" - ... - - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: - """Insert new nodes after the given node.""" - ... - - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: - """Insert new nodes before the given node.""" - ... - - def sort(self) -> None: - """Topologically sort the nodes in the graph.""" - ... - - -@typing.runtime_checkable -class GraphViewProtocol(Protocol): - """Protocol for a read-only view on a graph. - - The GraphView is useful for analysis of a subgraph. It can be initialized - with a subset of nodes from a :class:`Graph`. Creating GraphView does not - change the ownership of the nodes, and so it is possible to create multiple - GraphViews that contain the same nodes. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str | None - inputs: Sequence[ValueProtocol] - outputs: Sequence[ValueProtocol] - initializers: Mapping[str, TensorProtocol] - doc_string: str - opset_imports: Mapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def __getitem__(self, index: int) -> NodeProtocol: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[NodeProtocol]: ... - def __reversed__(self) -> Iterator[NodeProtocol]: ... - - -@typing.runtime_checkable -class ModelProtocol(Protocol): - """Protocol for models. - - A model is a container for a graph and metadata. It is the top-level object - that represents an ONNX model. - - Attributes: - graph: The graph of the model. - ir_version: The version of the IR. - producer_name: The name of the producer. - producer_version: The version of the producer. - domain: The domain of the model. - model_version: The version of the model. - doc_string: Documentation string. - functions: The functions defined in the model. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - graph: GraphProtocol - ir_version: int - producer_name: str | None - producer_version: str | None - domain: str | None - model_version: int | None - doc_string: str | None - functions: MutableMapping[str, FunctionProtocol] - # TODO(justinchuby): Add training_info - opset_imports: MutableMapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - -@typing.runtime_checkable -class AttributeProtocol(Protocol): - """Protocol for ONNX attributes. - - Attributes: - name: The name of the attribute. - type: The type of the attribute. - value: The value of the attribute. - doc_string: Documentation string. - """ - - name: str - type: _enums.AttributeType - value: Any - doc_string: str | None - - -@typing.runtime_checkable -class ReferenceAttributeProtocol(Protocol): - """Protocol for a reference attribute. - - A reference attribute can only appear inside the definition body of a function. - - Attributes: - name: The name of the attribute. - ref_attr_name: The name of the attribute definition this attribute refers to. - type: The type of the attribute. - doc_string: Documentation string. - """ - - name: str - ref_attr_name: str - type: _enums.AttributeType - doc_string: str | None - - -@typing.runtime_checkable -class SparseTensorProtocol(Protocol): - values: TensorProtocol - indices: TensorProtocol - dims: Sequence[int] - - -@typing.runtime_checkable -class SymbolicDimProtocol(Protocol): - """Value of a single symbolic/dynamic dimension in a shape. - - Attributes: - value: The value of the dimension. - """ - - value: str | None # TODO(justinchuby): Maybe support sympy - - -@typing.runtime_checkable -class ShapeProtocol(Protocol): - """Protocol for ONNX shapes. - - A shape is a sequence of dimensions. - - Attributes: - dims: The dimensions of the shape. - """ - - dims: Sequence[int | SymbolicDimProtocol] - - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[int | SymbolicDimProtocol]: ... - @typing.overload - def __getitem__(self, index: int) -> int | SymbolicDimProtocol: ... - @typing.overload - def __getitem__(self, index: slice) -> tuple[int | SymbolicDimProtocol, ...]: ... - def __setitem__( - self, index: int, value: int | SymbolicDimProtocol | str | None - ) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __ne__(self, value: object) -> bool: ... - def get_denotation(self, index: int) -> str | None: ... - def set_denotation(self, index: int, denotation: str | None) -> None: ... - def numpy(self) -> Sequence[int]: ... - def rank(self) -> int: ... - - -@typing.runtime_checkable -class TypeProtocol(Protocol): - """Protocol for ONNX tensors, Sequence tensors, Optional tensors and Sparse tensors. - - These three types of tensors share the same attribute "elem_type" so they are - merged in the same interface. Unlike the ONNX TensorProto, shapes are not included - in the type and should be stored in the :class:`Value`. - - Attributes: - denotation: An optional denotation can be used to denote the whole - type with a standard semantic description as to what is - stored inside. - Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition - for pre-defined type denotations. - elem_type: The type of its elements for nested types like Sequence[Optional] tensors. - Or the DataType if the type is not nested. - dtype: The data type of the tensor or the nested tensor. - """ - - denotation: str | None - elem_type: TypeProtocol | _enums.DataType - dtype: _enums.DataType - - def __eq__(self, __value: object) -> bool: ... - - -@typing.runtime_checkable -class MapTypeProtocol(Protocol): - """Protocol for ONNX map types. - - TODO: This protocol is not yet implemented in the ONNX IR. - """ - - key_type: typing.Literal[ - _enums.DataType.STRING, - _enums.DataType.INT64, - _enums.DataType.INT32, - _enums.DataType.INT16, - _enums.DataType.INT8, - _enums.DataType.UINT64, - _enums.DataType.UINT32, - _enums.DataType.UINT16, - _enums.DataType.UINT8, - ] - value_type: _enums.DataType - - -@typing.runtime_checkable -class FunctionProtocol(Protocol): - """Protocol for ONNX functions. - - Like a graph, a function can have nodes that are not topologically sorted. It is - the responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Function. The Function can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(function)``. - - Attributes: - name: The function name. - domain: The domain this function is defined in. - overload: The overload name when the function is overloaded. - inputs: The input values of the function. - attributes: The attributes this function defines. - outputs: The output values of the function. - opset_imports: Opsets imported by the function. - doc_string: Documentation string. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str - domain: str - overload: str - inputs: Sequence[ValueProtocol] - attributes: OrderedDict[str, AttributeProtocol] - outputs: Sequence[ValueProtocol] - doc_string: str - opset_imports: MutableMapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def __getitem__(self, index: int) -> NodeProtocol: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[NodeProtocol]: ... - def __reversed__(self) -> Iterator[NodeProtocol]: ... - def identifier(self) -> OperatorIdentifier: - """Return the unique identifier of the function.""" - ... - - # Mutation methods - # End Block - def append(self, node: NodeProtocol, /) -> None: - """Append a node to the function.""" - ... - - def extend(self, nodes: Iterable[NodeProtocol], /) -> None: - """Extend the function with the given nodes.""" - ... - - def remove(self, node: NodeProtocol, /) -> None: - """Remove a node from the function.""" - ... - - def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: - """Insert new nodes after the given node.""" - ... - - def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: - """Insert new nodes before the given node.""" - ... - - def sort(self) -> None: - """Topologically sort the nodes in the function.""" - ... diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py new file mode 100644 index 0000000000..d4d88ab5bb --- /dev/null +++ b/onnxscript/ir/_schemas.py @@ -0,0 +1,548 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import collections.abc +import dataclasses +import inspect +import logging +import types +import typing +from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union + +import onnx + +import onnxscript +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +# A special value to indicate that the default value is not specified +class _Empty: + def __repr__(self): + return "_EMPTY_DEFAULT" + + +_EMPTY_DEFAULT = _Empty() + +# Map from python type to corresponding ONNX AttributeProto type +_PY_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOAT, + int: ir.AttributeType.INT, + str: ir.AttributeType.STRING, + bool: ir.AttributeType.INT, + ir.Tensor: ir.AttributeType.TENSOR, + ir.TensorProtocol: ir.AttributeType.TENSOR, + ir.Graph: ir.AttributeType.GRAPH, + ir.GraphProtocol: ir.AttributeType.GRAPH, +} + +# Map from python type to corresponding ONNX AttributeProto type, +# for repeated (i.e., list of) values +_LIST_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOATS, + int: ir.AttributeType.INTS, + str: ir.AttributeType.STRINGS, + bool: ir.AttributeType.INTS, + ir.Tensor: ir.AttributeType.TENSORS, + ir.TensorProtocol: ir.AttributeType.TENSORS, + ir.Graph: ir.AttributeType.GRAPHS, + ir.GraphProtocol: ir.AttributeType.GRAPHS, +} + +_ALL_VALUE_TYPES = ( + {ir.TensorType(dtype) for dtype in ir.DataType} + | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} + | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} +) + +# TypeAnnotationValue represents the (value of) valid type-annotations recognized +# by ONNX Script. Currently, it supports +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached +TypeAnnotationValue = Any + + +@dataclasses.dataclass(frozen=True) +class TypeConstraintParam: + """Type constraint for a parameter. + + Attributes: + name: Name of the parameter. E.g. "TFloat" + allowed_types: Allowed types for the parameter. + """ + + name: str + allowed_types: set[ir.TypeProtocol] + description: str = "" + + def __hash__(self) -> int: + return hash((self.name, tuple(self.allowed_types))) + + def __str__(self) -> str: + allowed_types_str = " | ".join(str(t) for t in self.allowed_types) + return f"{self.name}={allowed_types_str}" + + @classmethod + def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) + + @classmethod + def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] + + +@dataclasses.dataclass(frozen=True) +class Parameter: + """A formal parameter of an operator.""" + + name: str + type_constraint: TypeConstraintParam + required: bool + variadic: bool + default: Any = _EMPTY_DEFAULT + # TODO: Add other properties too + + def __str__(self) -> str: + type_str = self.type_constraint.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not _EMPTY_DEFAULT + + +@dataclasses.dataclass(frozen=True) +class AttributeParameter: + """A parameter in the function signature that represents an ONNX attribute.""" + + name: str + type: ir.AttributeType + required: bool + default: ir.Attr | None = None + + def __str__(self) -> str: + type_str = self.type.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not None + + +def _get_type_from_str( + type_str: str, +) -> ir.TensorType | ir.SequenceType | ir.OptionalType: + """Converter a type_str from ONNX OpSchema to ir.TypeProtocol. + + A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". + """ + # Split the type_str a sequence types and dtypes + # 1. Remove the ending ")" + striped = type_str.rstrip(")") + # 2. Split the type_str by "(" + type_parts = striped.split("(") + + # Convert the dtype to ir.DataType + dtype = ir.DataType[type_parts[-1].upper()] + + # Create a place holder type first + type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) + + # Construct the type + for type_part in reversed(type_parts[:-1]): + if type_part == "tensor": + type_ = ir.TensorType(dtype) + elif type_part == "seq": + type_ = ir.SequenceType(type_) + elif type_part == "optional": + type_ = ir.OptionalType(type_) + else: + raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") + return type_ # type: ignore[return-value] + + +def _convert_formal_parameter( + param: onnx.defs.OpSchema.FormalParameter, + type_constraints: Mapping[str, TypeConstraintParam], +) -> Parameter: + """Convert a formal parameter from ONNX OpSchema to Parameter.""" + if param.type_str in type_constraints: + type_constraint = type_constraints[param.type_str] + else: + # param.type_str can be a plain type like 'int64'. + type_constraint = TypeConstraintParam( + name=param.name, + allowed_types={_get_type_from_str(param.type_str)}, + ) + return Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, + variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + ) + + +def _is_optional(type_: type) -> bool: + """Returns whether a type_ is an Optional.""" + origin_type = typing.get_origin(type_) + if origin_type is Union and type(None) in typing.get_args(type_): + # Python < 3.10 + return True + if origin_type is Optional: + # Python >= 3.10 + return True + if ( + hasattr(types, "UnionType") + and origin_type is types.UnionType + and type(None) in typing.get_args(type_) + ): + # Python >= 3.10 + return True + return False + + +def _get_attr_type(type_: type) -> ir.AttributeType: + """Obtain the type of the attribute from a Python class.""" + try: + if type_ in _PY_TYPE_TO_ATTR_TYPE: + return _PY_TYPE_TO_ATTR_TYPE[type_] + origin_type = typing.get_origin(type_) + if origin_type is None: + return ir.AttributeType.UNDEFINED + if origin_type in ( + collections.abc.Sequence, + Sequence, + typing.List, + list, + typing.Tuple, + tuple, + ): + inner_type = typing.get_args(type_)[0] + if inner_type in _LIST_TYPE_TO_ATTR_TYPE: + return _LIST_TYPE_TO_ATTR_TYPE[inner_type] + except TypeError: + logger.warning("TypeError when checking %s.", type_, exc_info=True) + return ir.AttributeType.UNDEFINED + + +def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: + """Returns the name of the type constraint for a given type annotation. + + Args: + type_: A Python type. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + """ + if isinstance(type_, TypeVar): + return type_.__name__ + if _is_optional(type_): + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = _get_type_constraint_name(subtype) + return type_param_name if type_param_name else None + origin_type = typing.get_origin(type_) + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + type_param_name = _get_type_constraint_name(subtypes[0]) + return f"Sequence_{type_param_name}" if type_param_name else None + return None + + +def _get_allowed_types_from_type_annotation( + type_: TypeAnnotationValue, +) -> set[ir.TypeProtocol]: + """Obtain the allowed types from a type annotation.""" + if type_ is onnxscript.onnx_types.TensorType: + # Any tensor type + return {ir.TensorType(dtype) for dtype in ir.DataType} + + allowed_types: set[ir.TypeProtocol] + + if isinstance(type_, TypeVar): + allowed_types = set() + if constraints := type_.__constraints__: + for constraint in constraints: + allowed_types.update(_get_allowed_types_from_type_annotation(constraint)) + else: + bound = type_.__bound__ + if bound is None: + allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] + else: + allowed_types.update(_get_allowed_types_from_type_annotation(bound)) + return allowed_types + if hasattr(type_, "dtype"): + # A single tensor type like INT64, FLOAT, etc. + return {ir.TensorType(ir.DataType(type_.dtype))} + if _is_optional(type_): + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. + return allowed_types + + origin_type = typing.get_origin(type_) + if origin_type is Union: + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + assert subtype is not type(None), ( + "Union should not contain None type because it is handled by _is_optional." + ) + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + return allowed_types + + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + return { + ir.SequenceType(t) for t in _get_allowed_types_from_type_annotation(subtypes[0]) + } + + # Allow everything by default + return _ALL_VALUE_TYPES # type: ignore[return-value] + + +@dataclasses.dataclass +class OpSignature: + """Schema for an operator. + + Attributes: + domain: Domain of the operator. E.g. "". + name: Name of the operator. E.g. "Add". + overload: Overload name of the operator. + params: Input parameters. When the op is an ONNX function definition, + the order is according to the function signature. This mean we can + interleave ONNX inputs and ONNX attributes in the list. + outputs: Output parameters. + """ + + domain: str + name: str + overload: str + params: Sequence[Parameter | AttributeParameter] + outputs: Sequence[Parameter] + params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( + init=False, repr=False + ) + + def __post_init__(self): + self.params_map = {param.name: param for param in self.params} + + def get(self, name: str) -> Parameter | AttributeParameter: + return self.params_map[name] + + def __contains__(self, name: str) -> bool: + return name in self.params_map + + def __iter__(self) -> Iterator[Parameter | AttributeParameter]: + return iter(self.params) + + def __str__(self) -> str: + domain = self.domain or "''" + # TODO: Double check the separator for overload + overload = f"::{self.overload}" if self.overload else "" + params = ", ".join(str(param) for param in self.params) + outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) + type_constraints = {} + for param in self.params: + if isinstance(param, Parameter): + type_constraints[param.type_constraint.name] = param.type_constraint + for param in self.outputs: + type_constraints[param.type_constraint.name] = param.type_constraint + type_constraints_str = ", ".join( + str(type_constraint) for type_constraint in type_constraints.values() + ) + return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" + + @classmethod + def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: + """Produce an OpSignature from an ONNX OpSchema.""" + type_constraints = { + constraint.type_param_str: TypeConstraintParam( + name=constraint.type_param_str, + allowed_types={ + _get_type_from_str(type_str) for type_str in constraint.allowed_type_strs + }, + description=constraint.description, + ) + for constraint in op_schema.type_constraints + } + + params = [ + _convert_formal_parameter(param, type_constraints) for param in op_schema.inputs + ] + + for param in op_schema.attributes.values(): + default_attr = ( + ir.serde.deserialize_attribute(param.default_value) + if param.default_value is not None + else None + ) + if default_attr is not None: + # Set the name of the default attribute because it may have a different name from the parameter + default_attr.name = param.name + params.append( + AttributeParameter( + name=param.name, + type=ir.AttributeType(param.type), # type: ignore[arg-type] + required=param.required, + default=default_attr, # type: ignore[arg-type] + ) + ) + + outputs = [ + _convert_formal_parameter(param, type_constraints) for param in op_schema.outputs + ] + + return cls( + domain=op_schema.domain, + name=op_schema.name, + overload="", + params=params, + outputs=outputs, + ) + + @classmethod + def from_function( + cls, func, domain: str, name: str | None = None, overload: str = "" + ) -> OpSignature: + """Produce an OpSignature from a function using type annotation.""" + + py_signature = inspect.signature(func) + # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases + # https://github.com/python/cpython/issues/102405 + type_hints = typing.get_type_hints(func) + + params: list[Parameter | AttributeParameter] = [] + # Create a mapping from type to a unique name + type_constraints: dict[str, TypeConstraintParam] = {} + + for param in py_signature.parameters.values(): + if param.name not in type_hints: + logger.warning( + "Missing annotation for parameter '%s' from %s. Treating as an Input.", + param.name, + py_signature, + ) + type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") + type_constraints[param.name] = type_constraint + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + else: + type_ = type_hints[param.name] + if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + # Construct the default attribute + if param.default is not inspect.Parameter.empty: + # TODO: Use ir_convenience instead to handle int as float + default = ir.Attr(param.name, attr_type, param.default) + else: + default = None + params.append( + AttributeParameter( + name=param.name, + type=attr_type, + required=param.default is inspect.Parameter.empty, + default=default, + ) + ) + else: + # Obtain the type constraint from the type annotation + + # 1. Get a type constraint name from the type annotation + # If the type annotation is a TypeVar or Optional[TypeVar], get its name + # Otherwise, name it T_{param.name} + type_constraint_name = _get_type_constraint_name(type_) + if type_constraint_name is None: + type_constraint_name = f"T_{param.name}" + + # 2. If the type constraint param is already initialized, use it + if type_constraint_name in type_constraints: + type_constraint = type_constraints[type_constraint_name] + else: + # 3. Otherwise, create a new TypeConstraintParam + type_constraint = TypeConstraintParam( + name=type_constraint_name, + allowed_types=_get_allowed_types_from_type_annotation(type_), + ) + type_constraints[type_constraint_name] = type_constraint + # 4. Create Parameter + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + + return_type = type_hints.get("return") + + outputs = [] + if return_type is None: + # No returns + pass + else: + if typing.get_origin(return_type) is tuple: + # Multiple returns + return_types = typing.get_args(return_type) + else: + return_types = [return_type] # type: ignore[assignment] + + for i, return_type_i in enumerate(return_types): + if ( + return_param_name := _get_type_constraint_name(return_type_i) + ) in type_constraints: + type_constraint = type_constraints[return_param_name] + else: + return_param_name = f"TReturn{i}" + type_constraint = TypeConstraintParam( + name=return_param_name, + allowed_types=_get_allowed_types_from_type_annotation(return_type_i), + ) + type_constraints[return_param_name] = type_constraint + outputs.append( + Parameter( + name=return_param_name, + type_constraint=type_constraint, + required=True, + variadic=False, + default=_EMPTY_DEFAULT, + ) + ) + + return cls( + domain=domain, + name=name or func.__name__, + overload=overload, + params=params, + outputs=outputs, + ) diff --git a/onnxscript/ir/_schemas_test.py b/onnxscript/ir/_schemas_test.py new file mode 100644 index 0000000000..82082d031f --- /dev/null +++ b/onnxscript/ir/_schemas_test.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from typing import Any, Optional, Sequence, TypeVar, Union + +import parameterized + +import onnxscript +from onnxscript import FLOAT, INT64, ir +from onnxscript.ir import _schemas + +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) + + +class TypeConversionFunctionsTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ( + "tensor_type_all", + onnxscript.onnx_types.TensorType, + {ir.TensorType(dtype) for dtype in ir.DataType}, + ), + ("tensor_type", INT64, {ir.TensorType(ir.DataType.INT64)}), + ( + "tensor_type_union", + Union[INT64, FLOAT], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "tensor_type_variadic_shape", + INT64[...], + {ir.TensorType(ir.DataType.INT64)}, + ), + ("tensor_type_shape", INT64[10], {ir.TensorType(ir.DataType.INT64)}), + ( + "type_var_constraints", + _TestTypeVarConstraints, + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "type_bound_one", + _TestTypeVarOneBound, + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "type_bound_two", + _TestTypeVarTwoBound, + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_tensor_type_all", + Optional[onnxscript.onnx_types.TensorType], + {ir.TensorType(dtype) for dtype in ir.DataType}, + ), + ( + "optional_tensor_type", + Optional[INT64], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_tensor_type_union", + Optional[Union[INT64, FLOAT]], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_tensor_type_variadic_shape", + Optional[INT64[...]], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_tensor_type_shape", + Optional[INT64[10]], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_type_var_constraints", + Optional[_TestTypeVarConstraints], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_type_bound_one", + Optional[_TestTypeVarOneBound], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_type_bound_two", + Optional[_TestTypeVarTwoBound], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "sequence_type_all", + Sequence[onnxscript.onnx_types.TensorType], + {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}, + ), + ( + "sequence_type", + Sequence[INT64], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "union_sequence_type", + Union[Sequence[INT64], Sequence[FLOAT]], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ( + "sequence_type_variadic_shape", + Sequence[INT64[...]], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_shape", + Sequence[INT64[10]], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_var_constraints", + Sequence[_TestTypeVarConstraints], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ( + "sequence_type_bound_one", + Sequence[_TestTypeVarOneBound], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_bound_two", + Sequence[_TestTypeVarTwoBound], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ] + ) + def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]): + self.assertEqual(_schemas._get_allowed_types_from_type_annotation(pytype), expected) # pylint: disable=protected-access + + @parameterized.parameterized.expand( + [ + ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), + ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), + ( + "optional_type_var", + Optional[_TestTypeVarOneBound], + "_TestTypeVarOneBound", + ), + ( + "sequence_type_var", + Sequence[_TestTypeVarOneBound], + "Sequence__TestTypeVarOneBound", + ), + ("normal_type", INT64, None), + ("union_type", Union[INT64, FLOAT], None), + ("optional_type", Optional[INT64], None), + ("sequence_type", Sequence[INT64], None), + ("optional_sequence_type", Optional[Sequence[INT64]], None), + ("optional_union_type", Optional[Union[INT64, FLOAT]], None), + ] + ) + def test_get_type_constraint_name(self, _: str, pytype: Any, expected: str | None): + self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) # pylint: disable=protected-access + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py new file mode 100644 index 0000000000..78dce2739e --- /dev/null +++ b/onnxscript/ir/_tape.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Convenience methods for constructing the IR.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Sequence + +from onnx_ir import tape + +if TYPE_CHECKING: + import onnx_ir as ir + + +# A type representing the domains/versions used in creating nodes in IR. +UsedOpsets = set[tuple[str, Optional[int]]] + + +class Builder(tape.Tape): + """An extension of the tape that provides a more convenient API for constructing the IR. + + Example: + >>> from onnxscript import ir + >>> from onnxscript.ir import _tape + >>> op = _tape.Builder() + >>> input = ir.Value(name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))) + >>> relu_val = op.Relu(input, _name="relu_node", _domain="", _version=18, _outputs=["relu_out"]) + + Note: When passing `_name`, ensure it is unique to avoid duplicate node names. + """ + + def __getattr__(self, op_type: str) -> Any: + return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) + + def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): + domain = kwargs.pop("_domain", "") + version = kwargs.pop("_version", None) + outputs = kwargs.pop("_outputs", 1) + name = kwargs.pop("_name", None) + + if isinstance(outputs, Sequence): + num_outputs = len(outputs) + else: + assert isinstance(outputs, int) + num_outputs = outputs + + if num_outputs == 1: + value = super().op( + op_type, + inputs=inputs, + attributes=kwargs, + domain=domain, + version=version, + name=name, + ) + if isinstance(outputs, Sequence): + value.name = outputs[0] + return value + values = super().op_multi_out( + op_type, + inputs=inputs, + attributes=kwargs, + domain=domain, + version=version, + name=name, + num_outputs=num_outputs, + ) + if isinstance(outputs, Sequence): + for value, name in zip(values, outputs): + value.name = name + return values diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py new file mode 100644 index 0000000000..f8210e7a0b --- /dev/null +++ b/onnxscript/ir/_tape_test.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +from onnxscript import ir +from onnxscript.ir import _tape + + +class TestTape(unittest.TestCase): + def test_op(self): + # Create a simple ONNX model with shape inference + # Define the model + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + _ = tape.op("Add", inputs=inputs) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add"]) + + def test_initializers(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + # Shape and type are not explicitly set for the initializer but it should still work + initializer = tape.initializer( + ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT), name="initializer" + ) + val_add = tape.op("Add", inputs=inputs) + _ = tape.op("Mul", inputs=[val_add, initializer]) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add", "Mul"]) + self.assertEqual(tape.initializers, (initializer,)) + + def test_op_multi_out(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + out1, out2, out3 = tape.op_multi_out("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking + _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) + + self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) + + +class TestBuilder(unittest.TestCase): + def test_op_name(self): + op = _tape.Builder() + + input_a = ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + input_b = ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + + add = op.Add(input_a, input_b, _name="add_node") + _ = op.Relu(add, _name="relu_node") + self.assertEqual(op.nodes[0].name, "add_node") + self.assertEqual(op.nodes[1].name, "relu_node") + + def test_op_name_multi_out(self): + op = _tape.Builder() + + input_a = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + + _ = op.CustomOp(input_a, _name="custom_node", _outputs=3) + self.assertEqual(op.nodes[0].name, "custom_node") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py deleted file mode 100644 index abe825f84b..0000000000 --- a/onnxscript/ir/_type_casting.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Numpy utilities for non-native type operation.""" -# TODO(justinchuby): Upstream the logic to onnx - -from __future__ import annotations - -import typing -from typing import Sequence - -import numpy as np - -if typing.TYPE_CHECKING: - import numpy.typing as npt - - -def pack_int4(array: np.ndarray) -> npt.NDArray[np.uint8]: - """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range.""" - # Create a 1D copy - array_flat = array.ravel().astype(np.uint8) - size = array.size - odd_sized = size % 2 == 1 - if odd_sized: - array_flat.resize([size + 1], refcheck=False) - array_flat &= 0x0F - array_flat[1::2] <<= 4 - return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type] - - -def unpack_uint4(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]: - """Convert a packed uint4 array to unpacked uint4 array represented as uint8. - - Args: - data: A numpy array. - dims: The dimensions are used to reshape the unpacked buffer. - - Returns: - A numpy array of int8/uint8 reshaped to dims. - """ - result = np.empty([data.size * 2], dtype=data.dtype) - array_low = data & np.uint8(0x0F) - array_high = data & np.uint8(0xF0) - array_high >>= np.uint8(4) - result[0::2] = array_low - result[1::2] = array_high - if result.size == np.prod(dims) + 1: - # handle single-element padding due to odd number of elements - result = result[:-1] - result.resize(dims, refcheck=False) - return result - - -def _int4_to_int8(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]: - """Extend 4-bit signed integer to 8-bit signed integer.""" - return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8) - - -def unpack_int4(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.int8]: - """Convert a packed (signed) int4 array to unpacked int4 array represented as int8. - - The sign bit is extended to the most significant bit of the int8. - - Args: - data: A numpy array. - dims: The dimensions are used to reshape the unpacked buffer. - - Returns: - A numpy array of int8 reshaped to dims. - """ - unpacked = unpack_uint4(data, dims) - return _int4_to_int8(unpacked) - - -def float32_to_bfloat16(array: npt.NDArray[np.float32]) -> npt.NDArray[np.uint16]: - """Convert a numpy array to uint16 representation of bfloat16.""" - bfloat16_array = array.astype(np.float32).view(np.uint32) - # Drop bottom 16-bits - # Round remaining bits using round-to-nearest-even - rounded = bfloat16_array >> 16 - rounded &= 1 - rounded += 0x7FFF - bfloat16_array += rounded # type: ignore[arg-type] - bfloat16_array >>= 16 - # NaN requires at least 1 significant bit set - bfloat16_array[np.isnan(array)] = 0x7FC0 # sign=0, exp=all-ones, sig=0b1000000 - return bfloat16_array.astype(np.uint16) diff --git a/onnxscript/ir/_type_casting_test.py b/onnxscript/ir/_type_casting_test.py deleted file mode 100644 index 544146e6b1..0000000000 --- a/onnxscript/ir/_type_casting_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -import numpy as np -import parameterized - -from onnxscript.ir import _type_casting - - -class TypeCastingTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ("signed", np.float32), - ("unsigned", np.uint32), - ] - ) - def test_pack_int4_even_sized_array(self, _: str, dtype): - array = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype) - expected = np.array([0x21, 0x43, 0x65, 0x87], dtype=np.uint8) - actual = _type_casting.pack_int4(array) - np.testing.assert_array_equal(actual, expected) - - @parameterized.parameterized.expand( - [ - ("signed", np.float32), - ("unsigned", np.uint32), - ] - ) - def test_pack_int4_odd_sized_array(self, _: str, dtype): - array = np.array([1, 2, 3, 4, 5], dtype=dtype) - expected = np.array([0x21, 0x43, 0x5], dtype=np.uint8) - actual = _type_casting.pack_int4(array) - np.testing.assert_array_equal(actual, expected) - - @parameterized.parameterized.expand( - [ - ("signed", np.float32), - ("unsigned", np.uint32), - ] - ) - def test_pack_int4_returns_flatten_array(self, _: str, dtype): - array = np.array([[[1, 2, 3, 4, 5]]], dtype=dtype) - expected = np.array([0x21, 0x43, 0x5], dtype=np.uint8) - actual = _type_casting.pack_int4(array) - np.testing.assert_array_equal(actual, expected) - - @parameterized.parameterized.expand( - [ - ("negative_infinity", np.uint16(0b1_11111111_0000000)), - ("negative_min_normal", np.uint16(0b1_11111110_1111111)), - ("negative_max_normal", np.uint16(0b1_00000001_0000000)), - ("negative_min_subnormal", np.uint16(0b1_00000000_1111111)), - ("negative_max_subnormal", np.uint16(0b1_00000000_0000001)), - ("negative_zero", np.uint16(0b1_00000000_0000000)), - ("positive_zero", np.uint16(0b0_00000000_0000000)), - ("positive_min_subnormal", np.uint16(0b0_00000000_0000001)), - ("positive_max_subnormal", np.uint16(0b0_00000000_1111111)), - ("positive_min_normal", np.uint16(0b0_00000001_0000000)), - ("positive_max_normal", np.uint16(0b0_11111110_1111111)), - ("positive_infinity", np.uint16(0b0_11111111_0000000)), - ("positive_nan", np.uint16(0b0_11111111_1000000)), - ("positive_one", np.uint16(0b0_00111111_0000000)), - ("negative_one", np.uint16(0b1_00111111_0000000)), - ] - ) - def test_float32_to_bfloat16(self, _: str, binary: np.uint16): - value = np.array([binary << 16]).astype(np.uint32).view(np.float32) - expected = np.array([binary]) - actual = _type_casting.float32_to_bfloat16(value) - np.testing.assert_array_equal(actual, expected) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py new file mode 100644 index 0000000000..e248a5fa84 --- /dev/null +++ b/onnxscript/ir/convenience.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=wildcard-import,unused-wildcard-import +from onnx_ir.convenience import * # type: ignore # noqa: F403 diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py new file mode 100644 index 0000000000..5310f1740a --- /dev/null +++ b/onnxscript/ir/passes/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +__all__ = [ + "PassBase", + "PassResult", + "PassManager", + "Sequential", + "InPlacePass", + "FunctionalPass", + # Errors + "InvariantError", + "PreconditionError", + "PostconditionError", + "PassError", +] + +from onnx_ir.passes import ( + FunctionalPass, + InPlacePass, + InvariantError, + PassBase, + PassError, + PassManager, + PassResult, + PostconditionError, + PreconditionError, + Sequential, +) diff --git a/onnxscript/ir/passes/common/__init__.py b/onnxscript/ir/passes/common/__init__.py new file mode 100644 index 0000000000..5a5ddbe52f --- /dev/null +++ b/onnxscript/ir/passes/common/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +__all__ = [ + "AddInitializersToInputsPass", + "CheckerPass", + "ClearMetadataAndDocStringPass", + "CommonSubexpressionEliminationPass", + "InlinePass", + "LiftConstantsToInitializersPass", + "LiftSubgraphInitializersToMainGraphPass", + "RemoveInitializersFromInputsPass", + "RemoveUnusedFunctionsPass", + "RemoveUnusedNodesPass", + "RemoveUnusedOpsetsPass", + "ShapeInferencePass", + "TopologicalSortPass", +] + +from onnx_ir.passes.common import ( + AddInitializersToInputsPass, + CheckerPass, + ClearMetadataAndDocStringPass, + CommonSubexpressionEliminationPass, + InlinePass, + LiftConstantsToInitializersPass, + LiftSubgraphInitializersToMainGraphPass, + RemoveInitializersFromInputsPass, + RemoveUnusedFunctionsPass, + RemoveUnusedNodesPass, + RemoveUnusedOpsetsPass, + ShapeInferencePass, + TopologicalSortPass, +) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py deleted file mode 100644 index 6060c881bc..0000000000 --- a/onnxscript/ir/serde.py +++ /dev/null @@ -1,1350 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Serialize and deserialize the intermediate representation to/from ONNX protos.""" - -# NOTES for developers: -# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead. -# -# NOTE: Protobuf serialization -# Initializing a protobuf message with initialized protobuf messages incurs -# a copy and is slow. Instead, use proto.add() to add to a repeated field. -# or initialize the message first and then set the fields if the fields are -# plain Python objects. - -from __future__ import annotations - -__all__ = [ - # Tensors - "TensorProtoTensor", - # Deserialization - "deserialize_attribute", - "deserialize_function", - "deserialize_graph", - "deserialize_model", - "deserialize_node", - "deserialize_opset_import", - "deserialize_tensor", - "deserialize_type_proto_for_shape", - "deserialize_type_proto_for_type", - "deserialize_value_info_proto", - # Serialization - "serialize_attribute_into", - "serialize_attribute", - "serialize_dimension_into", - "serialize_function_into", - "serialize_function", - "serialize_graph_into", - "serialize_graph", - "serialize_model_into", - "serialize_model", - "serialize_node_into", - "serialize_node", - "serialize_shape_into", - "serialize_reference_attribute_into", - "serialize_tensor_into", - "serialize_tensor", - "serialize_type_into", - "serialize_value_into", - "serialize_value", -] - -import collections -import logging -import os -import typing -from typing import Any, List, Mapping, Sequence - -import numpy as np -import onnx -import onnx.external_data_helper - -from onnxscript.ir import _core, _enums, _metadata, _protocols, _type_casting - -if typing.TYPE_CHECKING: - import google.protobuf.internal.containers as proto_containers - import numpy.typing as npt - -logger = logging.getLogger(__name__) - -_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = ( - 10 # ONNX IR version where value info in functions was introduced -) - - -def _little_endian_dtype(dtype) -> np.dtype: - """Create a small endian dtype on all platforms. - - This is useful because ONNX always stores raw_data in small endian. On big - endian platforms, we still need to interpret the raw_data in small endian. - """ - return np.dtype(dtype).newbyteorder("<") - - -def _unflatten_complex( - array: npt.NDArray[np.float32 | np.float64], -) -> npt.NDArray[np.complex64 | np.complex128]: - """Convert the real representation of a complex dtype to the complex dtype.""" - return array[::2] + 1j * array[1::2] - - -class TensorProtoTensor(_core.TensorBase): - """A tensor initialized from a tensor proto.""" - - def __init__(self, proto: onnx.TensorProto) -> None: - self._proto = proto - self._metadata_props: dict[str, str] | None = deserialize_metadata_props( - proto.metadata_props - ) - self._metadata: _metadata.MetadataStore | None = None - - @property - def name(self) -> str: - return self._proto.name - - @property - def shape(self) -> _core.Shape: - return _core.Shape(self._proto.dims, frozen=True) - - @property - def dtype(self) -> _enums.DataType: - return _enums.DataType(self._proto.data_type) - - @property - def doc_string(self) -> str: - return self._proto.doc_string - - @property - def raw(self) -> onnx.TensorProto: - return self._proto - - def __repr__(self) -> str: - # It is a little hard to display the content when there can be types - # unsupported by numpy - # Preferably we should display some content when the tensor is small - return f"{self._repr_base()}(name={self.name!r})" - - def __array__(self, dtype: Any = None) -> np.ndarray: - """Return the tensor as a numpy array, compatible with np.array.""" - return self.numpy().__array__(dtype) - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array. - - This is an improved version of onnx.numpy_helper.to_array. - It first reads the data using the dtype corresponding to the tensor - proto data field, then converts it to the correct dtype and shape. - Special cases are bfloat16, complex and int4 where we need to - reinterpret the data. Other types can simply be casted. - - When the data type is not supported by numpy, the value is the bit representation - of the dtype: - - - ``int8`` for int4, with the sign bit extended to 8 bits. - - ``uint8`` for uint4. - - ``uint8`` for 8-bit data types like float8. - - ``uint16`` for bfloat16. - - When the data type is a string, this method returns a numpy array - of bytes instead of a numpy array of strings, to follow the ONNX - specification. - - External tensors are not supported by this class. Use - :class:`onnxscript.ir.ExternalTensor` instead. - - Raises: - ValueError: If the data type is UNDEFINED. - """ - dtype = self.dtype - if dtype == _enums.DataType.UNDEFINED: - raise ValueError("Cannot convert UNDEFINED tensor to numpy array.") - if self._proto.data_location == onnx.TensorProto.EXTERNAL: - raise ValueError( - "Cannot convert external tensor to numpy array. " - "Use ir.ExternalTensor instead." - ) - - if self._proto.HasField("raw_data"): - array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")) - # Cannot return now, because we may need to unpack 4bit tensors - elif dtype == _enums.DataType.STRING: - return np.array(self._proto.string_data).reshape(self._proto.dims) - elif self._proto.int32_data: - array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32)) - if dtype == _enums.DataType.FLOAT16: - # Reinterpret the int32 as float16; bfloat16 is handled on the last line - array = array.astype(np.uint16).view(np.float16) - elif self._proto.int64_data: - array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64)) - elif self._proto.uint64_data: - array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64)) - elif self._proto.float_data: - array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32)) - if dtype == _enums.DataType.COMPLEX64: - array = _unflatten_complex(array) - elif self._proto.double_data: - array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64)) - if dtype == _enums.DataType.COMPLEX128: - array = _unflatten_complex(array) - else: - # Empty tensor - if not self._proto.dims: - # When dims not precent and there is no data, we return an empty array - return np.array([], dtype=dtype.numpy()) - else: - # Otherwise we return a size 0 array with the correct shape - return np.zeros(self._proto.dims, dtype=dtype.numpy()) - - if dtype == _enums.DataType.INT4: - return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims) - elif dtype == _enums.DataType.UINT4: - return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims) - else: - # Otherwise convert to the correct dtype and reshape - # Note we cannot use view() here because the storage dtype may not be the same size as the target - return array.astype(dtype.numpy()).reshape(self._proto.dims) - - def tobytes(self) -> bytes: - """Return the tensor as a byte string conformed to the ONNX specification, in little endian. - - Raises: - ValueError: If the tensor is a string tensor or an external tensor. - ValueError: If the tensor is of UNDEFINED data type. - """ - if self._proto.data_location == onnx.TensorProto.EXTERNAL: - raise ValueError( - "Cannot convert external tensor to bytes. Use ir.ExternalTensor instead." - ) - if self.dtype == _enums.DataType.STRING: - raise ValueError("Cannot convert string tensor to bytes.") - if self.dtype == _enums.DataType.UNDEFINED: - raise ValueError("Cannot convert UNDEFINED tensor to bytes.") - - if self._proto.HasField("raw_data"): - return self._proto.raw_data - if self._proto.float_data: - return np.array( - self._proto.float_data, dtype=_little_endian_dtype(np.float32) - ).tobytes() - if self._proto.int32_data: - array = np.array(self._proto.int32_data, dtype=np.int32) - if self.dtype in { - _enums.DataType.INT16, - _enums.DataType.UINT16, - _enums.DataType.FLOAT16, - _enums.DataType.BFLOAT16, - }: - return array.astype(_little_endian_dtype(np.uint16)).tobytes() - if self.dtype in { - _enums.DataType.INT8, - _enums.DataType.UINT8, - _enums.DataType.BOOL, - _enums.DataType.FLOAT8E4M3FN, - _enums.DataType.FLOAT8E4M3FNUZ, - _enums.DataType.FLOAT8E5M2, - _enums.DataType.FLOAT8E5M2FNUZ, - _enums.DataType.INT4, - _enums.DataType.UINT4, - }: - # uint4 and int4 values are already packed, even when stored as int32 - # so we don't need to pack them again - return array.astype(_little_endian_dtype(np.uint8)).tobytes() - assert self.dtype == _enums.DataType.INT32 - return array.tobytes() - if self._proto.int64_data: - return np.array( - self._proto.int64_data, dtype=_little_endian_dtype(np.int64) - ).tobytes() - if self._proto.double_data: - return np.array( - self._proto.double_data, dtype=_little_endian_dtype(np.float64) - ).tobytes() - if self._proto.uint64_data: - array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64)) - if self.dtype == _enums.DataType.UINT32: - return array.astype(_little_endian_dtype(np.uint32)).tobytes() - assert self.dtype == _enums.DataType.UINT64 - return array.tobytes() - # The repeating fields can be empty and still valid. - # For example, int32_data can be empty and still be a valid tensor. - return b"" - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attribute:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - -def _get_field(proto: Any, field: str) -> Any: - if proto.HasField(field): - return getattr(proto, field) - return None - - -# Deserialization - - -def deserialize_opset_import( - protos: Sequence[onnx.OperatorSetIdProto], -) -> dict[str, int]: - return {opset.domain: opset.version for opset in protos} - - -def _parse_experimental_function_value_info_name( - name: str, -) -> tuple[str, str, str] | None: - """Get the function domain, name and value name if the value info is for a function. - - The experimental format is: - {function_domain}::{function_name}/{value_name} - - Args: - name: The name stored in the value info. - - Returns: - A tuple of the function domain, function name and value name if the value info is for a function. - None otherwise. - """ - parts = name.split("/") - expected_parts = 2 - if len(parts) != expected_parts: - return None - function, value_name = parts - parts = function.split("::") - if len(parts) != expected_parts: - return None - # NOTE: There will not be overload because overloads are introduced in ONNX IR v10, which also - # introduces the ValueInfoProto for functions - function_domain, function_name = parts - return function_domain, function_name, value_name - - -def deserialize_model(proto: onnx.ModelProto) -> _core.Model: - graph = _deserialize_graph(proto.graph, []) - graph.opset_imports.update(deserialize_opset_import(proto.opset_import)) - - functions = [] - for func in proto.functions: - functions.append(deserialize_function(func)) - - model = _core.Model( - graph, - ir_version=proto.ir_version, - producer_name=_get_field(proto, "producer_name"), - producer_version=_get_field(proto, "producer_version"), - domain=_get_field(proto, "domain"), - model_version=_get_field(proto, "model_version"), - doc_string=_get_field(proto, "doc_string"), - functions=functions, - meta_data_props=deserialize_metadata_props(proto.metadata_props), - ) - - # Handle experimental value info for functions created by the dynamo exporter in IR version 9 - if model.ir_version < _FUNCTION_VALUE_INFO_SUPPORTED_VERSION: - _deserialized_experimental_value_info_for_function_ir9( - model.functions, proto.graph.value_info - ) - - return model - - -def _deserialized_experimental_value_info_for_function_ir9( - functions: Mapping[_protocols.OperatorIdentifier, _core.Function], - value_info_protos: Sequence[onnx.ValueInfoProto], -) -> None: - """Deserialize value info for functions when they are stored in an experimental format. - - The experimental format is: - {function_domain}::{function_name}/{value_name} - """ - # Parse value info for functions from the main graph - function_value_value_info_mapping: collections.defaultdict[ - _protocols.OperatorIdentifier, - dict[str, onnx.ValueInfoProto], - ] = collections.defaultdict(dict) - for value_info_proto in value_info_protos: - if ( - parsed := _parse_experimental_function_value_info_name(value_info_proto.name) - ) is None: - continue - function_domain, function_name, value_name = parsed - function_overload = "" - # TODO(justinchuby): Create a constructor for OperatorIdentifier so we don't create tuples manually - function_id = (function_domain, function_name, function_overload) - function = functions.get(function_id) - if function is None: - # Function not found - logger.debug( - "Function with ID '%s' not found in model functions. Value info '%s' will be ignored.", - function_id, - value_info_proto.name, - ) - continue - function_value_value_info_mapping[function_id][value_name] = value_info_proto - for function_id, function in functions.items(): - for input in function.inputs: - if input.name in function_value_value_info_mapping[function_id]: - deserialize_value_info_proto( - function_value_value_info_mapping[function_id][input.name], input - ) - for node in function: - for output in node.outputs: - if output.name in function_value_value_info_mapping[function_id]: - deserialize_value_info_proto( - function_value_value_info_mapping[function_id][output.name], - output, - ) - # The function outputs are handled as well because they are also node outputs - - -def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph: - return _deserialize_graph(proto, []) - - -def _deserialize_graph( - proto: onnx.GraphProto, scoped_values: list[dict[str, _core.Value]] -) -> _core.Graph: - """Deserialize a graph proto, recursively if needed. - - Args: - proto: The graph proto to deserialize. - scoped_values: A list of dictionaries mapping value names to their corresponding Value objects. - Every time we enter a new graph, a new scope is created and appended to this list to include - all values defined in the scope. - scoped_value_info: A list of dictionaries mapping value names to their corresponding ValueInfoProto. - """ - # Create values for initializers and inputs - initializers = [deserialize_tensor(tensor) for tensor in proto.initializer] - inputs = [_core.Input(info.name) for info in proto.input] - for info, value in zip(proto.input, inputs): - deserialize_value_info_proto(info, value) - - # Initialize the values dictionary for this graph scope with the inputs and initializers - values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc] - scoped_values.append(values) - for initializer in initializers: - if initializer.name in values: - # The initializer is for an input - values[initializer.name].const_value = initializer - else: - # The initializer is for some other value. Create this value first - initializer_value = _core.Value( - None, - index=None, - name=initializer.name, - # TODO(justinchuby): Fix type hinting for shape and dtype - shape=initializer.shape, # type: ignore - type=_core.TensorType(initializer.dtype), - const_value=initializer, - ) - values[initializer.name] = initializer_value - - # Add ValueInfos for this graph scope - value_info = {info.name: info for info in proto.value_info} - - # Deserialize nodes with all known values - nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node] - - # Fill in values for graph outputs - outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output] - scoped_values.pop() - return _core.Graph( - inputs, - outputs, - nodes=nodes, - # TODO(justinchuby): Attach the values associated with the initializers - initializers=initializers, - doc_string=_get_field(proto, "doc_string"), - name=_get_field(proto, "name"), - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - - -def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: - inputs = [_core.Input(name) for name in proto.input] - values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc] - value_info = {info.name: info for info in getattr(proto, "value_info", [])} - - # TODO(justinchuby): Handle unsorted nodes - nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node] - outputs = [values[name] for name in proto.output] - graph = _core.Graph( - inputs, - outputs, - nodes=nodes, - initializers=(), - doc_string=_get_field(proto, "doc_string"), - opset_imports=deserialize_opset_import(proto.opset_import), - name=( - f"{proto.name}_{proto.domain}" + f"__{proto.overload}" - if hasattr(proto, "overload") and proto.overload - else "" - ), - ) - attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto] - # Attributes without defaults - attributes += [ - _core.Attr(name, _enums.AttributeType.UNDEFINED, None) for name in proto.attribute - ] - return _core.Function( - domain=proto.domain, - name=proto.name, - overload=getattr(proto, "overload", ""), - graph=graph, - attributes=typing.cast(List[_core.Attr], attributes), - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - - -def deserialize_value_info_proto( - proto: onnx.ValueInfoProto, value: _core.Value | None -) -> _core.Value: - if value is None: - value = _core.Value(None, index=None, name=proto.name) - value.shape = deserialize_type_proto_for_shape(proto.type) - value.type = deserialize_type_proto_for_type(proto.type) - metadata_props = deserialize_metadata_props(proto.metadata_props) - if metadata_props is not None: - value.metadata_props.update(metadata_props) - value.doc_string = _get_field(proto, "doc_string") - return value - - -def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None: - if proto.HasField("tensor_type"): - if (shape_proto := _get_field(proto.tensor_type, "shape")) is None: - return None - # This logic handles when the shape is [] as well - dim_protos = shape_proto.dim - deserialized_dim_denotations = [ - deserialize_dimension(dim_proto) for dim_proto in dim_protos - ] - dims = [dim for dim, _ in deserialized_dim_denotations] - denotations = [denotation for _, denotation in deserialized_dim_denotations] - return _core.Shape(dims, denotations=denotations, frozen=True) - if proto.HasField("sparse_tensor_type"): - if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None: - return None - dim_protos = shape_proto.dim - deserialized_dim_denotations = [ - deserialize_dimension(dim_proto) for dim_proto in dim_protos - ] - dims = [dim for dim, _ in deserialized_dim_denotations] - denotations = [denotation for _, denotation in deserialized_dim_denotations] - return _core.Shape(dims, denotations=denotations, frozen=True) - if proto.HasField("sequence_type"): - if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None: - return None - return deserialize_type_proto_for_shape(elem_type) - if proto.HasField("optional_type"): - if (elem_type := _get_field(proto.optional_type, "elem_type")) is None: - return None - return deserialize_type_proto_for_shape(elem_type) - if proto.HasField("map_type"): - # TODO(justinchuby): Do we need to support map types? - raise NotImplementedError("Map types are not supported yet") - - return None - - -def deserialize_type_proto_for_type( - proto: onnx.TypeProto, -) -> _protocols.TypeProtocol | None: - denotation = _get_field(proto, "denotation") - if proto.HasField("tensor_type"): - if (elem_type := _get_field(proto.tensor_type, "elem_type")) is None: - return None - return _core.TensorType(_enums.DataType(elem_type), denotation=denotation) - if proto.HasField("sparse_tensor_type"): - if (elem_type := _get_field(proto.sparse_tensor_type, "elem_type")) is None: - return None - return _core.SparseTensorType(_enums.DataType(elem_type), denotation=denotation) - if proto.HasField("sequence_type"): - # FIXME(justinchuby): Allow nested types being None - if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None: - raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}") - nested_type = deserialize_type_proto_for_type(elem_type) - if nested_type is None: - raise ValueError(f"SequenceType must have elem_type set: {proto}") - return _core.SequenceType(nested_type, denotation=denotation) - if proto.HasField("optional_type"): - # FIXME(justinchuby): Allow nested types being None - if (elem_type := _get_field(proto.optional_type, "elem_type")) is None: - raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}") - nested_type = deserialize_type_proto_for_type(elem_type) - if nested_type is None: - raise ValueError(f"SequenceType must have elem_type set: {proto}") - return _core.OptionalType(nested_type, denotation=denotation) - if proto.HasField("map_type"): - # TODO(justinchuby): Do we need to support map types? - raise NotImplementedError("Map types are not supported yet") - - return None - - -def deserialize_dimension( - proto: onnx.TensorShapeProto.Dimension, -) -> tuple[int | _core.SymbolicDim, str | None]: - """Deserialize a dimension proto into (dimension, denotation). - - Args: - proto: The dimension proto to deserialize. - - Returns: - A tuple of the dimension and its denotation. - """ - value_field = proto.WhichOneof("value") - denotation = _get_field(proto, "denotation") - if value_field is not None: - value = getattr(proto, value_field) - if value_field == "dim_value": - return value, denotation - if value_field == "dim_param": - return _core.SymbolicDim(value), denotation - return _core.SymbolicDim(None), denotation - - -def deserialize_tensor( - proto: onnx.TensorProto, base_path: str | os.PathLike = "" -) -> _protocols.TensorProtocol: - # TODO: Sanitize base_path - if proto.data_location == onnx.TensorProto.EXTERNAL: - external_info = onnx.external_data_helper.ExternalDataInfo(proto) - return _core.ExternalTensor( - path=os.path.join(base_path, external_info.location), - offset=external_info.offset, - length=external_info.length, - dtype=_enums.DataType(proto.data_type), - name=proto.name, - shape=_core.Shape(proto.dims), - doc_string=proto.doc_string, - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - if proto.data_type == _enums.DataType.STRING: - name = _get_field(proto, "name") - doc_string = _get_field(proto, "doc_string") - metadata_props = deserialize_metadata_props(proto.metadata_props) - return _core.StringTensor( - proto.string_data, - shape=_core.Shape(proto.dims), - name=name, - doc_string=doc_string, - metadata_props=metadata_props, - ) - return TensorProtoTensor(proto) - - -def deserialize_metadata_props( - proto: Sequence[onnx.StringStringEntryProto], -) -> dict[str, str] | None: - if len(proto) == 0: - # Avoid creating an empty dictionary to save memory - return None - return {entry.key: entry.value for entry in proto} - - -def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr: - return _deserialize_attribute(proto, []) - - -def _deserialize_attribute( - proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]] -) -> _core.Attr | _core.RefAttr: - name = proto.name - doc_string = _get_field(proto, "doc_string") - type_ = _enums.AttributeType(proto.type) - ref_attr_name = _get_field(proto, "ref_attr_name") - if ref_attr_name: - return _core.RefAttr(name, ref_attr_name, type_, doc_string=doc_string) - - if type_ == _enums.AttributeType.INT: - return _core.AttrInt64(name, proto.i, doc_string=doc_string) - if type_ == _enums.AttributeType.FLOAT: - return _core.AttrFloat32(name, proto.f, doc_string=doc_string) - if type_ == _enums.AttributeType.STRING: - return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string) - if type_ == _enums.AttributeType.INTS: - return _core.AttrInt64s(name, proto.ints, doc_string=doc_string) - if type_ == _enums.AttributeType.FLOATS: - return _core.AttrFloat32s(name, proto.floats, doc_string=doc_string) - if type_ == _enums.AttributeType.STRINGS: - return _core.AttrStrings( - name, [s.decode("utf-8") for s in proto.strings], doc_string=doc_string - ) - if type_ == _enums.AttributeType.TENSOR: - return _core.AttrTensor(name, deserialize_tensor(proto.t), doc_string=doc_string) - if type_ == _enums.AttributeType.GRAPH: - return _core.AttrGraph( - name, _deserialize_graph(proto.g, scoped_values), doc_string=doc_string - ) - if type_ == _enums.AttributeType.TENSORS: - return _core.AttrTensors( - name, - [deserialize_tensor(t) for t in proto.tensors], - doc_string=doc_string, - ) - if type_ == _enums.AttributeType.GRAPHS: - return _core.AttrGraphs( - name, - [_deserialize_graph(g, scoped_values) for g in proto.graphs], - doc_string=doc_string, - ) - if type_ == _enums.AttributeType.SPARSE_TENSOR: - raise NotImplementedError("Sparse tensors are not supported yet") - if type_ == _enums.AttributeType.SPARSE_TENSORS: - raise NotImplementedError("Sparse tensors are not supported yet") - if type_ == _enums.AttributeType.TYPE_PROTO: - ir_type = deserialize_type_proto_for_type(proto.tp) - shape = deserialize_type_proto_for_shape(proto.tp) - return _core.AttrTypeProto( - name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string - ) - if type_ == _enums.AttributeType.TYPE_PROTOS: - type_and_shapes = [] - for type_proto in proto.type_protos: - ir_type = deserialize_type_proto_for_type(type_proto) - shape = deserialize_type_proto_for_shape(type_proto) - type_and_shapes.append(_core.TypeAndShape(ir_type, shape)) - return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string) - if type_ == _enums.AttributeType.UNDEFINED: - return _core.Attr(name, type_, None, doc_string=doc_string) - raise ValueError(f"Unsupported attribute type: '{type_}'") - - -def deserialize_node(proto: onnx.NodeProto) -> _core.Node: - return _deserialize_node(proto, scoped_values=[], value_info={}) - - -def _deserialize_node( - proto: onnx.NodeProto, - scoped_values: list[dict[str, _core.Value]], - value_info: dict[str, onnx.ValueInfoProto], -) -> _core.Node: - node_inputs: list[_core.Value | None] = [] - for input_name in proto.input: - if input_name == "": - # Empty input - node_inputs.append(None) - continue - - # Find the input in all value scopes - found = False - for values in reversed(scoped_values): - if input_name not in values: - continue - node_inputs.append(values[input_name]) - found = True - del values # Remove the reference so it is not used by mistake - break - if not found: - # If the input is not found, we know the graph may be unsorted and - # the input may be a supposed-to-be initializer or an output of a node that comes later. - # Here we create the value with the name and add it to the current scope. - # Nodes need to check the value pool for potentially initialized outputs - logger.warning( - "Input '%s' of node '%s(%s::%s:%s)' not found in any scope. " - "The graph may be unsorted. Creating a new input (current depth: %s) .", - input_name, - proto.name, - proto.domain, - proto.op_type, - getattr(proto, "overload", ""), - len(scoped_values), - ) - if len(scoped_values) > 1: - logger.warning( - "Caveat: The value is created in the subgraph. If " - "the node is referencing a value that is not in the current graph, " - "it is impossible to create it in the correct scope.", - ) - value = _core.Value(None, index=None, name=input_name) - # Fill in shape/type information if they exist - if input_name in value_info: - deserialize_value_info_proto(value_info[input_name], value) - node_inputs.append(value) - # We can only create the value in the current scope. If the subgraph is - # referencing a value that is not in the current scope, it is impossible - # to create it in the correct scope. - scoped_values[-1][input_name] = value - - # Build the output values for the node. - node_outputs: list[_core.Value] = [] - for output_name in proto.output: - if output_name == "": - # Empty output - node_outputs.append(_core.Value(None, index=None, name="")) - continue - - # 1. When the graph is unsorted, we may be able to find the output already created - # as an input to some other nodes in the current scope. - # Note that a value is always owned by the producing node. Even though a value - # can be created when parsing inputs of other nodes, the new node created here - # that produces the value will assume ownership. It is then impossible to transfer - # the ownership to any other node. - - # The output can only be found in the current scope. It is impossible for - # a node to produce an output that is not in its own scope. - current_scope = scoped_values[-1] - if output_name in current_scope: - value = current_scope[output_name] - else: - # 2. Common scenario: the graph is sorted and this is the first time we see the output. - # Create the value and add it to the current scope. - value = _core.Value(None, index=None, name=output_name) - current_scope[output_name] = value - # Fill in shape/type information if they exist - if output_name in value_info: - deserialize_value_info_proto(value_info[output_name], value) - else: - logger.debug( - "ValueInfoProto not found for output '%s' in node '%s' of type '%s'", - output_name, - proto.name, - proto.op_type, - ) - node_outputs.append(value) - return _core.Node( - proto.domain, - proto.op_type, - node_inputs, - [_deserialize_attribute(a, scoped_values) for a in proto.attribute], - overload=getattr(proto, "overload", ""), - outputs=node_outputs, - name=proto.name, - doc_string=_get_field(proto, "doc_string"), - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - - -# Serialization - - -def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto: - return serialize_model_into(onnx.ModelProto(), from_=model) - - -def serialize_model_into( - model_proto: onnx.ModelProto, from_: _protocols.ModelProtocol -) -> onnx.ModelProto: - """Serialize an IR model to an ONNX model proto.""" - model_proto.ir_version = from_.ir_version - if from_.producer_name: - model_proto.producer_name = from_.producer_name - if from_.producer_version: - model_proto.producer_version = from_.producer_version - if from_.domain: - model_proto.domain = from_.domain - if from_.model_version: - model_proto.model_version = from_.model_version - if from_.doc_string: - model_proto.doc_string = from_.doc_string - # Sort names for deterministic serialization - _serialize_opset_imports_into(model_proto.opset_import, from_.opset_imports) - if from_.metadata_props: - _serialize_metadata_props_into(model_proto.metadata_props, from_.metadata_props) - serialize_graph_into(model_proto.graph, from_.graph) - - create_value_info_in_functions = from_.ir_version >= _FUNCTION_VALUE_INFO_SUPPORTED_VERSION - for func in from_.functions.values(): - serialize_function_into( - model_proto.functions.add(), - from_=func, - create_value_info=create_value_info_in_functions, - ) - if not create_value_info_in_functions: - # Create them in the main graph instead - _serialize_experimental_value_info_for_function_ir9_into(model_proto.graph, func) - return model_proto - - -def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool: - """Check if value info should be created for a value. - - Args: - value: The value to check. - - Returns: - True if value info should be created for the value. - """ - # No need to serialize value info if it is not set - return not (value.shape is None and value.type is None) - - -def _serialize_experimental_value_info_for_function_ir9_into( - graph_proto: onnx.GraphProto, function: _protocols.FunctionProtocol -) -> None: - """Serialize value info for functions in an experimental format for IR version 9. - - Because IRv9 and older does not have ValueInfoProto for functions, we give the value info - special names and store them in the main graph instead. - - The experimental format is: - {function_domain}::{function_name}/{value_name} - - Args: - graph_proto: The graph proto to create ValueInfoProto in. - function: The function to serialize. - """ - # TODO(justinchuby): In the future, we can decide if it is a good idea to simply iterate over - # all values in the function and call serialize_value_into instead. - function_qualified_name = f"{function.domain}::{function.name}" - - def format_name(value_name: str) -> str: - return f"{function_qualified_name}/{value_name}" - - for input in function.inputs: - if not input.name: - logging.warning( - "Function '%s': Value name not set for function input: %s", - function_qualified_name, - input, - ) - continue - if not _should_create_value_info_for_value(input): - # No need to serialize value info if it is not set - continue - serialize_value_into(graph_proto.value_info.add(), input, name=format_name(input.name)) - for node in function: - for node_output in node.outputs: - if not node_output.name: - logging.warning( - "Function '%s': Value name not set for node output: %s", - function_qualified_name, - node_output, - ) - continue - if not _should_create_value_info_for_value(node_output): - # No need to serialize value info if it is not set - continue - serialize_value_into( - graph_proto.value_info.add(), - node_output, - name=format_name(node_output.name), - ) - - -def _serialize_opset_imports_into( - opset_ids: proto_containers.RepeatedCompositeFieldContainer[onnx.OperatorSetIdProto], - from_: Mapping[str, int], -) -> None: - """Serialize opset imports into a repeated field of OperatorSetId protos. - - Args: - opset_ids: The repeated field to serialize into. - from_: The mapping of opset domains to versions to serialize. - """ - # Sort names for deterministic serialization - for domain, version in from_.items(): - opset_ids.add(domain=domain, version=version) - - -def _serialize_metadata_props_into( - string_string_entries: proto_containers.RepeatedCompositeFieldContainer[ - onnx.StringStringEntryProto - ], - from_: Mapping[str, str], -) -> None: - """Serialize metadata properties into a repeated field of string-string entries. - - Args: - string_string_entries: The repeated field to serialize into. - from_: The mapping of metadata properties to serialize. - """ - # Sort names for deterministic serialization - for key in sorted(from_): - string_string_entries.add(key=key, value=from_[key]) - - -def serialize_graph( - graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol, -) -> onnx.GraphProto: - graph_proto = onnx.GraphProto() - serialize_graph_into(graph_proto, from_=graph) - return graph_proto - - -def serialize_graph_into( - graph_proto: onnx.GraphProto, - from_: _protocols.GraphProtocol | _protocols.GraphViewProtocol, -) -> None: - if from_.name: - graph_proto.name = from_.name - if from_.doc_string: - graph_proto.doc_string = from_.doc_string - for input_ in from_.inputs: - serialize_value_into(graph_proto.input.add(), input_) - # TODO(justinchuby): Support sparse_initializer - for initializer in from_.initializers.values(): - serialize_tensor_into(graph_proto.initializer.add(), from_=initializer) - for node in from_: - serialize_node_into(graph_proto.node.add(), from_=node) - for node_output in node.outputs: - if not _should_create_value_info_for_value(node_output): - # No need to serialize value info if it is not set - continue - if node_output.is_graph_output(): - # No need to serialize value info for these outputs because they are also graph outputs - continue - serialize_value_into(graph_proto.value_info.add(), node_output) - for output in from_.outputs: - serialize_value_into(graph_proto.output.add(), from_=output) - if from_.metadata_props: - _serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props) - - -def serialize_function( - function: _protocols.FunctionProtocol, *, create_value_info: bool = True -) -> onnx.FunctionProto: - """Serialize an IR function as a FunctionProto. - - Args: - function: The function to serialize. - create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported - starting from ONNX IR version 10. - """ - function_proto = onnx.FunctionProto() - serialize_function_into( - function_proto, from_=function, create_value_info=create_value_info - ) - return function_proto - - -def serialize_function_into( - function_proto: onnx.FunctionProto, - from_: _protocols.FunctionProtocol, - *, - create_value_info: bool = True, -) -> None: - """Serialize an IR function into a FunctionProto. - - Args: - function_proto: The proto to serialize into. - from_: The function to serialize. - create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported - starting from ONNX IR version 10. - """ - if from_.domain: - function_proto.domain = from_.domain - if from_.name: - function_proto.name = from_.name - if from_.overload: - function_proto.overload = from_.overload - if from_.doc_string: - function_proto.doc_string = from_.doc_string - if from_.opset_imports: - # A valid ONNX graph should have at least one opset import, that is - # the default ONNX opset. - # Here we check for emptiness before serializing to keep the logic consistent - _serialize_opset_imports_into(function_proto.opset_import, from_.opset_imports) - if from_.metadata_props: - _serialize_metadata_props_into(function_proto.metadata_props, from_.metadata_props) - for input_ in from_.inputs: - function_proto.input.append(input_.name) - if not _should_create_value_info_for_value(input_): - # No need to serialize value info if it is not set - continue - if not create_value_info: - continue - serialize_value_into(function_proto.value_info.add(), input_) - for attr in from_.attributes.values(): - if attr.value is not None: - serialize_attribute_into(function_proto.attribute_proto.add(), from_=attr) - else: - # ONNX does not record type information if the attribute does not have a default - function_proto.attribute.append(attr.name) - for func_output in from_.outputs: - function_proto.output.append(func_output.name) - # No need to serialize value info for function outputs because they are - # also node outputs - for node in from_: - serialize_node_into(function_proto.node.add(), from_=node) - # Record value info for outputs - for node_output in node.outputs: - if not _should_create_value_info_for_value(node_output): - # No need to serialize value info if it is not set - continue - if not create_value_info: - continue - serialize_value_into(function_proto.value_info.add(), node_output) - - -def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto: - node_proto = onnx.NodeProto() - serialize_node_into(node_proto, from_=node) - return node_proto - - -def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None: - node_proto.op_type = from_.op_type - if from_.domain: - # If the domain is "", we can assume the default domain and not set it - node_proto.domain = from_.domain - if from_.name: - node_proto.name = from_.name - if from_.overload: - node_proto.overload = from_.overload - if from_.doc_string: - node_proto.doc_string = from_.doc_string - if from_.metadata_props: - _serialize_metadata_props_into(node_proto.metadata_props, from_.metadata_props) - for input_ in from_.inputs: - if input_ is None: - node_proto.input.append("") - else: - node_proto.input.append(input_.name) - for output in from_.outputs: - node_proto.output.append(output.name) - for attr in from_.attributes.values(): - if isinstance(attr, _core.Attr): - serialize_attribute_into(node_proto.attribute.add(), from_=attr) - elif isinstance(attr, _core.RefAttr): - serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr) - # Handle protocol attributes for completeness. We do not check them first because - # calling isinstance on a protocol can be slow. - # Most of the time, we will have Attr or RefAttr so the two branches below - # will not be taken. - elif isinstance(attr, _protocols.AttributeProtocol): - serialize_attribute_into(node_proto.attribute.add(), from_=attr) - elif isinstance(attr, _protocols.ReferenceAttributeProtocol): - serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr) - else: - raise TypeError(f"Unsupported attribute type: {type(attr)}") - - -def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto: - tensor_proto = onnx.TensorProto() - serialize_tensor_into(tensor_proto, from_=tensor) - return tensor_proto - - -def serialize_tensor_into( - tensor_proto: onnx.TensorProto, from_: _protocols.TensorProtocol -) -> None: - if isinstance(from_, TensorProtoTensor): - # Directly copy from the tensor proto if it is available - tensor_proto.CopyFrom(from_.raw) - if from_.metadata_props: - _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props) - return - - tensor_proto.name = from_.name - if from_.doc_string: - tensor_proto.doc_string = from_.doc_string - tensor_proto.data_type = from_.dtype.value - tensor_proto.dims.extend(from_.shape.numpy()) - if isinstance(from_, _core.ExternalTensor): - # Store external tensors as is - tensor_proto.data_location = onnx.TensorProto.EXTERNAL - for k, v in { - "location": os.fspath(from_.path), - "offset": from_.offset, - "length": from_.length, - }.items(): - if v is not None: - entry = tensor_proto.external_data.add() - entry.key = k - entry.value = str(v) - elif isinstance(from_, _core.StringTensor): - tensor_proto.string_data.extend(from_.string_data()) - else: - tensor_proto.raw_data = from_.tobytes() - _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props) - - -def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.AttributeProto: - attribute_proto = onnx.AttributeProto() - serialize_attribute_into(attribute_proto, from_=attribute) - return attribute_proto - - -def serialize_attribute_into( - attribute_proto: onnx.AttributeProto, from_: _protocols.AttributeProtocol -) -> None: - attribute_proto.name = from_.name - if from_.doc_string: - attribute_proto.doc_string = from_.doc_string - _fill_in_value_for_attribute(attribute_proto, from_.type, from_.value) - - -def _fill_in_value_for_attribute( - attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any -) -> None: - if type_ == _enums.AttributeType.INT: - # value: int - attribute_proto.i = value - attribute_proto.type = onnx.AttributeProto.INT - elif type_ == _enums.AttributeType.FLOAT: - # value: float - attribute_proto.f = value - attribute_proto.type = onnx.AttributeProto.FLOAT - elif type_ == _enums.AttributeType.STRING: - # value: str - attribute_proto.s = value.encode("utf-8") - attribute_proto.type = onnx.AttributeProto.STRING - elif type_ == _enums.AttributeType.INTS: - # value: Sequence[int] - attribute_proto.ints.extend(value) - attribute_proto.type = onnx.AttributeProto.INTS - elif type_ == _enums.AttributeType.FLOATS: - # value: Sequence[float] - attribute_proto.floats.extend(value) - attribute_proto.type = onnx.AttributeProto.FLOATS - elif type_ == _enums.AttributeType.STRINGS: - # value: Sequence[str] - attribute_proto.strings.extend([s.encode("utf-8") for s in value]) - attribute_proto.type = onnx.AttributeProto.STRINGS - elif type_ == _enums.AttributeType.TENSOR: - # value: _protocols.TensorProtocol - serialize_tensor_into(attribute_proto.t, value) - attribute_proto.type = onnx.AttributeProto.TENSOR - elif type_ == _enums.AttributeType.GRAPH: - # value: _protocols.GraphProtocol - serialize_graph_into(attribute_proto.g, value) - attribute_proto.type = onnx.AttributeProto.GRAPH - elif type_ == _enums.AttributeType.TENSORS: - # value: Sequence[_protocols.TensorProtocol] - for tensor in value: - serialize_tensor_into(attribute_proto.tensors.add(), tensor) - attribute_proto.type = onnx.AttributeProto.TENSORS - elif type_ == _enums.AttributeType.GRAPHS: - # value: Sequence[_protocols.GraphProtocol] - for graph in value: - serialize_graph_into(attribute_proto.graphs.add(), graph) - attribute_proto.type = onnx.AttributeProto.GRAPHS - elif type_ == _enums.AttributeType.SPARSE_TENSOR: - raise NotImplementedError("Sparse tensors are not supported yet") - elif type_ == _enums.AttributeType.SPARSE_TENSORS: - raise NotImplementedError("Sparse tensors are not supported yet") - elif type_ == _enums.AttributeType.TYPE_PROTO: - # value: _core.TypeAndShape - if value.type is not None: - serialize_type_into(attribute_proto.tp, value.type) - # Need to create the type _before_ writing the shape - if value.shape is not None: - serialize_shape_into(attribute_proto.tp, value.shape) - attribute_proto.type = onnx.AttributeProto.TYPE_PROTO - elif type_ == _enums.AttributeType.TYPE_PROTOS: - for ir_type in value: - # ir_type: _core.TypeAndShape - type_proto = attribute_proto.type_protos.add() - if ir_type.type is not None: - serialize_type_into(type_proto, ir_type.type) - # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto - if ir_type.shape is not None: - serialize_shape_into(type_proto, ir_type.shape) - attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS - else: - raise TypeError(f"Unsupported attribute type: {type_}") - - -def serialize_reference_attribute_into( - attribute_proto: onnx.AttributeProto, from_: _protocols.ReferenceAttributeProtocol -) -> None: - attribute_proto.name = from_.name - attribute_proto.ref_attr_name = from_.ref_attr_name - if from_.doc_string: - attribute_proto.doc_string = from_.doc_string - attribute_proto.type = typing.cast(onnx.AttributeProto.AttributeType, from_.type.value) - - -def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.ValueInfoProto: - """Serialize a value into a ValueInfoProto. - - Args: - value: The proto to serialize into. - from_: The value to serialize. - name: A custom name to set for the value info. If not provided, the name from the value will be used. - """ - value_info_proto = onnx.ValueInfoProto() - serialize_value_into(value_info_proto, value, name=name) - return value_info_proto - - -def serialize_value_into( - value_info_proto: onnx.ValueInfoProto, - from_: _protocols.ValueProtocol, - *, - name: str = "", -) -> None: - """Serialize a value into a ValueInfoProto. - - Args: - value_info_proto: The proto to serialize into. - from_: The value to serialize. - name: A custom name to set for the value info. If not provided, the name from the value will be used. - """ - if name: - value_info_proto.name = name - else: - value_info_proto.name = from_.name - if from_.metadata_props: - _serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props) - if from_.type is not None: - serialize_type_into(value_info_proto.type, from_.type) - # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto - if from_.shape is not None: - serialize_shape_into(value_info_proto.type, from_.shape) - if from_.doc_string: - value_info_proto.doc_string = from_.doc_string - - -def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None: - if from_.denotation: - type_proto.denotation = from_.denotation - if isinstance(from_, _core.TensorType): - tensor_type_proto = type_proto.tensor_type - tensor_type_proto.elem_type = from_.dtype.value - elif isinstance(from_, _core.SparseTensorType): - sparse_tensor_type_proto = type_proto.sparse_tensor_type - sparse_tensor_type_proto.elem_type = from_.dtype.value - elif isinstance(from_, _core.SequenceType): - sequence_type_proto = type_proto.sequence_type - serialize_type_into(sequence_type_proto.elem_type, from_.elem_type) - elif isinstance(from_, _core.OptionalType): - optional_type_proto = type_proto.optional_type - serialize_type_into(optional_type_proto.elem_type, from_.elem_type) - else: - raise TypeError(f"Unsupported type: {from_}") - - -def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None: - value_field = type_proto.WhichOneof("value") - tensor_type = getattr(type_proto, value_field) - while not isinstance(tensor_type.elem_type, int): - # Find the leaf type that has the shape field - type_proto = tensor_type.elem_type - value_field = type_proto.WhichOneof("value") - tensor_type = getattr(type_proto, value_field) - # When from is empty, we still need to set the shape field to an empty list by touching it - tensor_type.shape.ClearField("dim") - for i, dim in enumerate(from_): - denotation = from_.get_denotation(i) - serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation) - - -def serialize_dimension_into( - dim_proto: onnx.TensorShapeProto.Dimension, - dim: int | _protocols.SymbolicDimProtocol, - denotation: str | None = None, -) -> None: - if denotation: - dim_proto.denotation = denotation - if isinstance(dim, int): - dim_proto.dim_value = dim - elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)): - if dim.value is not None: - # TODO(justinchuby): None is probably not a valid value for dim_param - dim_proto.dim_param = str(dim.value) diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py deleted file mode 100644 index 64512d9066..0000000000 --- a/onnxscript/ir/serde_test.py +++ /dev/null @@ -1,211 +0,0 @@ -import unittest -from typing import Callable - -import numpy as np -import onnx -import parameterized - -from onnxscript import ir -from onnxscript.ir import serde - - -class TensorProtoTensorTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ("FLOAT", onnx.TensorProto.FLOAT), - ("BOOL", onnx.TensorProto.BOOL), - ("FLOAT16", onnx.TensorProto.FLOAT16), - ("DOUBLE", onnx.TensorProto.DOUBLE), - ] - ) - def test_tensor_proto_tensor(self, _: str, dtype: int): - tensor_proto = onnx.helper.make_tensor( - "test_tensor", dtype, [1, 9], [-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0] - ) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - - def test_tensor_proto_tensor_bfloat16(self): - expected_array = np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]]) - tensor_proto = onnx.helper.make_tensor( - "test_tensor", onnx.TensorProto.BFLOAT16, [1, 9], expected_array - ) - tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal( - onnx.numpy_helper.bfloat16_to_float32(tensor.numpy()), expected_array - ) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - - @parameterized.parameterized.expand( - [ - ( - "FLOAT8E4M3FN", - onnx.TensorProto.FLOAT8E4M3FN, - lambda x: onnx.numpy_helper.float8e4m3_to_float32(x, fn=True), - ), - ( - "FLOAT8E4M3FNUZ", - onnx.TensorProto.FLOAT8E4M3FNUZ, - lambda x: onnx.numpy_helper.float8e4m3_to_float32(x, fn=True, uz=True), - ), - ( - "FLOAT8E5M2", - onnx.TensorProto.FLOAT8E5M2, - onnx.numpy_helper.float8e5m2_to_float32, - ), - ( - "FLOAT8E5M2FNUZ", - onnx.TensorProto.FLOAT8E5M2FNUZ, - lambda x: onnx.numpy_helper.float8e5m2_to_float32(x, fn=True, uz=True), - ), - ] - ) - def test_tensor_proto_tensor_float8(self, _: str, dtype: int, to_float32_func: Callable): - expected_array = np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]]) - tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 9], expected_array) - tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal(to_float32_func(tensor.numpy()), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - if dtype in (onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FNUZ): - # TODO: Remove the fix when ONNX 1.17 releases - self.skipTest("ONNX to_array fails: https://github.com/onnx/onnx/pull/6124") - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - - @parameterized.parameterized.expand( - [ - ("INT8", onnx.TensorProto.INT8), - ("INT16", onnx.TensorProto.INT16), - ("INT32", onnx.TensorProto.INT32), - ("INT64", onnx.TensorProto.INT64), - ("INT4", onnx.TensorProto.INT4), - ] - ) - def test_tensor_proto_tensor_int(self, _: str, dtype: int): - tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 4], [-1, 0, 1, 8]) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array( - tensor_proto - ) # [-1, 0, 1, 7], 8 is clamped to 7 - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - - @parameterized.parameterized.expand( - [ - ("UINT8", onnx.TensorProto.UINT8), - ("UINT16", onnx.TensorProto.UINT16), - ("UINT32", onnx.TensorProto.UINT32), - ("UINT64", onnx.TensorProto.UINT64), - ("UINT4", onnx.TensorProto.UINT4), - ] - ) - def test_tensor_proto_tensor_uint(self, _: str, dtype: int): - tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 3], [0, 1, 8]) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - - @parameterized.parameterized.expand( - [ - ("COMPLEX64", onnx.TensorProto.COMPLEX64, np.complex64), - ("COMPLEX128", onnx.TensorProto.COMPLEX128, np.complex128), - ] - ) - def test_tensor_proto_tensor_complex(self, _: str, dtype: int, np_dtype: np.dtype): - expected_array = np.array([[0.0 + 1j, 0.2 - 1j, 0.3]], dtype=np_dtype) - tensor_proto = onnx.helper.make_tensor( - "test_tensor", dtype, [1, 3], [0.0 + 1j, 0.2 - 1j, 0.3] - ) - tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - - def test_tensor_proto_tensor_empty_tensor(self): - tensor_proto = onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [0], []) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - - -class DeserializeGraphTest(unittest.TestCase): - def test_deserialize_graph_handles_unsorted_graph(self): - node_0 = ir.Node( - "", - "Op_0", - inputs=[ir.Input("input_0"), ir.Input("input_1")], - num_outputs=2, - name="node_0", - ) - node_1 = ir.Node( - "", - "Op_1", - inputs=[node_0.outputs[0]], - num_outputs=1, - name="node_1", - ) - graph = ir.Graph( - inputs=node_0.inputs, # type: ignore - outputs=[node_1.outputs[0]], - # Unsorted nodes - nodes=[node_1, node_0], - name="test_graph", - ) - graph_proto = serde.serialize_graph(graph) - deserialized_graph = serde.deserialize_graph(graph_proto) - self.assertEqual(deserialized_graph[0].op_type, "Op_1") - self.assertEqual(deserialized_graph[1].op_type, "Op_0") - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 3940ba9297..76023ea002 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -1,7 +1,6 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- +# ruff: noqa: TID251 from __future__ import annotations import dataclasses @@ -215,7 +214,7 @@ def __str__(self): def debug_print(self): if logger.isEnabledFor(logging.DEBUG): - logger.debug("%s: %s", type(self), str(self)) + logger.debug("%s: %s", type(self), self) def to_node_proto(self, node_name: str) -> onnx.NodeProto: n = helper.make_node( @@ -321,6 +320,7 @@ def to_model_proto( io_types: Optional[ONNXType] = None, input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, + value_infos: dict[str, ONNXType] | None = None, **kwargs, ) -> onnx.ModelProto: """Converts this instance into a `onnx.ModelProto`. @@ -334,12 +334,24 @@ def to_model_proto( are set to be of the corresponding type in this list. output_types: When specified, all the outputs of the model are set to be of the corresponding type in this list. + value_infos: A dictionary mapping intermediate variable names to ONNX types. + Used to set value_info for intermediate variables. kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. Returns: An instance of :class:`onnx.ModelProto`. """ - graph, sub_functions = self.to_graph_and_functions(use_default_type=False) + value_infos = ( + [ + onnx.helper.make_value_info(name, type.to_type_proto()) + for name, type in value_infos.items() + ] + if value_infos + else None + ) + graph, sub_functions = self.to_graph_and_functions( + use_default_type=False, value_infos=value_infos + ) if io_types is not None: for input in graph.input: if not input.HasField("type"): @@ -370,13 +382,19 @@ def to_proto(f): for n in self.stmts: if n.callee.opset.domain not in opsets: opsets[n.callee.opset.domain] = n.callee.opset.version + + for proto in functions: + if proto.domain not in opsets: + opsets[proto.domain] = 1 + # TODO(rama): Handle conflicts with appropriate error/warning message. + for opset in proto.opset_import: + if opset.domain not in opsets: + opsets[opset.domain] = opset.version + if "" not in opsets: # No operator is using the standard opset. # A default value is given. opsets[""] = onnx_opset_version() - for proto in functions: - if proto.domain not in opsets: - opsets[proto.domain] = 1 if "ir_version" not in kwargs: kwargs["ir_version"] = select_ir_version(opsets[""]) @@ -389,7 +407,9 @@ def to_proto(f): ) def to_graph_and_functions( - self, use_default_type: bool = True + self, + use_default_type: bool = True, + value_infos: Sequence[ValueInfoProto] | None = None, ) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]: """Converts this instance into a `onnx.GraphProto` and a map from function-name to `onnx.FunctionProto`. @@ -397,6 +417,8 @@ def to_graph_and_functions( Args: use_default_type: if True, the function uses a default type for inputs and outputs that do not have a type + value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added + to the graph. Returns: a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto` @@ -410,6 +432,7 @@ def to_graph_and_functions( self.name, [x.to_value_info(use_default_type) for x in self.inputs], [y.to_value_info(use_default_type) for y in self.outputs], + value_info=value_infos, ) return graph, called_functions diff --git a/onnxscript/main.py b/onnxscript/main.py index 51c180e275..3ea3e50f90 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -1,22 +1,22 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- # pylint disable: protected-access from __future__ import annotations import ast import inspect import sys -import types -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, TypeVar -import onnx.helper +from typing_extensions import ParamSpec import onnxscript -from onnxscript import converter, irbuilder, values +from onnxscript import converter, ir, irbuilder, values from onnxscript._internal import ast_utils +_R = TypeVar("_R") +_P = ParamSpec("_P") + def script_check( f: ast.FunctionDef, @@ -42,7 +42,7 @@ def script( opset: Optional[values.Opset] = None, default_opset: Optional[values.Opset] = None, **kwargs: Any, -) -> Callable[[types.FunctionType], onnxscript.OnnxFunction]: +) -> Callable[[Callable[_P, _R]], onnxscript.OnnxFunction[_P, _R]]: """Main decorator. Declares a function as an onnx function. Args: @@ -78,7 +78,7 @@ def log2(x): "Script parameter must be an opset. Did you use @script instead of @script()?" ) - def transform(f: types.FunctionType) -> onnxscript.OnnxFunction: + def transform(f: Callable[_P, _R]) -> onnxscript.OnnxFunction[_P, _R]: if not inspect.isfunction(f): raise TypeError("The ONNXScript decorator should be applied to functions only.") @@ -98,7 +98,7 @@ def transform(f: types.FunctionType) -> onnxscript.OnnxFunction: return transform -def graph() -> Callable[[types.FunctionType], values.OnnxClosure]: +def graph() -> Callable[[Callable], values.OnnxClosure]: """A parametric decorator used to annotate nested-functions that are used as graph-attributes. @@ -145,7 +145,7 @@ def Sum(sum_in, next): onnx_function = wrapper_frame.f_locals["self"] nested_functions = onnx_function.function_ir.nested_functions - def transform(f: types.FunctionType) -> values.OnnxClosure: + def transform(f: Callable) -> values.OnnxClosure: return values.OnnxClosure(nested_functions[f.__name__], function_frame, f) return transform @@ -160,11 +160,17 @@ def export_onnx_lib(functions: Sequence[values.OnnxFunction], filename: str) -> # Since we don't yet have LibProto defined, we use a ModelProto as a temporary # container for the list of functions exported as a library, with an empty graph # and dummy opset_imports. - model = onnx.helper.make_model( - onnx.GraphProto(), - functions=[f.to_function_proto() for f in functions], + + # TODO(justinchuby): This function is not well supported. We should consider removing it + model = ir.Model( + ir.Graph( + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 15}, + ), + functions=[ir.serde.deserialize_function(f.to_function_proto()) for f in functions], + ir_version=10, producer_name="p2o", - opset_imports=[onnx.helper.make_opsetid("", 15)], ) - - onnx.save(model, filename) + ir.save(model, filename) diff --git a/onnxscript/onnx_opset/__init__.py b/onnxscript/onnx_opset/__init__.py index c84d95c0cd..9b6ed0915c 100644 --- a/onnxscript/onnx_opset/__init__.py +++ b/onnxscript/onnx_opset/__init__.py @@ -2,13 +2,11 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 # -------------------------------------------------------------------------- from __future__ import annotations @@ -37,13 +35,15 @@ from onnxscript.onnx_opset._impl.opset18 import Opset18 from onnxscript.onnx_opset._impl.opset19 import Opset19 from onnxscript.onnx_opset._impl.opset20 import Opset20 +from onnxscript.onnx_opset._impl.opset21 import Opset21 +from onnxscript.onnx_opset._impl.opset22 import Opset22 +from onnxscript.onnx_opset._impl.opset23 import Opset23 +from onnxscript.onnx_opset._impl.opset24 import Opset24 from onnxscript.onnx_opset._impl.opset_ai_onnx_ml1 import Opset_ai_onnx_ml1 from onnxscript.onnx_opset._impl.opset_ai_onnx_ml2 import Opset_ai_onnx_ml2 from onnxscript.onnx_opset._impl.opset_ai_onnx_ml3 import Opset_ai_onnx_ml3 from onnxscript.onnx_opset._impl.opset_ai_onnx_ml4 import Opset_ai_onnx_ml4 -from onnxscript.onnx_opset._impl.opset_ai_onnx_preview_training1 import ( - Opset_ai_onnx_preview_training1, -) +from onnxscript.onnx_opset._impl.opset_ai_onnx_ml5 import Opset_ai_onnx_ml5 from onnxscript.values import Opset __all__ = [ @@ -68,11 +68,15 @@ "opset18", "opset19", "opset20", + "opset21", + "opset22", + "opset23", + "opset24", "opset_ai_onnx_ml1", "opset_ai_onnx_ml2", "opset_ai_onnx_ml3", "opset_ai_onnx_ml4", - "opset_ai_onnx_preview_training1", + "opset_ai_onnx_ml5", ] @@ -102,11 +106,15 @@ opset18 = Opset18() opset19 = Opset19() opset20 = Opset20() +opset21 = Opset21() +opset22 = Opset22() +opset23 = Opset23() +opset24 = Opset24() opset_ai_onnx_ml1 = Opset_ai_onnx_ml1() opset_ai_onnx_ml2 = Opset_ai_onnx_ml2() opset_ai_onnx_ml3 = Opset_ai_onnx_ml3() opset_ai_onnx_ml4 = Opset_ai_onnx_ml4() -opset_ai_onnx_preview_training1 = Opset_ai_onnx_preview_training1() +opset_ai_onnx_ml5 = Opset_ai_onnx_ml5() all_opsets: Mapping[Tuple[str, int], Opset] = { ( "", @@ -188,6 +196,22 @@ "", 20, ): opset20, + ( + "", + 21, + ): opset21, + ( + "", + 22, + ): opset22, + ( + "", + 23, + ): opset23, + ( + "", + 24, + ): opset24, ( "ai.onnx.ml", 1, @@ -205,7 +229,7 @@ 4, ): opset_ai_onnx_ml4, ( - "ai.onnx.preview.training", - 1, - ): opset_ai_onnx_preview_training1, + "ai.onnx.ml", + 5, + ): opset_ai_onnx_ml5, } diff --git a/onnxscript/onnx_opset/_impl/opset1.py b/onnxscript/onnx_opset/_impl/opset1.py index 756cc5a150..4af313184d 100644 --- a/onnxscript/onnx_opset/_impl/opset1.py +++ b/onnxscript/onnx_opset/_impl/opset1.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D214, D402, D405, D411, D416, D417 # -------------------------------------------------------------------------- from __future__ import annotations @@ -398,7 +397,18 @@ def BatchNormalization( ) T2_Cast: TypeAlias = Union[ - BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8 + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, ] def Cast(self, input: T1_Cast, *, to: str) -> T2_Cast: @@ -837,7 +847,11 @@ def Dropout( T_Elu = TypeVar("T_Elu", DOUBLE, FLOAT, FLOAT16) def Elu( - self, X: T_Elu, *, alpha: float = 1.0, consumed_inputs: Optional[Sequence[int]] = None + self, + X: T_Elu, + *, + alpha: float = 1.0, + consumed_inputs: Optional[Sequence[int]] = None, ) -> T_Elu: r"""[🌐 Elu(1)](https://onnx.ai/onnx/operators/onnx__Elu.html#elu-1 "Online Documentation") @@ -849,7 +863,7 @@ def Elu( Args: - X: 1D input tensor + X: Input tensor alpha: Coefficient of ELU default to 1.0. @@ -859,7 +873,9 @@ def Elu( schema = get_schema("Elu", 1, "") op = Op(self, "Elu", schema) return op( - *self._prepare_inputs(schema, X), alpha=alpha, consumed_inputs=consumed_inputs + *self._prepare_inputs(schema, X), + alpha=alpha, + consumed_inputs=consumed_inputs, ) T_Equal = TypeVar("T_Equal", BOOL, INT32, INT64) @@ -1338,7 +1354,12 @@ def GlobalMaxPool(self, X: T_GlobalMaxPool) -> T_GlobalMaxPool: T1_Greater: TypeAlias = BOOL def Greater( - self, A: T_Greater, B: T_Greater, *, axis: Optional[int] = None, broadcast: int = 0 + self, + A: T_Greater, + B: T_Greater, + *, + axis: Optional[int] = None, + broadcast: int = 0, ) -> T1_Greater: r"""[🌐 Greater(1)](https://onnx.ai/onnx/operators/onnx__Greater.html#greater-1 "Online Documentation") @@ -1603,7 +1624,11 @@ def LRN( schema = get_schema("LRN", 1, "") op = Op(self, "LRN", schema) return op( - *self._prepare_inputs(schema, X), alpha=alpha, beta=beta, bias=bias, size=size + *self._prepare_inputs(schema, X), + alpha=alpha, + beta=beta, + bias=bias, + size=size, ) T_LSTM = TypeVar("T_LSTM", DOUBLE, FLOAT, FLOAT16) @@ -1822,7 +1847,9 @@ def LeakyRelu( schema = get_schema("LeakyRelu", 1, "") op = Op(self, "LeakyRelu", schema) return op( - *self._prepare_inputs(schema, X), alpha=alpha, consumed_inputs=consumed_inputs + *self._prepare_inputs(schema, X), + alpha=alpha, + consumed_inputs=consumed_inputs, ) T_Less = TypeVar("T_Less", DOUBLE, FLOAT, FLOAT16) @@ -1935,7 +1962,11 @@ def LogSoftmax(self, input: T_LogSoftmax, *, axis: int = 1) -> T_LogSoftmax: ) def Loop( - self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, ) -> V_Loop: r"""[🌐 Loop(1)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-1 "Online Documentation") @@ -1954,7 +1985,7 @@ def Loop( This table summarizes the operating modes of this operator with equivalent C-style code: - Operator inputs defined as (max_trip_count, condition_var). + Operator inputs defined as (max_trip_count, condition_var). input ("", ""): for (int i=0; ; ++i) { @@ -2171,7 +2202,7 @@ def MatMul(self, A: T_MatMul, B: T_MatMul) -> T_MatMul: r"""[🌐 MatMul(1)](https://onnx.ai/onnx/operators/onnx__MatMul.html#matmul-1 "Online Documentation") - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). Args: @@ -2493,7 +2524,11 @@ def Or(self, A: T_Or, B: T_Or, *, axis: Optional[int] = None, broadcast: int = 0 T_PRelu = TypeVar("T_PRelu", DOUBLE, FLOAT, FLOAT16) def PRelu( - self, X: T_PRelu, slope: T_PRelu, *, consumed_inputs: Optional[Sequence[int]] = None + self, + X: T_PRelu, + slope: T_PRelu, + *, + consumed_inputs: Optional[Sequence[int]] = None, ) -> T_PRelu: r"""[🌐 PRelu(1)](https://onnx.ai/onnx/operators/onnx__PRelu.html#prelu-1 "Online Documentation") @@ -2567,7 +2602,10 @@ def Pad( schema = get_schema("Pad", 1, "") op = Op(self, "Pad", schema) return op( - *self._prepare_inputs(schema, data), mode=mode, paddings=paddings, value=value + *self._prepare_inputs(schema, data), + mode=mode, + paddings=paddings, + value=value, ) T_Pow = TypeVar("T_Pow", DOUBLE, FLOAT, FLOAT16) @@ -2975,7 +3013,11 @@ def RandomUniformLike( schema = get_schema("RandomUniformLike", 1, "") op = Op(self, "RandomUniformLike", schema) return op( - *self._prepare_inputs(schema, input), dtype=dtype, high=high, low=low, seed=seed + *self._prepare_inputs(schema, input), + dtype=dtype, + high=high, + low=low, + seed=seed, ) T_Reciprocal = TypeVar("T_Reciprocal", DOUBLE, FLOAT, FLOAT16) @@ -3004,7 +3046,11 @@ def Reciprocal( T_ReduceL1 = TypeVar("T_ReduceL1", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceL1( - self, data: T_ReduceL1, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceL1, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceL1: r"""[🌐 ReduceL1(1)](https://onnx.ai/onnx/operators/onnx__ReduceL1.html#reducel1-1 "Online Documentation") @@ -3034,7 +3080,11 @@ def ReduceL1( T_ReduceL2 = TypeVar("T_ReduceL2", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceL2( - self, data: T_ReduceL2, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceL2, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceL2: r"""[🌐 ReduceL2(1)](https://onnx.ai/onnx/operators/onnx__ReduceL2.html#reducel2-1 "Online Documentation") @@ -3066,7 +3116,11 @@ def ReduceL2( ) def ReduceLogSum( - self, data: T_ReduceLogSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceLogSum, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceLogSum: r"""[🌐 ReduceLogSum(1)](https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html#reducelogsum-1 "Online Documentation") @@ -3132,7 +3186,11 @@ def ReduceLogSumExp( T_ReduceMax = TypeVar("T_ReduceMax", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceMax( - self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMax, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMax: r"""[🌐 ReduceMax(1)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-1 "Online Documentation") @@ -3164,7 +3222,11 @@ def ReduceMax( ) def ReduceMean( - self, data: T_ReduceMean, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMean, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMean: r"""[🌐 ReduceMean(1)](https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-1 "Online Documentation") @@ -3194,7 +3256,11 @@ def ReduceMean( T_ReduceMin = TypeVar("T_ReduceMin", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceMin( - self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMin, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMin: r"""[🌐 ReduceMin(1)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-1 "Online Documentation") @@ -3226,7 +3292,11 @@ def ReduceMin( ) def ReduceProd( - self, data: T_ReduceProd, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceProd, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceProd: r"""[🌐 ReduceProd(1)](https://onnx.ai/onnx/operators/onnx__ReduceProd.html#reduceprod-1 "Online Documentation") @@ -3256,7 +3326,11 @@ def ReduceProd( T_ReduceSum = TypeVar("T_ReduceSum", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceSum( - self, data: T_ReduceSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceSum, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceSum: r"""[🌐 ReduceSum(1)](https://onnx.ai/onnx/operators/onnx__ReduceSum.html#reducesum-1 "Online Documentation") @@ -3371,7 +3445,9 @@ def Reshape( schema = get_schema("Reshape", 1, "") op = Op(self, "Reshape", schema) return op( - *self._prepare_inputs(schema, data), consumed_inputs=consumed_inputs, shape=shape + *self._prepare_inputs(schema, data), + consumed_inputs=consumed_inputs, + shape=shape, ) T_Selu = TypeVar("T_Selu", DOUBLE, FLOAT, FLOAT16) @@ -3538,7 +3614,7 @@ def Slice( Produces a slice of the input tensor along multiple axes. Similar to numpy: - https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html + https://numpy.org/doc/stable/reference/routines.indexing.html Slices uses `axes`, `starts` and `ends` attributes to specify the start and end dimension for each axis in the list of axes, it uses this information to slice the input `data` tensor. If a negative value is passed for any of the @@ -3632,7 +3708,7 @@ def Softplus(self, X: T_Softplus) -> T_Softplus: Args: - X: (differentiable) 1D input tensor + X: (differentiable) Input tensor """ schema = get_schema("Softplus", 1, "") @@ -4019,7 +4095,12 @@ def Unsqueeze(self, data: T_Unsqueeze, *, axes: Sequence[int]) -> T_Unsqueeze: T_Upsample = TypeVar("T_Upsample", BOOL, DOUBLE, FLOAT, FLOAT16, INT32, INT64) def Upsample( - self, X: T_Upsample, *, height_scale: float, mode: str = "nearest", width_scale: float + self, + X: T_Upsample, + *, + height_scale: float, + mode: str = "nearest", + width_scale: float, ) -> T_Upsample: r"""[🌐 Upsample(1)](https://onnx.ai/onnx/operators/onnx__Upsample.html#upsample-1 "Online Documentation") diff --git a/onnxscript/onnx_opset/_impl/opset10.py b/onnxscript/onnx_opset/_impl/opset10.py index 65ea0013e3..ec1734b266 100644 --- a/onnxscript/onnx_opset/_impl/opset10.py +++ b/onnxscript/onnx_opset/_impl/opset10.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations @@ -346,7 +345,7 @@ def MatMulInteger( r"""[🌐 MatMulInteger(10)](https://onnx.ai/onnx/operators/onnx__MatMulInteger.html#matmulinteger-10 "Online Documentation") - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. @@ -749,7 +748,7 @@ def QLinearMatMul( r"""[🌐 QLinearMatMul(10)](https://onnx.ai/onnx/operators/onnx__QLinearMatMul.html#qlinearmatmul-10 "Online Documentation") - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). It consumes two quantized input tensors, their scales and zero points, scale and zero point of output, and computes the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point). For (x / y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. @@ -1067,7 +1066,7 @@ def Slice( Produces a slice of the input tensor along multiple axes. Similar to numpy: - https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html + https://numpy.org/doc/stable/reference/routines.indexing.html Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end dimension and step for each axis in the list of axes, it uses this information to slice the input `data` tensor. If a negative value is passed for any of the diff --git a/onnxscript/onnx_opset/_impl/opset11.py b/onnxscript/onnx_opset/_impl/opset11.py index bb54cbeb02..6538ac3afb 100644 --- a/onnxscript/onnx_opset/_impl/opset11.py +++ b/onnxscript/onnx_opset/_impl/opset11.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: E741, D214, D402, D405, D411, D416 # -------------------------------------------------------------------------- from __future__ import annotations @@ -1465,7 +1464,11 @@ def LogSoftmax(self, input: T_LogSoftmax, *, axis: int = 1) -> T_LogSoftmax: ) def Loop( - self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, ) -> V_Loop: r"""[🌐 Loop(11)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-11 "Online Documentation") @@ -1484,7 +1487,7 @@ def Loop( This table summarizes the operating modes of this operator with equivalent C-style code: - Operator inputs defined as (max_trip_count, condition_var). + Operator inputs defined as (max_trip_count, condition_var). input ("", ""): for (int i=0; ; ++i) { @@ -2238,7 +2241,11 @@ def Range(self, start: T_Range, limit: T_Range, delta: T_Range) -> T_Range: T_ReduceL1 = TypeVar("T_ReduceL1", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceL1( - self, data: T_ReduceL1, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceL1, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceL1: r"""[🌐 ReduceL1(11)](https://onnx.ai/onnx/operators/onnx__ReduceL1.html#reducel1-11 "Online Documentation") @@ -2268,7 +2275,11 @@ def ReduceL1( T_ReduceL2 = TypeVar("T_ReduceL2", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceL2( - self, data: T_ReduceL2, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceL2, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceL2: r"""[🌐 ReduceL2(11)](https://onnx.ai/onnx/operators/onnx__ReduceL2.html#reducel2-11 "Online Documentation") @@ -2300,7 +2311,11 @@ def ReduceL2( ) def ReduceLogSum( - self, data: T_ReduceLogSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceLogSum, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceLogSum: r"""[🌐 ReduceLogSum(11)](https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html#reducelogsum-11 "Online Documentation") @@ -2366,7 +2381,11 @@ def ReduceLogSumExp( T_ReduceMax = TypeVar("T_ReduceMax", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceMax( - self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMax, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMax: r"""[🌐 ReduceMax(11)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-11 "Online Documentation") @@ -2399,7 +2418,11 @@ def ReduceMax( ) def ReduceMean( - self, data: T_ReduceMean, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMean, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMean: r"""[🌐 ReduceMean(11)](https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-11 "Online Documentation") @@ -2429,7 +2452,11 @@ def ReduceMean( T_ReduceMin = TypeVar("T_ReduceMin", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceMin( - self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMin, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMin: r"""[🌐 ReduceMin(11)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-11 "Online Documentation") @@ -2462,7 +2489,11 @@ def ReduceMin( ) def ReduceProd( - self, data: T_ReduceProd, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceProd, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceProd: r"""[🌐 ReduceProd(11)](https://onnx.ai/onnx/operators/onnx__ReduceProd.html#reduceprod-11 "Online Documentation") @@ -2492,7 +2523,11 @@ def ReduceProd( T_ReduceSum = TypeVar("T_ReduceSum", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) def ReduceSum( - self, data: T_ReduceSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceSum, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceSum: r"""[🌐 ReduceSum(11)](https://onnx.ai/onnx/operators/onnx__ReduceSum.html#reducesum-11 "Online Documentation") @@ -3314,7 +3349,9 @@ def SequenceEmpty(self, *, dtype: Optional[int] = None) -> S_SequenceEmpty: I_SequenceErase = TypeVar("I_SequenceErase", INT32, INT64) def SequenceErase( - self, input_sequence: S_SequenceErase, position: Optional[I_SequenceErase] = None + self, + input_sequence: S_SequenceErase, + position: Optional[I_SequenceErase] = None, ) -> S_SequenceErase: r"""[🌐 SequenceErase(11)](https://onnx.ai/onnx/operators/onnx__SequenceErase.html#sequenceerase-11 "Online Documentation") @@ -3481,7 +3518,7 @@ def Slice( Produces a slice of the input tensor along multiple axes. Similar to numpy: - https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html + https://numpy.org/doc/stable/reference/routines.indexing.html Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end dimension and step for each axis in the list of axes, it uses this information to slice the input `data` tensor. If a negative value is passed for any of the @@ -3798,7 +3835,10 @@ def TopK( schema = get_schema("TopK", 11, "") op = Op(self, "TopK", schema) return op( - *self._prepare_inputs(schema, X, K), axis=axis, largest=largest, sorted=sorted + *self._prepare_inputs(schema, X, K), + axis=axis, + largest=largest, + sorted=sorted, ) T_Unique = TypeVar( diff --git a/onnxscript/onnx_opset/_impl/opset12.py b/onnxscript/onnx_opset/_impl/opset12.py index ede4fb34a7..95b2ea83c5 100644 --- a/onnxscript/onnx_opset/_impl/opset12.py +++ b/onnxscript/onnx_opset/_impl/opset12.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations @@ -60,7 +59,12 @@ def __new__(cls): ) def ArgMax( - self, data: T_ArgMax, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0 + self, + data: T_ArgMax, + *, + axis: int = 0, + keepdims: int = 1, + select_last_index: int = 0, ) -> INT64: r"""[🌐 ArgMax(12)](https://onnx.ai/onnx/operators/onnx__ArgMax.html#argmax-12 "Online Documentation") @@ -111,7 +115,12 @@ def ArgMax( ) def ArgMin( - self, data: T_ArgMin, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0 + self, + data: T_ArgMin, + *, + axis: int = 0, + keepdims: int = 1, + select_last_index: int = 0, ) -> INT64: r"""[🌐 ArgMin(12)](https://onnx.ai/onnx/operators/onnx__ArgMin.html#argmin-12 "Online Documentation") @@ -674,7 +683,7 @@ def MaxPool( ``` output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) ``` - if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: ``` @@ -938,7 +947,11 @@ def Pow(self, X: T_Pow, Y: T1_Pow) -> T_Pow: ) def ReduceMax( - self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMax, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMax: r"""[🌐 ReduceMax(12)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-12 "Online Documentation") @@ -970,7 +983,11 @@ def ReduceMax( ) def ReduceMin( - self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMin, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMin: r"""[🌐 ReduceMin(12)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-12 "Online Documentation") diff --git a/onnxscript/onnx_opset/_impl/opset13.py b/onnxscript/onnx_opset/_impl/opset13.py index 616fe5ff69..5403df22cf 100644 --- a/onnxscript/onnx_opset/_impl/opset13.py +++ b/onnxscript/onnx_opset/_impl/opset13.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D214, D402, D405, D411, D416, D417 # -------------------------------------------------------------------------- from __future__ import annotations @@ -116,7 +115,12 @@ def Add(self, A: T_Add, B: T_Add) -> T_Add: ) def ArgMax( - self, data: T_ArgMax, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0 + self, + data: T_ArgMax, + *, + axis: int = 0, + keepdims: int = 1, + select_last_index: int = 0, ) -> INT64: r"""[🌐 ArgMax(13)](https://onnx.ai/onnx/operators/onnx__ArgMax.html#argmax-13 "Online Documentation") @@ -168,7 +172,12 @@ def ArgMax( ) def ArgMin( - self, data: T_ArgMin, *, axis: int = 0, keepdims: int = 1, select_last_index: int = 0 + self, + data: T_ArgMin, + *, + axis: int = 0, + keepdims: int = 1, + select_last_index: int = 0, ) -> INT64: r"""[🌐 ArgMin(13)](https://onnx.ai/onnx/operators/onnx__ArgMin.html#argmin-13 "Online Documentation") @@ -334,6 +343,8 @@ def Clip( Clip operator limits the given input within an interval. The interval is specified by the inputs 'min' and 'max'. They default to numeric_limits::lowest() and numeric_limits::max(), respectively. + When 'min' is greater than 'max', the clip operator sets all the 'input' values to + the value of 'max'. Thus, this is equivalent to 'Min(max, Max(input, min))'. Args: @@ -875,7 +886,22 @@ def Gather(self, data: T_Gather, indices: Tind_Gather, *, axis: int = 0) -> T_Ga entries of the axis dimension of `data` (by default outer-most one as axis=0) indexed by `indices`, and concatenates them in an output tensor of rank q + (r - 1). - If `axis = 0`, let `k = indices[i_{0}, ..., i_{q-1}]` + It is an indexing operation that indexes into the input `data` along a single (specified) axis. + Each entry in `indices` produces a `r-1` dimensional slice of the input tensor. + The entire operation produces, conceptually, a `q`-dimensional tensor of `r-1` dimensional slices, + which is arranged into a `q + (r-1)`-dimensional tensor, with the `q` dimensions taking the + place of the original `axis` that is being indexed into. + + The following few examples illustrate how `Gather` works for specific shapes of `data`, + `indices`, and given value of `axis`: + | data shape | indices shape | axis | output shape | output equation | + | --- | --- | --- | --- | --- | + | (P, Q) | ( ) (a scalar) | 0 | (Q) | output[q] = data[indices, q] | + | (P, Q, R) | ( ) (a scalar) | 1 | (P, R) | output[p, r] = data[p, indices, r] | + | (P, Q) | (R, S) | 0 | (R, S, Q) | output[r, s, q] = data[ [indices[r, s], q] | + | (P, Q) | (R, S) | 1 | (P, R, S) | output[p, r, s] = data[ p, indices[r, s]] | + + More generally, if `axis = 0`, let `k = indices[i_{0}, ..., i_{q-1}]` then `output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[k , j_{0}, ..., j_{r-2}]`: :: @@ -1462,7 +1488,11 @@ def LRN( schema = get_schema("LRN", 13, "") op = Op(self, "LRN", schema) return op( - *self._prepare_inputs(schema, X), alpha=alpha, beta=beta, bias=bias, size=size + *self._prepare_inputs(schema, X), + alpha=alpha, + beta=beta, + bias=bias, + size=size, ) T_Less = TypeVar( @@ -1589,7 +1619,11 @@ def LogSoftmax(self, input: T_LogSoftmax, *, axis: int = -1) -> T_LogSoftmax: ) def Loop( - self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, ) -> V_Loop: r"""[🌐 Loop(13)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-13 "Online Documentation") @@ -1608,7 +1642,7 @@ def Loop( This table summarizes the operating modes of this operator with equivalent C-style code: - Operator inputs defined as (max_trip_count, condition_var). + Operator inputs defined as (max_trip_count, condition_var). input ("", ""): for (int i=0; ; ++i) { @@ -1762,7 +1796,7 @@ def MatMul(self, A: T_MatMul, B: T_MatMul) -> T_MatMul: r"""[🌐 MatMul(13)](https://onnx.ai/onnx/operators/onnx__MatMul.html#matmul-13 "Online Documentation") - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). Args: @@ -1907,19 +1941,23 @@ def Mod(self, A: T_Mod, B: T_Mod, *, fmod: int = 0) -> T_Mod: r"""[🌐 Mod(13)](https://onnx.ai/onnx/operators/onnx__Mod.html#mod-13 "Online Documentation") - Performs element-wise binary modulus (with Numpy-style broadcasting support). - The sign of the remainder is the same as that of the Divisor. - - Mod operator can also behave like C fmod() or numpy.fmod. In this case, the sign of the remainder however, will be the same as the Dividend - (in contrast to integer mod). To force a behavior like numpy.fmod() an 'fmod' Attribute is provided. - This attribute is set to 0 by default causing the behavior to be like integer mod. - Setting this attribute to 1 causes the remainder to be calculated similar to that of numpy.fmod(). + Performs an element-wise binary modulo operation. + The semantics and supported data types depend on the value of the `fmod` attribute which must be `0` (default), or `1`. - If the input type is floating point, then `fmod` attribute must be set to 1. + If the `fmod` attribute is set to `0`, `T` is constrained to integer data types and the semantics follow that of the Python `%`-operator. + The sign of the result is that of the divisor. - In case of dividend being zero, the results will be platform dependent. + If `fmod` is set to `1`, the behavior of this operator follows that of the `fmod` function in C and `T` is constrained to floating point data types. + The result of this operator is the remainder of the division operation `x / y` where `x` and `y` are respective elements of `A` and `B`. The result is exactly the value `x - n * y`, where `n` is `x / y` with its fractional part truncated. + The returned value has the same sign as `x` (except if `x` is `-0`) and is less or equal to `|y|` in magnitude. + The following special cases apply when `fmod` is set to `1`: + - If `x` is `-0` and `y` is greater than zero, either `+0` or `-0` may be returned. + - If `x` is `±∞` and `y` is not `NaN`, `NaN` is returned. + - If `y` is `±0` and `x` is not `NaN`, `NaN` should be returned. + - If `y` is `±∞` and `x` is finite, `x` is returned. + - If either argument is `NaN`, `NaN` is returned. - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check `Broadcasting in ONNX `_. + This operator supports **multidirectional (i.e., NumPy-style) broadcasting**; for more details please check `Broadcasting in ONNX `_. Args: @@ -2414,7 +2452,11 @@ def Reciprocal(self, X: T_Reciprocal) -> T_Reciprocal: ) def ReduceL1( - self, data: T_ReduceL1, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceL1, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceL1: r"""[🌐 ReduceL1(13)](https://onnx.ai/onnx/operators/onnx__ReduceL1.html#reducel1-13 "Online Documentation") @@ -2448,7 +2490,11 @@ def ReduceL1( ) def ReduceL2( - self, data: T_ReduceL2, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceL2, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceL2: r"""[🌐 ReduceL2(13)](https://onnx.ai/onnx/operators/onnx__ReduceL2.html#reducel2-13 "Online Documentation") @@ -2482,7 +2528,11 @@ def ReduceL2( ) def ReduceLogSum( - self, data: T_ReduceLogSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceLogSum, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceLogSum: r"""[🌐 ReduceLogSum(13)](https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html#reducelogsum-13 "Online Documentation") @@ -2512,7 +2562,15 @@ def ReduceLogSum( return op(*self._prepare_inputs(schema, data), axes=axes, keepdims=keepdims) T_ReduceLogSumExp = TypeVar( - "T_ReduceLogSumExp", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64 + "T_ReduceLogSumExp", + BFLOAT16, + DOUBLE, + FLOAT, + FLOAT16, + INT32, + INT64, + UINT32, + UINT64, ) def ReduceLogSumExp( @@ -2564,7 +2622,11 @@ def ReduceLogSumExp( ) def ReduceMax( - self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMax, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMax: r"""[🌐 ReduceMax(13)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-13 "Online Documentation") @@ -2598,7 +2660,11 @@ def ReduceMax( ) def ReduceMean( - self, data: T_ReduceMean, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMean, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMean: r"""[🌐 ReduceMean(13)](https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 "Online Documentation") @@ -2642,7 +2708,11 @@ def ReduceMean( ) def ReduceMin( - self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceMin, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceMin: r"""[🌐 ReduceMin(13)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-13 "Online Documentation") @@ -2676,7 +2746,11 @@ def ReduceMin( ) def ReduceProd( - self, data: T_ReduceProd, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 + self, + data: T_ReduceProd, + *, + axes: Optional[Sequence[int]] = None, + keepdims: int = 1, ) -> T_ReduceProd: r"""[🌐 ReduceProd(13)](https://onnx.ai/onnx/operators/onnx__ReduceProd.html#reduceprod-13 "Online Documentation") @@ -2733,18 +2807,20 @@ def ReduceSum( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceSum", 13, "") @@ -2756,7 +2832,15 @@ def ReduceSum( ) T_ReduceSumSquare = TypeVar( - "T_ReduceSumSquare", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64 + "T_ReduceSumSquare", + BFLOAT16, + DOUBLE, + FLOAT, + FLOAT16, + INT32, + INT64, + UINT32, + UINT64, ) def ReduceSumSquare( diff --git a/onnxscript/onnx_opset/_impl/opset14.py b/onnxscript/onnx_opset/_impl/opset14.py index 21983c8a94..a9ec21f0d8 100644 --- a/onnxscript/onnx_opset/_impl/opset14.py +++ b/onnxscript/onnx_opset/_impl/opset14.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402, D405 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset15.py b/onnxscript/onnx_opset/_impl/opset15.py index 38c235bced..c0758999f0 100644 --- a/onnxscript/onnx_opset/_impl/opset15.py +++ b/onnxscript/onnx_opset/_impl/opset15.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402, D412 # -------------------------------------------------------------------------- from __future__ import annotations @@ -291,36 +290,37 @@ def CastLike(self, input: T1_CastLike, target_type: T2_CastLike) -> T2_CastLike: ) O_Optional: TypeAlias = Union[ - _Optional[Sequence[BOOL]], - _Optional[Sequence[COMPLEX128]], - _Optional[Sequence[COMPLEX64]], - _Optional[Sequence[DOUBLE]], - _Optional[Sequence[FLOAT]], - _Optional[Sequence[FLOAT16]], - _Optional[Sequence[INT16]], - _Optional[Sequence[INT32]], - _Optional[Sequence[INT64]], - _Optional[Sequence[INT8]], - _Optional[Sequence[STRING]], - _Optional[Sequence[UINT16]], - _Optional[Sequence[UINT32]], - _Optional[Sequence[UINT64]], - _Optional[Sequence[UINT8]], - _Optional[BOOL], - _Optional[COMPLEX128], - _Optional[COMPLEX64], - _Optional[DOUBLE], - _Optional[FLOAT], - _Optional[FLOAT16], - _Optional[INT16], - _Optional[INT32], - _Optional[INT64], - _Optional[INT8], - _Optional[STRING], - _Optional[UINT16], - _Optional[UINT32], - _Optional[UINT64], - _Optional[UINT8], + None, + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, ] def Optional( @@ -546,11 +546,11 @@ def Shape(self, data: T_Shape, *, end: _Optional[int] = None, start: int = 0) -> The end axis, if specified, is exclusive (and the returned value will not include the size of that axis). If the end axis is omitted, the axes upto the last one will be included. Negative axes indicate counting back from the last axis. - Note that axes will be clamped to the range [0, r-1], where r is the + Note that axes will be clamped to the range [0, r], where r is the rank of the input tensor if they are out-of-range (after adding r in the case of negative axis). Thus, specifying any end value > r is equivalent to specifying an end value of r, and specifying any start value < -r is equivalent to specifying a start - value of 0. + value of 0. If start > end, the result will be an empty shape. Examples: diff --git a/onnxscript/onnx_opset/_impl/opset16.py b/onnxscript/onnx_opset/_impl/opset16.py index c90392d582..21a92a6026 100644 --- a/onnxscript/onnx_opset/_impl/opset16.py +++ b/onnxscript/onnx_opset/_impl/opset16.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D214, D402, D405, D411, D416 # -------------------------------------------------------------------------- from __future__ import annotations @@ -253,38 +252,7 @@ def Identity(self, input: V_Identity) -> V_Identity: B_If: TypeAlias = BOOL V_If: TypeAlias = Union[ - Optional[Sequence[BFLOAT16]], - Optional[Sequence[BOOL]], - Optional[Sequence[COMPLEX128]], - Optional[Sequence[COMPLEX64]], - Optional[Sequence[DOUBLE]], - Optional[Sequence[FLOAT]], - Optional[Sequence[FLOAT16]], - Optional[Sequence[INT16]], - Optional[Sequence[INT32]], - Optional[Sequence[INT64]], - Optional[Sequence[INT8]], - Optional[Sequence[STRING]], - Optional[Sequence[UINT16]], - Optional[Sequence[UINT32]], - Optional[Sequence[UINT64]], - Optional[Sequence[UINT8]], - Optional[BFLOAT16], - Optional[BOOL], - Optional[COMPLEX128], - Optional[COMPLEX64], - Optional[DOUBLE], - Optional[FLOAT], - Optional[FLOAT16], - Optional[INT16], - Optional[INT32], - Optional[INT64], - Optional[INT8], - Optional[STRING], - Optional[UINT16], - Optional[UINT32], - Optional[UINT64], - Optional[UINT8], + None, Sequence[BFLOAT16], Sequence[BOOL], Sequence[COMPLEX128], @@ -476,7 +444,11 @@ def LessOrEqual(self, A: T_LessOrEqual, B: T_LessOrEqual) -> T1_LessOrEqual: ) def Loop( - self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, ) -> V_Loop: r"""[🌐 Loop(16)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-16 "Online Documentation") diff --git a/onnxscript/onnx_opset/_impl/opset17.py b/onnxscript/onnx_opset/_impl/opset17.py index 80b4b457c0..092658a502 100644 --- a/onnxscript/onnx_opset/_impl/opset17.py +++ b/onnxscript/onnx_opset/_impl/opset17.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset18.py b/onnxscript/onnx_opset/_impl/opset18.py index c4154635d9..a795391355 100644 --- a/onnxscript/onnx_opset/_impl/opset18.py +++ b/onnxscript/onnx_opset/_impl/opset18.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402, D405 # -------------------------------------------------------------------------- from __future__ import annotations @@ -169,12 +168,18 @@ def CenterCropPad( Center crop or pad an input to given dimensions. - The crop/pad dimensions can be specified for a subset of the `axes`. Non-specified dimensions will not be - cropped or padded. + The crop/pad dimensions can be specified for a subset of the `axes`; unspecified dimensions will remain unchanged. - If the input dimensions are bigger than the crop shape, a centered cropping window is extracted from the input. - If the input dimensions are smaller than the crop shape, the input is padded on each side equally, - so that the input is centered in the output. + If the input dimensions are larger than the target crop dimensions, a centered cropping window will be extracted + from the input. The starting value for the cropping window is rounded down, which means that if the difference + between the input shape and the crop shape is odd, the cropping window will be shifted half a pixel to the left + of the input center. + + If the input dimensions are smaller than the target crop dimensions, the input will be padded equally on both sides + to center it in the output. In cases where the total number of padding pixels is odd, an additional pixel will be + added to the right side. + + The padding value used is zero. Args: @@ -286,65 +291,6 @@ def Col2Im( strides=strides, ) - T_GroupNormalization = TypeVar("T_GroupNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16) - - def GroupNormalization( - self, - X: T_GroupNormalization, - scale: T_GroupNormalization, - bias: T_GroupNormalization, - *, - epsilon: float = 9.999999747378752e-06, - num_groups: int, - ) -> T_GroupNormalization: - r"""[🌐 GroupNormalization(18)](https://onnx.ai/onnx/operators/onnx__GroupNormalization.html#groupnormalization-18 "Online Documentation") - - - A GroupNormalization function. Carries out group normalization as described in - the paper https://arxiv.org/abs/1803.08494 - - This operator transforms input according to - :: - - y = scale * (x - mean) / sqrt(variance + epsilon) + bias, - - - where the mean and variance are computed per instance per group of channels, and - `scale` and `bias` should be specified for each group of channels. The number of - groups `num_groups` should be divisible by the number of channels so that there are - an equal number of channels per group. - - When the number of groups is the same as the number of channels, this operator is - equivalent to InstanceNormalization. When there is only one group, this operator - is equivalent to LayerNormalization. - - - Args: - X: (differentiable) Input data tensor. Dimensions for image cases are `(N x - C x H x W)`, where `N` is the batch size, `C` is the number of channels, - and `H` and `W` are the height and width of the data. Statistics are - computed for every group of channels over `C`, `H`, and `W`. For - non-image cases, the dimensions are in the form of `(N x C x D1 x D2 ... - Dn)`. - - scale: (differentiable) Scale tensor of shape `(num_groups)`. - - bias: (differentiable) Bias tensor of shape `(num_groups)`. - - epsilon: The epsilon value to use to avoid division by zero. - - num_groups: The number of groups of channels. It should be a divisor of the - number of channels `C`. - """ - - schema = get_schema("GroupNormalization", 18, "") - op = Op(self, "GroupNormalization", schema) - return op( - *self._prepare_inputs(schema, X, scale, bias), - epsilon=epsilon, - num_groups=num_groups, - ) - T_LpPool = TypeVar("T_LpPool", DOUBLE, FLOAT, FLOAT16) def LpPool( @@ -838,18 +784,20 @@ def ReduceL1( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceL1", 18, "") @@ -888,18 +836,20 @@ def ReduceL2( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceL2", 18, "") @@ -938,18 +888,20 @@ def ReduceLogSum( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceLogSum", 18, "") @@ -961,7 +913,15 @@ def ReduceLogSum( ) T_ReduceLogSumExp = TypeVar( - "T_ReduceLogSumExp", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64 + "T_ReduceLogSumExp", + BFLOAT16, + DOUBLE, + FLOAT, + FLOAT16, + INT32, + INT64, + UINT32, + UINT64, ) def ReduceLogSumExp( @@ -988,18 +948,20 @@ def ReduceLogSumExp( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceLogSumExp", 18, "") @@ -1048,18 +1010,20 @@ def ReduceMax( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceMax", 18, "") @@ -1098,18 +1062,20 @@ def ReduceMean( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceMean", 18, "") @@ -1158,18 +1124,20 @@ def ReduceMin( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceMin", 18, "") @@ -1208,18 +1176,20 @@ def ReduceProd( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceProd", 18, "") @@ -1231,7 +1201,15 @@ def ReduceProd( ) T_ReduceSumSquare = TypeVar( - "T_ReduceSumSquare", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64 + "T_ReduceSumSquare", + BFLOAT16, + DOUBLE, + FLOAT, + FLOAT16, + INT32, + INT64, + UINT32, + UINT64, ) def ReduceSumSquare( @@ -1258,18 +1236,20 @@ def ReduceSumSquare( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceSumSquare", 18, "") @@ -1434,13 +1414,13 @@ def Resize( keeping the original aspect ratio:
`scale = Min(sizes[i] / in_size[d])`
- `out_size[d] = round_int(scale * in_size[i])`
+ `out_size[d] = round_int(scale * in_size[d])`
If `keep_aspect_ratio_policy` is `"not_smaller"`, the sizes are adjusted so that no extent of the output is smaller than the specified size, while keeping the original aspect ratio:
`scale = Max(sizes[i] / in_size[d])`
- `out_size[d] = round_int(scale * in_size[i])`
+ `out_size[d] = round_int(scale * in_size[d])`
For non-resizable axes (those not specified in `axes`), the output size will be equal to the input size. @@ -1799,5 +1779,7 @@ def Split( schema = get_schema("Split", 18, "") op = Op(self, "Split", schema) return op( - *self._prepare_inputs(schema, input, split), axis=axis, num_outputs=num_outputs + *self._prepare_inputs(schema, input, split), + axis=axis, + num_outputs=num_outputs, ) diff --git a/onnxscript/onnx_opset/_impl/opset19.py b/onnxscript/onnx_opset/_impl/opset19.py index 467c23917e..18a7cba17a 100644 --- a/onnxscript/onnx_opset/_impl/opset19.py +++ b/onnxscript/onnx_opset/_impl/opset19.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D214, D402, D405, D411, D412, D416 # -------------------------------------------------------------------------- from __future__ import annotations @@ -80,7 +79,7 @@ def AveragePool( ``` output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) ``` - if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: ``` @@ -245,28 +244,31 @@ def Cast(self, input: T1_Cast, *, saturate: int = 1, to: int) -> T2_Cast: to the following rules. `[x]` means the value rounded to the target mantissa width. - | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | - |------|----|----|----|----| - | 0 | 0 | 0 | 0 | 0 | - |-0 | -0 | 0 | -0 | 0 | - | NaN | NaN | NaN | NaN | NaN | - | +/- Inf | +/- FLT_MAX | NaN | FLT_MAX | NaN | - | [x] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | - | [x] < -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | - | else | RNE | RNE | RNE | RNE | + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | -------- | -------- | -------- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | Inf | FLT_MAX | NaN | FLT_MAX | NaN | + | -Inf | -FLT_MAX | NaN | -FLT_MAX | NaN | + | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | + | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | + | else | RNE | RNE | RNE | RNE | The behavior changes if the parameter 'saturate' is set to False. The rules then become: - | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | - |------|----|----|----|----| - | 0 | 0 | 0 | 0 | 0 | - |-0 | -0 | 0 | -0 | 0 | - | NaN | NaN | NaN | NaN | NaN | - | +/- Inf | NaN | NaN | +/- Inf | NaN | - | [x] > FLT_MAX | NaN | NaN | Inf | NaN | - | [x] < -FLT_MAX | NaN | NaN | -Inf | NaN | - | else | RNE | RNE | RNE | RNE | + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | ------ | -------- | ---- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | -NaN | -NaN | NaN | -NaN | NaN | + | Inf | NaN | NaN | Inf | NaN | + | -Inf | -NaN | NaN | -Inf | NaN | + | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN | + | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN | + | else | RNE | RNE | RNE | RNE | Args: @@ -566,9 +568,10 @@ def DequantizeLinear( It's optional. Zero point is 0 when it's not specified. axis: (Optional) The axis of the dequantizing dimension of the input tensor. - Ignored for per-tensor quantization. Negative value means counting - dimensions from the back. Accepted range is [-r, r-1] where r = - rank(input). + Used only for per-axis quantization. Negative value means counting + dimensions from the back. Accepted range is `[-r, r-1]` where `r = + rank(input)`. When the rank of the input is 1, per-tensor quantization + is applied, rendering the axis unnecessary in this scenario. """ schema = get_schema("DequantizeLinear", 19, "") @@ -700,42 +703,7 @@ def Identity(self, input: V_Identity) -> V_Identity: B_If: TypeAlias = BOOL V_If: TypeAlias = Union[ - Optional[Sequence[BFLOAT16]], - Optional[Sequence[BOOL]], - Optional[Sequence[COMPLEX128]], - Optional[Sequence[COMPLEX64]], - Optional[Sequence[DOUBLE]], - Optional[Sequence[FLOAT]], - Optional[Sequence[FLOAT16]], - Optional[Sequence[INT16]], - Optional[Sequence[INT32]], - Optional[Sequence[INT64]], - Optional[Sequence[INT8]], - Optional[Sequence[STRING]], - Optional[Sequence[UINT16]], - Optional[Sequence[UINT32]], - Optional[Sequence[UINT64]], - Optional[Sequence[UINT8]], - Optional[BFLOAT16], - Optional[BOOL], - Optional[COMPLEX128], - Optional[COMPLEX64], - Optional[DOUBLE], - Optional[FLOAT], - Optional[FLOAT16], - Optional[FLOAT8E4M3FN], - Optional[FLOAT8E4M3FNUZ], - Optional[FLOAT8E5M2], - Optional[FLOAT8E5M2FNUZ], - Optional[INT16], - Optional[INT32], - Optional[INT64], - Optional[INT8], - Optional[STRING], - Optional[UINT16], - Optional[UINT32], - Optional[UINT64], - Optional[UINT8], + None, Sequence[BFLOAT16], Sequence[BOOL], Sequence[COMPLEX128], @@ -743,10 +711,6 @@ def Identity(self, input: V_Identity) -> V_Identity: Sequence[DOUBLE], Sequence[FLOAT], Sequence[FLOAT16], - Sequence[FLOAT8E4M3FN], - Sequence[FLOAT8E4M3FNUZ], - Sequence[FLOAT8E5M2], - Sequence[FLOAT8E5M2FNUZ], Sequence[INT16], Sequence[INT32], Sequence[INT64], @@ -776,6 +740,10 @@ def Identity(self, input: V_Identity) -> V_Identity: UINT32, UINT64, UINT8, + Sequence[FLOAT8E4M3FN], + Sequence[FLOAT8E4M3FNUZ], + Sequence[FLOAT8E5M2], + Sequence[FLOAT8E5M2FNUZ], ] def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If: @@ -888,7 +856,11 @@ def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> ) def Loop( - self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, ) -> V_Loop: r"""[🌐 Loop(19)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-19 "Online Documentation") @@ -1546,7 +1518,7 @@ def Resize( ``` scale = Min(sizes[i] / in_size[d]) - out_size[d] = round_int(scale * in_size[i]) + out_size[d] = round_int(scale * in_size[d]) ``` If @@ -1556,7 +1528,7 @@ def Resize( ``` scale = Max(sizes[i] / in_size[d]) - out_size[d] = round_int(scale * in_size[i]) + out_size[d] = round_int(scale * in_size[d]) ``` For @@ -1842,11 +1814,11 @@ def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) -> The end axis, if specified, is exclusive (and the returned value will not include the size of that axis). If the end axis is omitted, the axes upto the last one will be included. Negative axes indicate counting back from the last axis. - Note that axes will be clamped to the range [0, r-1], where r is the + Note that axes will be clamped to the range [0, r], where r is the rank of the input tensor if they are out-of-range (after adding r in the case of negative axis). Thus, specifying any end value > r is equivalent to specifying an end value of r, and specifying any start value < -r is equivalent to specifying a start - value of 0. + value of 0. If start > end, the result will be an empty shape. Examples: diff --git a/onnxscript/onnx_opset/_impl/opset2.py b/onnxscript/onnx_opset/_impl/opset2.py index e04537c5f4..a4a0e7f291 100644 --- a/onnxscript/onnx_opset/_impl/opset2.py +++ b/onnxscript/onnx_opset/_impl/opset2.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402, D411 # -------------------------------------------------------------------------- from __future__ import annotations @@ -132,7 +131,12 @@ def LpPool( T_Pad = TypeVar("T_Pad", DOUBLE, FLOAT, FLOAT16) def Pad( - self, data: T_Pad, *, mode: str = "constant", pads: Sequence[int], value: float = 0.0 + self, + data: T_Pad, + *, + mode: str = "constant", + pads: Sequence[int], + value: float = 0.0, ) -> T_Pad: r"""[🌐 Pad(2)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-2 "Online Documentation") diff --git a/onnxscript/onnx_opset/_impl/opset20.py b/onnxscript/onnx_opset/_impl/opset20.py index e05b5018a4..2f3f264c2a 100644 --- a/onnxscript/onnx_opset/_impl/opset20.py +++ b/onnxscript/onnx_opset/_impl/opset20.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations @@ -513,18 +512,20 @@ def ReduceMax( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceMax", 20, "") @@ -576,18 +577,20 @@ def ReduceMin( data: (differentiable) An input tensor. axes: (optional, non-differentiable) Optional input list of integers, along - which to reduce. The default is to reduce over all the dimensions of the - input tensor if 'noop_with_empty_axes' is false, else act as an Identity - op when 'noop_with_empty_axes' is true. Accepted range is [-r, r-1] - where r = rank(data). + which to reduce. The default is to reduce over empty axes. When axes is + empty (either not provided or explicitly empty), behavior depends on + 'noop_with_empty_axes': reduction over all axes if + 'noop_with_empty_axes' is false, or no reduction is applied if + 'noop_with_empty_axes' is true (but other operations will be performed). + Accepted range is [-r, r-1] where r = rank(data). keepdims: Keep the reduced dimension or not, default 1 means keep reduced dimension. - noop_with_empty_axes: Defines behavior if 'axes' is empty. Default behavior - with 'false' is to reduce all axes. When axes is empty and this - attribute is set to true, input tensor will not be reduced,and the - output tensor would be equivalent to input tensor. + noop_with_empty_axes: Defines behavior when axes is not provided or is + empty. If false (default), reduction happens over all axes. If true, no + reduction is applied, but other operations will be performed. For + example, ReduceSumSquare acts as a vanilla Square. """ schema = get_schema("ReduceMin", 20, "") diff --git a/onnxscript/onnx_opset/_impl/opset21.py b/onnxscript/onnx_opset/_impl/opset21.py new file mode 100644 index 0000000000..b0ae5a2e9c --- /dev/null +++ b/onnxscript/onnx_opset/_impl/opset21.py @@ -0,0 +1,1940 @@ +# -------------------------------------------------------------------------- +# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ +# ⚙️ Generated by 'python -m opgen' +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=W0221,W0222,R0901,W0237 +# mypy: disable-error-code=override +# ruff: noqa: D214, D402, D405, D411, D412, D416 +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Optional, Sequence, TypeVar, Union + +from onnx import GraphProto, SparseTensorProto, TensorProto +from onnx.defs import get_schema +from typing_extensions import TypeAlias + +from onnxscript.onnx_opset._impl.opset20 import Opset20 +from onnxscript.onnx_types import ( + BFLOAT16, + BOOL, + COMPLEX64, + COMPLEX128, + DOUBLE, + FLOAT, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT16, + INT4, + INT8, + INT16, + INT32, + INT64, + STRING, + UINT4, + UINT8, + UINT16, + UINT32, + UINT64, +) +from onnxscript.values import Op, Opset + + +class Opset21(Opset20): + def __new__(cls): + return Opset.__new__(cls, "", 21) + + T1_Cast = TypeVar( + "T1_Cast", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T2_Cast: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def Cast(self, input: T1_Cast, *, saturate: int = 1, to: int) -> T2_Cast: + r"""[🌐 Cast(21)](https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 "Online Documentation") + + + The operator casts the elements of a given input tensor to a data type + specified by the 'to' argument and returns an output tensor of the same size in + the converted type. The 'to' argument must be one of the data types specified + in the 'DataType' enum field in the TensorProto message. + + Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations + (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may + yield result 100. There are some string literals reserved for special floating-point values; + "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively. + Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly, + this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors + to string tensors, plain floating-point representation (such as "314.15926") would be used. + Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases + of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior. + + Conversion from a numerical type to any numerical type is always allowed. + User must be aware of precision loss and value change caused by range difference between two types. + For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting + an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type. + + In more detail, the conversion among numerical types should follow these rules + if the destination type is not a float 8 type. + + * Casting from floating point to: + * floating point: +/- infinity if OOR (out of range). + * fixed point: undefined if OOR. + * bool: +/- 0.0 to False; all else to True. + * Casting from fixed point to: + * floating point: +/- infinity if OOR. (+ infinity in the case of uint) + * fixed point: when OOR, discard higher bits and reinterpret (with respect to two's complement representation for + signed types). For example, 200 (int16) -> -56 (int8). + * bool: zero to False; nonzero to True. + * Casting from bool to: + * floating point: `{1.0, 0.0}`. + * fixed point: `{1, 0}`. + * bool: no change. + + Float 8 type were introduced to speed up the training of + deep models. By default the conversion of a float *x* obeys + to the following rules. `[x]` means the value rounded to + the target mantissa width. + + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | -------- | -------- | -------- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | Inf | FLT_MAX | NaN | FLT_MAX | NaN | + | -Inf | -FLT_MAX | NaN | -FLT_MAX | NaN | + | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | + | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | + | else | RNE | RNE | RNE | RNE | + + The behavior changes if the parameter 'saturate' is set to False. + The rules then become: + + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | ------ | -------- | ---- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | -NaN | -NaN | NaN | -NaN | NaN | + | Inf | NaN | NaN | Inf | NaN | + | -Inf | -NaN | NaN | -Inf | NaN | + | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN | + | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN | + | else | RNE | RNE | RNE | RNE | + + + Args: + input: (differentiable) Input tensor to be cast. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). + It is true by default. All cases are fully described in two tables + inserted in the operator description. + + to: The data type to which the elements of the input tensor are cast. + Strictly must be one of the types from DataType enum in TensorProto + """ + + schema = get_schema("Cast", 21, "") + op = Op(self, "Cast", schema) + return op(*self._prepare_inputs(schema, input), saturate=saturate, to=to) + + T1_CastLike = TypeVar( + "T1_CastLike", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T2_CastLike = TypeVar( + "T2_CastLike", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def CastLike( + self, input: T1_CastLike, target_type: T2_CastLike, *, saturate: int = 1 + ) -> T2_CastLike: + r"""[🌐 CastLike(21)](https://onnx.ai/onnx/operators/onnx__CastLike.html#castlike-21 "Online Documentation") + + + The operator casts the elements of a given input tensor (the first input) to + the same data type as the elements of the second input tensor. + See documentation of the Cast operator for further details. + + + Args: + input: (differentiable) Input tensor to be cast. + + target_type: (non-differentiable) The (first) input tensor will be cast to + produce a tensor of the same type as this (second input) tensor. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). + It is true by default. Please refer to operator Cast description for + further details. + """ + + schema = get_schema("CastLike", 21, "") + op = Op(self, "CastLike", schema) + return op(*self._prepare_inputs(schema, input, target_type), saturate=saturate) + + T_Constant: TypeAlias = Union[ + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def Constant( + self, + *, + sparse_value: Optional[SparseTensorProto] = None, + value: Optional[TensorProto] = None, + value_float: Optional[float] = None, + value_floats: Optional[Sequence[float]] = None, + value_int: Optional[int] = None, + value_ints: Optional[Sequence[int]] = None, + value_string: Optional[str] = None, + value_strings: Optional[Sequence[str]] = None, + ) -> T_Constant: + r"""[🌐 Constant(21)](https://onnx.ai/onnx/operators/onnx__Constant.html#constant-21 "Online Documentation") + + + This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value, + or value_* must be specified. + + + Args: + sparse_value: The value for the elements of the output tensor in sparse + format. + + value: The value for the elements of the output tensor. + + value_float: The value for the sole element for the scalar, float32, output + tensor. + + value_floats: The values for the elements for the 1D, float32, output + tensor. + + value_int: The value for the sole element for the scalar, int64, output + tensor. + + value_ints: The values for the elements for the 1D, int64, output tensor. + + value_string: The value for the sole element for the scalar, UTF-8 string, + output tensor. + + value_strings: The values for the elements for the 1D, UTF-8 string, output + tensor. + """ + + schema = get_schema("Constant", 21, "") + op = Op(self, "Constant", schema) + return op( + sparse_value=sparse_value, + value=value, + value_float=value_float, + value_floats=value_floats, + value_int=value_int, + value_ints=value_ints, + value_string=value_string, + value_strings=value_strings, + ) + + T1_ConstantOfShape: TypeAlias = INT64 + + T2_ConstantOfShape: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def ConstantOfShape( + self, input: T1_ConstantOfShape, *, value: Optional[TensorProto] = None + ) -> T2_ConstantOfShape: + r"""[🌐 ConstantOfShape(21)](https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html#constantofshape-21 "Online Documentation") + + + Generate a tensor with given value and shape. + + + Args: + input: 1D tensor. The shape of the expected output tensor. If empty tensor + is given, the output would be a scalar. All values must be >= 0. + + value: (Optional) The value of the output elements.Should be a one-element + tensor. If not specified, it defaults to a tensor of value 0 and + datatype float32 + """ + + schema = get_schema("ConstantOfShape", 21, "") + op = Op(self, "ConstantOfShape", schema) + return op(*self._prepare_inputs(schema, input), value=value) + + T1_DequantizeLinear = TypeVar( + "T1_DequantizeLinear", + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT8, + UINT16, + UINT4, + UINT8, + ) + + T2_DequantizeLinear = TypeVar("T2_DequantizeLinear", BFLOAT16, FLOAT, FLOAT16) + + def DequantizeLinear( + self, + x: T1_DequantizeLinear, + x_scale: T2_DequantizeLinear, + x_zero_point: Optional[T1_DequantizeLinear] = None, + *, + axis: int = 1, + block_size: int = 0, + ) -> T2_DequantizeLinear: + r"""[🌐 DequantizeLinear(21)](https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html#dequantizelinear-21 "Online Documentation") + + + The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the + full-precision tensor. The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point` + must have the same shape, determining the quantization's granularity: a scalar for per-tensor/per-layer quantization, + a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization. + See QuantizeLinear for details on quantization granularity. + `x_zero_point` and `x` must have the same type. `x` and `y` must have the same shape. In the case of dequantizing + `int32`, there's no zero point (zero point is supposed to be 0). + `zero-point` is usually not used in the case of float8 types quantization, but the dequantization formula remains the same + for consistency, and `x_scale` still determines the output type. + + + Args: + x: N-D quantized input tensor to be de-quantized. + + x_scale: Scale for input `x`. For per-tensor/layer dequantization the scale + is a scalar, for per per-axis dequantization it is a 1-D Tensor and for + blocked dequantization it has the same shape as the input, except for + one dimension in which blocking is performed. + + x_zero_point: (optional) Zero point for input `x`. Shape must match x_scale. + It's optional. Zero point is 0 when it's not specified. + + axis: (Optional) The axis of the dequantizing dimension of the input tensor. + Used for per-axis and blocked quantization. Negative value means + counting dimensions from the back. Accepted range is `[-r, r-1]` where + `r = rank(input)`. + + block_size: (Optional) The size of the quantization block (number of times + every scale is replicated). Used only for blocked quantization. The + block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., + Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted + range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]` + """ + + schema = get_schema("DequantizeLinear", 21, "") + op = Op(self, "DequantizeLinear", schema) + return op( + *self._prepare_inputs(schema, x, x_scale, x_zero_point), + axis=axis, + block_size=block_size, + ) + + T_Flatten = TypeVar( + "T_Flatten", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Flatten(self, input: T_Flatten, *, axis: int = 1) -> T_Flatten: + r"""[🌐 Flatten(21)](https://onnx.ai/onnx/operators/onnx__Flatten.html#flatten-21 "Online Documentation") + + + Flattens the input tensor into a 2D matrix. If input tensor has shape + (d_0, d_1, ... d_n) then the output will have shape + (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn). + + + Args: + input: (differentiable) A tensor of rank >= axis. + + axis: Indicate up to which input dimensions (exclusive) should be flattened + to the outer dimension of the output. The value for axis must be in the + range [-r, r], where r is the rank of the input tensor. Negative value + means counting dimensions from the back. When axis = 0, the shape of the + output tensor is (1, (d_0 X d_1 ... d_n), where the shape of the input + tensor is (d_0, d_1, ... d_n). + """ + + schema = get_schema("Flatten", 21, "") + op = Op(self, "Flatten", schema) + return op(*self._prepare_inputs(schema, input), axis=axis) + + T_GroupNormalization = TypeVar("T_GroupNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def GroupNormalization( + self, + X: T_GroupNormalization, + scale: T_GroupNormalization, + bias: T_GroupNormalization, + *, + epsilon: float = 9.999999747378752e-06, + num_groups: int, + stash_type: int = 1, + ) -> T_GroupNormalization: + r"""[🌐 GroupNormalization(21)](https://onnx.ai/onnx/operators/onnx__GroupNormalization.html#groupnormalization-21 "Online Documentation") + + + A GroupNormalization function. Carries out group normalization as described in + the paper https://arxiv.org/abs/1803.08494 + + This operator transforms input according to + :: + + y = scale * (x - mean) / sqrt(variance + epsilon) + bias, + + + where the mean and variance are computed per instance per group of channels, and + `scale` and `bias` should be specified for each channel. The number of + groups `num_groups` should be divisible by the number of channels so that there are + an equal number of channels per group. + + The overall computation has two stages: the first stage normalizes the elements to + have zero mean and unit variance for each instance in each group, and the second + stage scales and shifts the results of the first stage. The floating-point precision + used in the first stage is determined by the `stash_type` attribute. For example, + if `stash_type` is 1, the operator casts all input variables to 32-bit float, + performs the computation, and finally casts the normalized results back to the + original type of `X`. The second stage does not depend on `stash_type`. + + When the number of groups is the same as the number of channels, this operator is + equivalent to InstanceNormalization. When there is only one group, this operator + is equivalent to LayerNormalization. + + + Args: + X: (differentiable) Input data tensor. Dimensions for image cases are `(N x + C x H x W)`, where `N` is the batch size, `C` is the number of channels, + and `H` and `W` are the height and width of the data. Statistics are + computed for every group of channels over `C`, `H`, and `W`. For + non-image cases, the dimensions are in the form of `(N x C x D1 x D2 ... + Dn)`. + + scale: (differentiable) Scale tensor of shape `(C)`. + + bias: (differentiable) Bias tensor of shape `(C)`. + + epsilon: The epsilon value to use to avoid division by zero. + + num_groups: The number of groups of channels. It should be a divisor of the + number of channels `C`. + + stash_type: The floating-point precision used in stage one of the + computation. + """ + + schema = get_schema("GroupNormalization", 21, "") + op = Op(self, "GroupNormalization", schema) + return op( + *self._prepare_inputs(schema, X, scale, bias), + epsilon=epsilon, + num_groups=num_groups, + stash_type=stash_type, + ) + + V_Identity = TypeVar( + "V_Identity", + Optional[Sequence[BOOL]], + Optional[Sequence[COMPLEX128]], + Optional[Sequence[COMPLEX64]], + Optional[Sequence[DOUBLE]], + Optional[Sequence[FLOAT]], + Optional[Sequence[FLOAT16]], + Optional[Sequence[INT16]], + Optional[Sequence[INT32]], + Optional[Sequence[INT64]], + Optional[Sequence[INT8]], + Optional[Sequence[STRING]], + Optional[Sequence[UINT16]], + Optional[Sequence[UINT32]], + Optional[Sequence[UINT64]], + Optional[Sequence[UINT8]], + Optional[BOOL], + Optional[COMPLEX128], + Optional[COMPLEX64], + Optional[DOUBLE], + Optional[FLOAT], + Optional[FLOAT16], + Optional[INT16], + Optional[INT32], + Optional[INT64], + Optional[INT8], + Optional[STRING], + Optional[UINT16], + Optional[UINT32], + Optional[UINT64], + Optional[UINT8], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Identity(self, input: V_Identity) -> V_Identity: + r"""[🌐 Identity(21)](https://onnx.ai/onnx/operators/onnx__Identity.html#identity-21 "Online Documentation") + + Identity operator + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Identity", 21, "") + op = Op(self, "Identity", schema) + return op(*self._prepare_inputs(schema, input)) + + B_If: TypeAlias = BOOL + + V_If: TypeAlias = Union[ + None, + Sequence[BFLOAT16], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + Sequence[FLOAT8E4M3FN], + Sequence[FLOAT8E4M3FNUZ], + Sequence[FLOAT8E5M2], + Sequence[FLOAT8E5M2FNUZ], + Sequence[INT4], + Sequence[UINT4], + ] + + def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If: + r"""[🌐 If(21)](https://onnx.ai/onnx/operators/onnx__If.html#if-21 "Online Documentation") + + If conditional + + Args: + cond: Condition for the if. The tensor must contain a single element. + + else_branch: Graph to run if condition is false. Has N outputs: values you + wish to be live-out to the enclosing scope. The number of outputs must + match the number of outputs in the then_branch. + + then_branch: Graph to run if condition is true. Has N outputs: values you + wish to be live-out to the enclosing scope. The number of outputs must + match the number of outputs in the else_branch. + """ + + schema = get_schema("If", 21, "") + op = Op(self, "If", schema) + return op( + *self._prepare_inputs(schema, cond), + else_branch=else_branch, + then_branch=then_branch, + ) + + I_Loop: TypeAlias = INT64 + + B_Loop: TypeAlias = BOOL + + V_Loop = TypeVar( + "V_Loop", + Optional[Sequence[BFLOAT16]], + Optional[Sequence[BOOL]], + Optional[Sequence[COMPLEX128]], + Optional[Sequence[COMPLEX64]], + Optional[Sequence[DOUBLE]], + Optional[Sequence[FLOAT]], + Optional[Sequence[FLOAT16]], + Optional[Sequence[INT16]], + Optional[Sequence[INT32]], + Optional[Sequence[INT64]], + Optional[Sequence[INT8]], + Optional[Sequence[STRING]], + Optional[Sequence[UINT16]], + Optional[Sequence[UINT32]], + Optional[Sequence[UINT64]], + Optional[Sequence[UINT8]], + Optional[BFLOAT16], + Optional[BOOL], + Optional[COMPLEX128], + Optional[COMPLEX64], + Optional[DOUBLE], + Optional[FLOAT], + Optional[FLOAT16], + Optional[FLOAT8E4M3FN], + Optional[FLOAT8E4M3FNUZ], + Optional[FLOAT8E5M2], + Optional[FLOAT8E5M2FNUZ], + Optional[INT16], + Optional[INT32], + Optional[INT4], + Optional[INT64], + Optional[INT8], + Optional[STRING], + Optional[UINT16], + Optional[UINT32], + Optional[UINT4], + Optional[UINT64], + Optional[UINT8], + Sequence[BFLOAT16], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[FLOAT8E4M3FN], + Sequence[FLOAT8E4M3FNUZ], + Sequence[FLOAT8E5M2], + Sequence[FLOAT8E5M2FNUZ], + Sequence[INT16], + Sequence[INT32], + Sequence[INT4], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT4], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Loop( + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, + ) -> V_Loop: + r"""[🌐 Loop(21)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-21 "Online Documentation") + + + Generic Looping construct. This loop has multiple termination conditions: + + 1) Trip count. Iteration count specified at runtime. Set by + specifying the input M. Optional. Set to empty string to omit. + Note that a static trip count (specified at graph construction time) can be + specified by passing in a constant node for input M. + 2) Loop termination condition. This is an input to the op that determines + whether to run the first iteration and also a loop-carried dependency for + the body graph. The body graph must yield a value for the condition variable, + whether this input is provided or not. + + This table summarizes the operating modes of this operator with equivalent + C-style code: + + Operator inputs defined as (max_trip_count, condition_var). + + * input ("", ""): + for (int i=0; ; ++i) { + cond = ... // Note this value is ignored, but is required in the body + } + + * input ("", cond) // Note this is analogous to a while loop + bool cond = ...; + for (int i=0; cond; ++i) { + cond = ...; + } + + * input ("", 1) // Note this is analogous to a do-while loop + bool cond = true + for (int i=0; cond; ++i) { + cond = ...; + } + + * input (trip_count, "") // Note this is analogous to a for loop + int trip_count = ... + for (int i=0; i < trip_count; ++i) { + cond = ...; // ignored + } + + * input (trip_count, cond) + int trip_count = ...; + bool cond = ...; + for (int i=0; i < trip_count && cond; ++i) { + cond = ...; + } + + + *Sample usage - cond as well as trip count* + + graph predict-net { + %a = Constant[value = ]() + %b = Constant[value = ]() + %keepgoing = Constant[value = ]() + %max_trip_count = Constant[value = ]() + %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b) + return + } + + graph body-net ( + %i[INT32, scalar] // iteration number + %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used + %b_in[INT32, scalar] // incoming value of loop-carried-dependency b + ) { + %my_local = Add(%a, %b_in) + %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b + %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition + %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated + return %keepgoing_out, %b_out, %user_defined_val + } + + *Sample equivalent C code* + + { + /* User-defined code (enclosing scope) */ + int a = 3, b = 6; + bool keepgoing = true; // Analogous to input cond + /* End user-defined code */ + + /* Implicitly-defined code */ + const int max_trip_count = 10; // Analogous to input M + int user_defined_vals[]; // Imagine this is resizable + /* End implicitly-defined code */ + /* initialize loop-carried variables and scan-output variables */ + bool keepgoing_out = keepgoing + int b_out = b + + for (int i=0; i < max_trip_count && keepgoing_out; ++i) { + /* Implicitly-defined code: bind actual parameter values + to formal parameter variables of loop-body */ + bool keepgoing_in = keepgoing_out; + bool b_in = b_out; + + /* User-defined code (loop body) */ + int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine + b_out = a - b_in; + keepgoing_out = my_local > b_out; + user_defined_val = b_in + b_in; // b_in and b_out are different variables + /* End user-defined code */ + + /* Implicitly defined-code */ + user_defined_vals[i] = user_defined_val // accumulate scan-output values + } + // int t = my_local; // Can't do this. my_local is not accessible here. + + // The values below are bound to the output variables of the loop and therefore accessible + // b_out; user_defined_vals; keepgoing_out; + } + + There are several things of note in this code snippet: + + 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can + be referenced in the inputs of the loop. + 2) Any values computed in the loop body that needs to be used in a subsequent + iteration or after the loop are modelled using a pair of variables in the loop-body, + consisting of an input variable (eg., b_in) and an output variable (eg., b_out). + These are referred to as loop-carried dependences. The loop operation node + supplies the input value of the input variable for the first iteration, and + returns the output value of the output variable produced by the final + iteration. + 3) Scan_output variables are used to implicitly concatenate values computed across + all the iterations. In the above example, the value of user_defined_val computed + over all iterations are concatenated and returned as the value of user_defined_vals + after the loop. + 4) Values created in the body cannot be accessed in the enclosing scope, + except using the mechanism described above. + + Note that the semantics of this op support "diagonal" or "wavefront" execution. + (See Step 3 here for an example: + https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/). + Frontends should emit multi-layer RNNs as a series of While operators (with + time being the inner looping dimension), with each successive layer consuming + the scan_outputs from the previous layer, possibly going through several + point-wise operators (e.g. dropout, residual connections, linear layer). + + The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order. + + + Args: + M: (optional) A maximum trip-count for the loop specified at runtime. + Optional. Pass empty string to skip. + + cond: (optional) A boolean termination condition. Optional. Pass empty + string to skip. + + v_initial: (variadic, heterogeneous) The initial values of any loop-carried + dependencies (values that change across loop iterations) + + body: The graph run each iteration. It has 2+N inputs: (iteration_num, + condition, loop carried dependencies...). It has 1+N+K outputs: + (condition, loop carried dependencies..., scan_outputs...). Each + scan_output is created by concatenating the value of the specified + output value at the end of each iteration of the loop. It is an error if + the dimensions or data type of these scan_outputs change across loop + iterations. + """ + + schema = get_schema("Loop", 21, "") + op = Op(self, "Loop", schema) + return op(*self._prepare_inputs(schema, M, cond, *v_initial), body=body) + + T_Pad = TypeVar( + "T_Pad", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + Tind_Pad = TypeVar("Tind_Pad", INT32, INT64) + + def Pad( + self, + data: T_Pad, + pads: INT64, + constant_value: Optional[T_Pad] = None, + axes: Optional[Tind_Pad] = None, + *, + mode: str = "constant", + ) -> T_Pad: + r"""[🌐 Pad(21)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-21 "Online Documentation") + + + Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, + a padded tensor (`output`) is generated. + + The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`): + + 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False) + + 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis + + 3) `edge` - pads with the edge values of array + + 4) `wrap` - wrap-around padding as if the data tensor forms a torus + + + Example 1 (`constant` mode): + + Insert 0 pads to the beginning of the second dimension. + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'constant' + + constant_value = 0.0 + + output = [ + [0.0, 0.0, 1.0, 1.2], + [0.0, 0.0, 2.3, 3.4], + [0.0, 0.0, 4.5, 5.7], + ] + + + + Example 2 (`reflect` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'reflect' + + output = [ + [1.0, 1.2, 1.0, 1.2], + [2.3, 3.4, 2.3, 3.4], + [4.5, 5.7, 4.5, 5.7], + ] + + + + Example 3 (`edge` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'edge' + + output = [ + [1.0, 1.0, 1.0, 1.2], + [2.3, 2.3, 2.3, 3.4], + [4.5, 4.5, 4.5, 5.7], + ] + + + + Example 4 (`wrap` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [2, 1, 1, 1] + + mode = 'wrap' + + output = [ + [3.4, 2.3, 3.4, 2.3], + [5.7, 4.5, 5.7, 4.5], + [1.2, 1.0, 1.2, 1.0], + [3.4, 2.3, 3.4, 2.3], + [5.7, 4.5, 5.7, 4.5], + [1.2, 1.0, 1.2, 1.0], + ] + + + + + Args: + data: (differentiable) Input tensor. + + pads: (non-differentiable) Tensor of integers indicating the number of + padding elements to add or remove (if negative) at the beginning and end + of each axis. For 2D input tensor, it is the number of pixels. `pads` + should be a 1D tensor of shape [2 * num_axes] where `num_axes` refers to + the number of elements in the `axes` input or the input rank if `axes` + are not provided explicitly. `pads` format should be: [x1_begin, + x2_begin, ..., x1_end, x2_end,...], where xi_begin is the number of pad + values added at the beginning of axis `axes[i]` and xi_end, the number + of pad values added at the end of axis `axes[i]`. + + constant_value: (optional, non-differentiable) (Optional) A scalar value to + be used if the mode chosen is `constant` (by default it is 0, empty + string or False). + + axes: (optional, non-differentiable) 1-D tensor of axes that `pads` apply + to. Negative value means counting dimensions from the back. Accepted + range is [-r, r-1] where r = rank(data). Behavior is undefined if an + axis is repeated. If not provided, all axes are assumed (`[0, 1, ..., + input_rank-1]`). + + mode: Supported modes: `constant`(default), `reflect`, `edge`, `wrap` + """ + + schema = get_schema("Pad", 21, "") + op = Op(self, "Pad", schema) + return op(*self._prepare_inputs(schema, data, pads, constant_value, axes), mode=mode) + + T1_QLinearMatMul = TypeVar( + "T1_QLinearMatMul", + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT8, + UINT8, + ) + + TS_QLinearMatMul = TypeVar("TS_QLinearMatMul", BFLOAT16, FLOAT, FLOAT16) + + T2_QLinearMatMul = TypeVar( + "T2_QLinearMatMul", + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT8, + UINT8, + ) + + T3_QLinearMatMul = TypeVar( + "T3_QLinearMatMul", + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT8, + UINT8, + ) + + def QLinearMatMul( + self, + a: T1_QLinearMatMul, + a_scale: TS_QLinearMatMul, + a_zero_point: T1_QLinearMatMul, + b: T2_QLinearMatMul, + b_scale: TS_QLinearMatMul, + b_zero_point: T2_QLinearMatMul, + y_scale: TS_QLinearMatMul, + y_zero_point: T3_QLinearMatMul, + ) -> T3_QLinearMatMul: + r"""[🌐 QLinearMatMul(21)](https://onnx.ai/onnx/operators/onnx__QLinearMatMul.html#qlinearmatmul-21 "Online Documentation") + + + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). + It consumes two quantized input tensors, their scales and zero points, scale and zero point of output, + and computes the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point). + For (x / y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. + Scale and zero point must have same shape. They must be either scalar (per tensor) or N-D tensor + (per row for 'a' and per column for 'b'). Scalar refers to per tensor quantization whereas N-D refers to per row + or per column quantization. If the input is 2D of shape [M, K] then zero point and scale tensor may be + an M element vector [v_1, v_2, ..., v_M] for per row quantization and K element vector of shape [v_1, v_2, ..., v_K] + for per column quantization. If the input is N-D tensor with shape [D1, D2, M, K] then zero point and scale tensor may + have shape [D1, D2, M, 1] for per row quantization and shape [D1, D2, 1, K] for per column quantization. + Production must never overflow, and accumulation may overflow if and only if in 32 bits. + + + Args: + a: (non-differentiable) N-dimensional quantized matrix a + + a_scale: (non-differentiable) scale of quantized input a + + a_zero_point: (non-differentiable) zero point of quantized input a + + b: (non-differentiable) N-dimensional quantized matrix b + + b_scale: (non-differentiable) scale of quantized input b + + b_zero_point: (non-differentiable) zero point of quantized input b + + y_scale: (non-differentiable) scale of quantized output y + + y_zero_point: (non-differentiable) zero point of quantized output y + """ + + schema = get_schema("QLinearMatMul", 21, "") + op = Op(self, "QLinearMatMul", schema) + return op( + *self._prepare_inputs( + schema, + a, + a_scale, + a_zero_point, + b, + b_scale, + b_zero_point, + y_scale, + y_zero_point, + ) + ) + + T1_QuantizeLinear = TypeVar("T1_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32) + + T2_QuantizeLinear = TypeVar( + "T2_QuantizeLinear", + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT4, + INT8, + UINT16, + UINT4, + UINT8, + ) + + def QuantizeLinear( + self, + x: T1_QuantizeLinear, + y_scale: T1_QuantizeLinear, + y_zero_point: Optional[T2_QuantizeLinear] = None, + *, + axis: int = 1, + block_size: int = 0, + output_dtype: int = 0, + saturate: int = 1, + ) -> T2_QuantizeLinear: + r"""[🌐 QuantizeLinear(21)](https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html#quantizelinear-21 "Online Documentation") + + + The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the + low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization + granularity. The quantization formula is `y = saturate((x / y_scale) + y_zero_point)`. + Saturation is done according to: + - uint16: [0, 65535] + - int16: [-32768, 32767] + - uint8: [0, 255] + - int8: [-128, 127] + - uint4: [0, 15] + - int4: [-8, 7] + For `(x / y_scale)`, it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details. + `y_zero_point` and `y` must have the same type. `y_zero_point` is usually not used for quantization to float8 types, but the quantization + formula remains the same for consistency, and the type of the attribute `y_zero_point` still determines the quantization type. + There are three supported quantization granularities, determined by the shape of `y_scale`. + In all cases, `y_zero_point` must have the same shape as `y_scale`. + - Per-tensor (per-layer) quantization: `y_scale` is a scalar. + - Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape + `(D0, ..., Di, ..., Dn)` and `axis=i`, `y_scale` is a 1-D tensor of length `Di`. + - Blocked quantization: The scale's shape is identical to the input's shape, except for one dimension, in which + blocking is performed. Given `x` shape `(D0, ..., Di, ..., Dn)`, `axis=i`, and block size `B`: `y_scale` shape is + `(D0, ..., ceil(Di/B), ..., Dn)`. + + + Args: + x: N-D full precision Input tensor to be quantized. + + y_scale: Scale for doing quantization to get `y`. For per-tensor/layer + quantization the scale is a scalar, for per-axis quantization it is a + 1-D Tensor and for blocked quantization it has the same shape as the + input, except for one dimension in which blocking is performed. + + y_zero_point: (optional) Zero point for doing quantization to get `y`. Shape + must match `y_scale`.Default is uint8 with zero point of 0 if it's not + specified. + + axis: (Optional) The axis of the dequantizing dimension of the input tensor. + Used only for per-axis and blocked quantization. Negative value means + counting dimensions from the back. Accepted range is `[-r, r-1]` where + `r = rank(input)`. When the rank of the input is 1, per-tensor + quantization is applied, rendering the axis unnecessary in this + scenario. + + block_size: (Optional) The size of the quantization block (number of times + every scale is replicated). Used only for blocked quantization. The + block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., + Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted + range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]` + + output_dtype: (Optional) The output data type. If not supplied, the output + data type is inferred from `y_zero_point` data type (`T2`). If neither + `output_dtype` nor `y_zero_point` are supplied, output data type is + uint8. If both `output_dtype` and `y_zero_point` are specified, + `output_dtype` must be `T2`. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + quantization (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). + It is true by default. All cases are fully described in two tables + inserted in the operator description. + """ + + schema = get_schema("QuantizeLinear", 21, "") + op = Op(self, "QuantizeLinear", schema) + return op( + *self._prepare_inputs(schema, x, y_scale, y_zero_point), + axis=axis, + block_size=block_size, + output_dtype=output_dtype, + saturate=saturate, + ) + + T_Reshape = TypeVar( + "T_Reshape", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Reshape(self, data: T_Reshape, shape: INT64, *, allowzero: int = 0) -> T_Reshape: + r"""[🌐 Reshape(21)](https://onnx.ai/onnx/operators/onnx__Reshape.html#reshape-21 "Online Documentation") + + + Reshape the input tensor similar to numpy.reshape. + First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor. + At most one dimension of the new shape can be -1. In this case, the value is + inferred from the size of the tensor and the remaining dimensions. A dimension + could also be 0, in which case the actual dimension value is unchanged (i.e. taken + from the input tensor). If 'allowzero' is set, and the new shape includes 0, the + dimension will be set explicitly to zero (i.e. not taken from input tensor). + Shape (second input) could be an empty shape, which means converting to a scalar. + The input tensor's shape and the output tensor's shape are required to have the same number of elements. + + If the attribute 'allowzero' is set, it is invalid for the specified shape to + contain both a zero value and -1, as the value of the dimension corresponding + to -1 cannot be determined uniquely. + + + Args: + data: (differentiable) An input tensor. + + shape: (non-differentiable) Specified shape for output. + + allowzero: (Optional) By default, when any value in the 'shape' input is + equal to zero the corresponding dimension value is copied from the input + tensor dynamically. allowzero=1 indicates that if any value in the + 'shape' input is set to zero, the zero value is honored, similar to + NumPy. + """ + + schema = get_schema("Reshape", 21, "") + op = Op(self, "Reshape", schema) + return op(*self._prepare_inputs(schema, data, shape), allowzero=allowzero) + + V_Scan = TypeVar( + "V_Scan", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Scan( + self, + *initial_state_and_scan_inputs: V_Scan, + body: GraphProto, + num_scan_inputs: int, + scan_input_axes: Optional[Sequence[int]] = None, + scan_input_directions: Optional[Sequence[int]] = None, + scan_output_axes: Optional[Sequence[int]] = None, + scan_output_directions: Optional[Sequence[int]] = None, + ) -> V_Scan: + r"""[🌐 Scan(21)](https://onnx.ai/onnx/operators/onnx__Scan.html#scan-21 "Online Documentation") + + + Scan can be used to iterate over one or more scan_input tensors, + constructing zero or more scan_output tensors. It combines ideas from general recurrences, + functional programming constructs such as scan, fold, map, and zip, and is intended to enable + generalizations of RNN-like constructs for sequence-to-sequence processing. + Other tensors (referred to as state_variables here) can be used to carry a state + when iterating from one element to another (similar to hidden-state in RNNs, also referred + to as loop-carried dependences in the context of loops). + Many common usages involve a single scan_input tensor (where functionality + similar to scan, fold and map can be obtained). When more than one scan_input is used, + a behavior similar to zip is obtained. + + The attribute body must be a graph, specifying the computation to be performed in + every iteration. It takes as input the current values of the state_variables and + the current iterated element of the scan_inputs. It must return the (updated) values + of the state_variables and zero or more scan_output_element tensors. The values of the + scan_output_element tensors are concatenated over all the iterations to produce the + scan_output values of the scan construct (similar to the concatenated intermediate + hidden-state values of RNN-like constructs). All the output tensors (state_variables as + well as scan_output_element tensors) are required to have the same shape in each iteration + of the loop (a restriction imposed to enable efficient memory allocation). + + Note that the iterated element passed to the body subgraph does not have a sequence + axis. It will have a rank one less than the rank of the corresponding scan_input. + + The scan operation returns the final values of the state_variables as well as the + scan_outputs. + + The optional attribute scan_input_directions specifies the direction (forward or backward) + for each scan input. If this attribute is omitted, all sequences are scanned in the forward + direction. A bidirectional scan may be performed by specifying the same tensor input twice + in the scan_inputs, once with a forward direction, and once with a backward direction. + + The scan_output of the operation is produced by concatenating the scan_output_element + values produced by the body in each iteration. The optional attribute scan_output_directions + specifies the direction in which scan_output is constructed (by appending or prepending the + scan_output_element to scan_output in each iteration) for each scan_output. If this attribute + is omitted, the scan_output_element is appended to the scan_output in each iteration. + + The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input. + If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the + batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1. + Note that scanning a non-zero axis may be less efficient than scanning axis zero. + + The optional attribute scan_output_axes specifies the axis along which the scan_outputs + are accumulated for each scan_output. For example, if axis 1 is the time axis (to be + scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis + value of 1. + + Note that because of the ONNX restriction that only the last parameter of an operator can + be variadic, the initial-states and scan-inputs are listed together as one input parameter. + Similarly, the final-states and scan-outputs are listed together as one output parameter. + The attribute num_scan_inputs indicates the number M of scan-inputs. + + The behavior of + + Scan < + num_scan_inputs = m, + body = loop-body, + scan_input_axes = [axis_1, ..., axis_m] + > (init_1, ..., init_n, scan_1, ..., scan_m) + + is equivalent to the following pseudo-code: + + // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i + // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j. + sequence_length = scan_1.shape[axis_1]; + + // initialize state-variables + st_1 = init_1; ... st_n = init_n; + // initialize scan-output variables: [] denotes an empty tensor + scan_out_1 = []; ...; scan_out_k = []; + // identify number of iterations: + + // execute loop + for (int t = 0; t < sequence_length; ++t) { + // generate the scan-input elements: the notation T[t] indicates the sub-tensor + // of rank one less than T obtained by indexing T at position t along axis k. + si_1 = scan_1[t]; + ... ; + si_m = scan_m[t]; + // execute loop-body + st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m) + // accumulate the scan-output elements + scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k); + } + + return st_1, ..., st_n, scan_out_1, ..., scan_out_k; + + *Sample usage: Encoding RNN using a Scan* + + The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi, + recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can + be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes + %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these + values are computed in the outer graph, they need to be passed in as extra state_variables. + + graph rnn-encoding { + %H_0 = ... + %X = ... + %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X) + return %Y, %Y_h + } + + graph rnn-cell-1 ( + %H_tminus1[FLOAT, tensor] + %X_t[FLOAT, tensor] + ) { + %Wi = ... + %Ri = ... + %Wbi = ... + %Rbi = ... + %t1 = X_t * (Wi^T) + %t2 = H_tminus1*(Ri^T) + %t3 = Add(%t1, %t2) + %t4 = Add(%t3, %Wbi) + %t5 = Add(%t4, %Rbi) + %Ht = Tanh(%t5) + %Accumulate = Identity(%Ht) + return %Ht, %Accumulate + } + + + + Args: + initial_state_and_scan_inputs: (variadic, heterogeneous) Initial values of + the loop's N state variables followed by M scan_inputs + + body: The graph run each iteration. It has N+M inputs: (loop state + variables..., scan_input_elts...). It has N+K outputs: (loop state + variables..., scan_output_elts...). Each scan_output is created by + concatenating the value of the specified scan_output_elt value at the + end of each iteration of the loop. It is an error if the dimensions of + these values change across loop iterations. + + num_scan_inputs: An attribute specifying the number of scan_inputs M. + + scan_input_axes: An optional list of M flags. The i-th element of the list + specifies the axis to be scanned (the sequence axis) for the i-th + scan_input. If omitted, 0 will be used as the scan axis for every + scan_input. Negative value for an axis means counting dimensions from + the back. Accepted range is [-r, r-1] where r = rank(input). + + scan_input_directions: An optional list of M flags. The i-th element of the + list specifies the direction to be scanned for the i-th scan_input + tensor: 0 indicates forward direction and 1 indicates reverse direction. + If omitted, all scan_input tensors will be scanned in the forward + direction. + + scan_output_axes: An optional list of K flags. The i-th element of the list + specifies the axis for the i-th scan_output. The scan outputs are + accumulated along the specified axis. If omitted, 0 will be used as the + scan axis for every scan_output. Negative value for an axis means + counting dimensions from the back. Accepted range is [-r, r-1]. + + scan_output_directions: An optional list of K flags, one for each + scan_output. The i-th element of the list specifies whether the i-th + scan_output should be constructed by appending or prepending a new value + in each iteration: 0 indicates appending and 1 indicates prepending. If + omitted, all scan_output tensors will be produced by appending a value + in each iteration. + """ + + schema = get_schema("Scan", 21, "") + op = Op(self, "Scan", schema) + return op( + *self._prepare_inputs(schema, *initial_state_and_scan_inputs), + body=body, + num_scan_inputs=num_scan_inputs, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + ) + + T_Shape = TypeVar( + "T_Shape", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T1_Shape: TypeAlias = INT64 + + def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) -> T1_Shape: + r"""[🌐 Shape(21)](https://onnx.ai/onnx/operators/onnx__Shape.html#shape-21 "Online Documentation") + + + Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor. + Optional attributes start and end can be used to compute a slice of the input tensor's shape. + If start axis is omitted, the slice starts from axis 0. + The end axis, if specified, is exclusive (and the returned value will not include the size of that axis). + If the end axis is omitted, the axes upto the last one will be included. + Negative axes indicate counting back from the last axis. + Note that axes will be clamped to the range [0, r], where r is the + rank of the input tensor if they are out-of-range (after adding r in the case of + negative axis). Thus, specifying any end value > r is equivalent to specifying an end + value of r, and specifying any start value < -r is equivalent to specifying a start + value of 0. If start > end, the result will be an empty shape. + + Examples: + + :: + + Input tensor with shape: [2, 3, 4] + No attributes specified. + Output: [2, 3, 4] + + + + :: + + Input tensor with shape: [2, 3, 4] + start: -1 + Output: [4] + + + + :: + + Input tensor with shape: [2, 3, 4] + end: -1 + Output: [2, 3] + + + + :: + + Input tensor with shape: [2, 3, 4] + start: 1 + end: 2 + Output: [3] + + + + + Args: + data: (non-differentiable) An input tensor. + + end: (Optional) Ending axis for slicing the shape. Negative value means + counting dimensions from the back. If omitted, sizes of all axes upto + (including) the last one will be included. + + start: (Optional) Starting axis for slicing the shape. Default value is + 0.Negative value means counting dimensions from the back. + """ + + schema = get_schema("Shape", 21, "") + op = Op(self, "Shape", schema) + return op(*self._prepare_inputs(schema, data), end=end, start=start) + + T_Size = TypeVar( + "T_Size", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T1_Size: TypeAlias = INT64 + + def Size(self, data: T_Size) -> T1_Size: + r"""[🌐 Size(21)](https://onnx.ai/onnx/operators/onnx__Size.html#size-21 "Online Documentation") + + + Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor. + + + Args: + data: (non-differentiable) An input tensor. + """ + + schema = get_schema("Size", 21, "") + op = Op(self, "Size", schema) + return op(*self._prepare_inputs(schema, data)) + + T_Squeeze = TypeVar( + "T_Squeeze", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Squeeze(self, data: T_Squeeze, axes: Optional[INT64] = None) -> T_Squeeze: + r"""[🌐 Squeeze(21)](https://onnx.ai/onnx/operators/onnx__Squeeze.html#squeeze-21 "Online Documentation") + + + Remove single-dimensional entries from the shape of a tensor. + Takes an input `axes` with a list of axes to squeeze. + If `axes` is not provided, all the single dimensions will be removed from + the shape. If an axis is selected with shape entry not equal to one, an error is raised. + + + Args: + data: (differentiable) Tensors with at least max(dims) dimensions. + + axes: (optional, non-differentiable) List of integers indicating the + dimensions to squeeze. Negative value means counting dimensions from the + back. Accepted range is [-r, r-1] where r = rank(data). + """ + + schema = get_schema("Squeeze", 21, "") + op = Op(self, "Squeeze", schema) + return op(*self._prepare_inputs(schema, data, axes)) + + T_Transpose = TypeVar( + "T_Transpose", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Transpose( + self, data: T_Transpose, *, perm: Optional[Sequence[int]] = None + ) -> T_Transpose: + r"""[🌐 Transpose(21)](https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 "Online Documentation") + + + Transpose the input tensor similar to numpy.transpose. For example, when + perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape + will be (2, 1, 3). + + + Args: + data: (differentiable) An input tensor. + + perm: A list of integers. By default, reverse the dimensions, otherwise + permute the axes according to the values given. Its length must be equal + to the rank of the input. + """ + + schema = get_schema("Transpose", 21, "") + op = Op(self, "Transpose", schema) + return op(*self._prepare_inputs(schema, data), perm=perm) + + T_Unsqueeze = TypeVar( + "T_Unsqueeze", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Unsqueeze(self, data: T_Unsqueeze, axes: INT64) -> T_Unsqueeze: + r"""[🌐 Unsqueeze(21)](https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#unsqueeze-21 "Online Documentation") + + + Insert single-dimensional entries to the shape of an input tensor (`data`). + Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`). + + For example, given an input tensor (`data`) of shape [3, 4, 5], then + Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1]. + + The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates. + The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`. + Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1]. + The order of values in `axes` does not matter and can come in any order. + + + Args: + data: (differentiable) Original tensor + + axes: (non-differentiable) List of integers indicating the dimensions to be + inserted. Negative value means counting dimensions from the back. + Accepted range is [-r, r-1] where r = rank(expanded). + """ + + schema = get_schema("Unsqueeze", 21, "") + op = Op(self, "Unsqueeze", schema) + return op(*self._prepare_inputs(schema, data, axes)) diff --git a/onnxscript/onnx_opset/_impl/opset22.py b/onnxscript/onnx_opset/_impl/opset22.py new file mode 100644 index 0000000000..2b1656ed2a --- /dev/null +++ b/onnxscript/onnx_opset/_impl/opset22.py @@ -0,0 +1,2593 @@ +# -------------------------------------------------------------------------- +# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ +# ⚙️ Generated by 'python -m opgen' +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=W0221,W0222,R0901,W0237 +# mypy: disable-error-code=override +# ruff: noqa: E741, D402, D405 +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, TypeVar, Union + +from onnx.defs import get_schema +from typing_extensions import TypeAlias + +from onnxscript.onnx_opset._impl.opset21 import Opset21 +from onnxscript.onnx_types import ( + BFLOAT16, + BOOL, + COMPLEX64, + COMPLEX128, + DOUBLE, + FLOAT, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT16, + INT8, + INT16, + INT32, + INT64, + STRING, + UINT8, + UINT16, + UINT32, + UINT64, +) +from onnxscript.values import Op, Opset + + +class Opset22(Opset21): + def __new__(cls): + return Opset.__new__(cls, "", 22) + + T_Acos = TypeVar("T_Acos", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Acos(self, input: T_Acos) -> T_Acos: + r"""[🌐 Acos(22)](https://onnx.ai/onnx/operators/onnx__Acos.html#acos-22 "Online Documentation") + + + Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Acos", 22, "") + op = Op(self, "Acos", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Acosh = TypeVar("T_Acosh", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Acosh(self, input: T_Acosh) -> T_Acosh: + r"""[🌐 Acosh(22)](https://onnx.ai/onnx/operators/onnx__Acosh.html#acosh-22 "Online Documentation") + + + Calculates the hyperbolic arccosine of the given input tensor element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Acosh", 22, "") + op = Op(self, "Acosh", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Asin = TypeVar("T_Asin", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Asin(self, input: T_Asin) -> T_Asin: + r"""[🌐 Asin(22)](https://onnx.ai/onnx/operators/onnx__Asin.html#asin-22 "Online Documentation") + + + Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Asin", 22, "") + op = Op(self, "Asin", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Asinh = TypeVar("T_Asinh", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Asinh(self, input: T_Asinh) -> T_Asinh: + r"""[🌐 Asinh(22)](https://onnx.ai/onnx/operators/onnx__Asinh.html#asinh-22 "Online Documentation") + + + Calculates the hyperbolic arcsine of the given input tensor element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Asinh", 22, "") + op = Op(self, "Asinh", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Atan = TypeVar("T_Atan", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Atan(self, input: T_Atan) -> T_Atan: + r"""[🌐 Atan(22)](https://onnx.ai/onnx/operators/onnx__Atan.html#atan-22 "Online Documentation") + + + Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Atan", 22, "") + op = Op(self, "Atan", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Atanh = TypeVar("T_Atanh", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Atanh(self, input: T_Atanh) -> T_Atanh: + r"""[🌐 Atanh(22)](https://onnx.ai/onnx/operators/onnx__Atanh.html#atanh-22 "Online Documentation") + + + Calculates the hyperbolic arctangent of the given input tensor element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Atanh", 22, "") + op = Op(self, "Atanh", schema) + return op(*self._prepare_inputs(schema, input)) + + T_AveragePool = TypeVar("T_AveragePool", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def AveragePool( + self, + X: T_AveragePool, + *, + auto_pad: str = "NOTSET", + ceil_mode: int = 0, + count_include_pad: int = 0, + dilations: Optional[Sequence[int]] = None, + kernel_shape: Sequence[int], + pads: Optional[Sequence[int]] = None, + strides: Optional[Sequence[int]] = None, + ) -> T_AveragePool: + r"""[🌐 AveragePool(22)](https://onnx.ai/onnx/operators/onnx__AveragePool.html#averagepool-22 "Online Documentation") + + + AveragePool consumes an input tensor X and applies average pooling across + the tensor according to kernel sizes, stride sizes, and pad lengths. + average pooling consisting of computing the average on all values of a + subset of the input tensor according to the kernel size and downsampling the + data into the output tensor Y for further processing. The output spatial shape is calculated differently + depending on whether explicit padding is used, where pads is employed, or auto padding is used, where auto_pad is utilized. + With explicit padding (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html?highlight=maxpool#torch.nn.MaxPool2d): + ``` + output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) + ``` + or + ``` + output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) + ``` + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. + + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: + ``` + VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) + 1) / strides_spatial_shape[i]) + SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i]) + ``` + or when ceil_mode is disabled (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AveragePooling2D): + ``` + VALID: output_spatial_shape[i] = floor((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i]) + 1 + SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = floor((input_spatial_shape[i] - 1) / strides_spatial_shape[i]) + 1 + ``` + And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`: + ``` + pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i] + ``` + The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). + + + Args: + X: (differentiable) Input data tensor from the previous operator; dimensions + for image case are (N x C x H x W), where N is the batch size, C is the + number of channels, and H and W are the height and the width of the + data. For non image case, the dimensions are in the form of (N x C x D1 + x D2 ... Dn), where N is the batch size. Optionally, if dimension + denotation is in effect, the operation expects the input data tensor to + arrive with the dimension denotation of [DATA_BATCH, DATA_CHANNEL, + DATA_FEATURE, DATA_FEATURE ...]. + + auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID. + Where default value is NOTSET, which means explicit padding is used. + SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] = + ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is + split between the two sides equally or almost equally (depending on + whether it is even or odd). In case the padding is an odd number, the + extra padding is added at the end for SAME_UPPER and at the beginning + for SAME_LOWER. + + ceil_mode: Whether to use ceil or floor (default) to compute the output + shape. + + count_include_pad: Whether include pad pixels when calculating values for + the edges. Default is 0, doesn't count include pad. + + dilations: Dilation value along each spatial axis of filter. If not present, + the dilation defaults to 1 along each spatial axis. + + kernel_shape: The size of the kernel along each axis. + + pads: Padding for the beginning and ending along each spatial axis, it can + take any value greater than or equal to 0. The value represent the + number of pixels added to the beginning and end part of the + corresponding axis. `pads` format should be as follow [x1_begin, + x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels + added at the beginning of axis `i` and xi_end, the number of pixels + added at the end of axis `i`. This attribute cannot be used + simultaneously with auto_pad attribute. If not present, the padding + defaults to 0 along start and end of each spatial axis. + + strides: Stride along each spatial axis. If not present, the stride defaults + to 1 along each spatial axis. + """ + + schema = get_schema("AveragePool", 22, "") + op = Op(self, "AveragePool", schema) + return op( + *self._prepare_inputs(schema, X), + auto_pad=auto_pad, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + dilations=dilations, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + T1_Bernoulli = TypeVar("T1_Bernoulli", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T2_Bernoulli: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, + ] + + def Bernoulli( + self, + input: T1_Bernoulli, + *, + dtype: Optional[int] = None, + seed: Optional[float] = None, + ) -> T2_Bernoulli: + r"""[🌐 Bernoulli(22)](https://onnx.ai/onnx/operators/onnx__Bernoulli.html#bernoulli-22 "Online Documentation") + + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor + containing probabilities p (a value in the range [0,1]) to be used for drawing the binary random number, + where an output of 1 is produced with probability p and an output of 0 is produced with probability (1-p). + + This operator is non-deterministic and may not produce the same values in different + implementations (even if a seed is specified). + + + Args: + input: All values in input have to be in the range:[0, 1]. + + dtype: The data type for the elements of the output tensor. if not + specified, we will use the data type of the input tensor. + + seed: (Optional) Seed to the random generator, if not specified we will auto + generate one. + """ + + schema = get_schema("Bernoulli", 22, "") + op = Op(self, "Bernoulli", schema) + return op(*self._prepare_inputs(schema, input), dtype=dtype, seed=seed) + + T_Conv = TypeVar("T_Conv", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Conv( + self, + X: T_Conv, + W: T_Conv, + B: Optional[T_Conv] = None, + *, + auto_pad: str = "NOTSET", + dilations: Optional[Sequence[int]] = None, + group: int = 1, + kernel_shape: Optional[Sequence[int]] = None, + pads: Optional[Sequence[int]] = None, + strides: Optional[Sequence[int]] = None, + ) -> T_Conv: + r"""[🌐 Conv(22)](https://onnx.ai/onnx/operators/onnx__Conv.html#conv-22 "Online Documentation") + + + The convolution operator consumes an input tensor and a filter, and + computes the output. + + Args: + X: (differentiable) Input data tensor from previous layer; has size (N x C x + H x W), where N is the batch size, C is the number of channels, and H + and W are the height and width. Note that this is for the 2D image. + Otherwise the size is (N x C x D1 x D2 ... x Dn). Optionally, if + dimension denotation is in effect, the operation expects input data + tensor to arrive with the dimension denotation of [DATA_BATCH, + DATA_CHANNEL, DATA_FEATURE, DATA_FEATURE ...]. + + W: (differentiable) The weight tensor that will be used in the convolutions; + has size (M x C/group x kH x kW), where C is the number of channels, and + kH and kW are the height and width of the kernel, and M is the number of + feature maps. For more than 2 dimensions, the kernel shape will be (M x + C/group x k1 x k2 x ... x kn), where (k1 x k2 x ... kn) is the dimension + of the kernel. Optionally, if dimension denotation is in effect, the + operation expects the weight tensor to arrive with the dimension + denotation of [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, FILTER_SPATIAL, + FILTER_SPATIAL ...]. Assuming zero based indices for the shape array, + X.shape[1] == (W.shape[1] * group) == C and W.shape[0] mod G == 0. Or in + other words FILTER_IN_CHANNEL multiplied by the number of groups should + be equal to DATA_CHANNEL and the number of feature maps M should be a + multiple of the number of groups G. + + B: (optional, differentiable) Optional 1D bias to be added to the + convolution, has size of M. + + auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID. + Where default value is NOTSET, which means explicit padding is used. + SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] = + ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is + split between the two sides equally or almost equally (depending on + whether it is even or odd). In case the padding is an odd number, the + extra padding is added at the end for SAME_UPPER and at the beginning + for SAME_LOWER. + + dilations: dilation value along each spatial axis of the filter. If not + present, the dilation defaults is 1 along each spatial axis. + + group: number of groups input channels and output channels are divided into. + + kernel_shape: The shape of the convolution kernel. If not present, should be + inferred from input W. + + pads: Padding for the beginning and ending along each spatial axis, it can + take any value greater than or equal to 0. The value represent the + number of pixels added to the beginning and end part of the + corresponding axis. `pads` format should be as follow [x1_begin, + x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels + added at the beginning of axis `i` and xi_end, the number of pixels + added at the end of axis `i`. This attribute cannot be used + simultaneously with auto_pad attribute. If not present, the padding + defaults to 0 along start and end of each spatial axis. + + strides: Stride along each spatial axis. If not present, the stride defaults + is 1 along each spatial axis. + """ + + schema = get_schema("Conv", 22, "") + op = Op(self, "Conv", schema) + return op( + *self._prepare_inputs(schema, X, W, B), + auto_pad=auto_pad, + dilations=dilations, + group=group, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + T_ConvTranspose = TypeVar("T_ConvTranspose", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def ConvTranspose( + self, + X: T_ConvTranspose, + W: T_ConvTranspose, + B: Optional[T_ConvTranspose] = None, + *, + auto_pad: str = "NOTSET", + dilations: Optional[Sequence[int]] = None, + group: int = 1, + kernel_shape: Optional[Sequence[int]] = None, + output_padding: Optional[Sequence[int]] = None, + output_shape: Optional[Sequence[int]] = None, + pads: Optional[Sequence[int]] = None, + strides: Optional[Sequence[int]] = None, + ) -> T_ConvTranspose: + r"""[🌐 ConvTranspose(22)](https://onnx.ai/onnx/operators/onnx__ConvTranspose.html#convtranspose-22 "Online Documentation") + + + The convolution transpose operator consumes an input tensor and a filter, + and computes the output. + + If the pads parameter is provided the shape of the output is calculated via the following equation: + + output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i] + + output_shape can also be explicitly specified in which case pads values are auto generated using these equations: + + total_padding[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i] + If (auto_pads == SAME_UPPER): pads[start_i] = total_padding[i]/2; pads[end_i] = total_padding[i] - (total_padding[i]/2) + Else: pads[start_i] = total_padding[i] - (total_padding[i]/2); pads[end_i] = (total_padding[i]/2). + + + + Args: + X: (differentiable) Input data tensor from previous layer; has size (N x C x + H x W), where N is the batch size, C is the number of channels, and H + and W are the height and width. Note that this is for the 2D image. + Otherwise the size is (N x C x D1 x D2 ... x Dn) + + W: (differentiable) The weight tensor that will be used in the convolutions; + has size (C x M/group x kH x kW), where C is the number of channels, and + kH and kW are the height and width of the kernel, and M is the number of + feature maps. For more than 2 dimensions, the weight shape will be (C x + M/group x k1 x k2 x ... x kn), where (k1 x k2 x ... x kn) is the + dimension of the kernel. The number of channels in the output should be + equal to W.shape[1] * group (assuming zero based indices of the shape + array) + + B: (optional, differentiable) Optional 1D bias to be added to the + convolution, has size of M. + + auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID. + Where default value is NOTSET, which means explicit padding is used. + SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] = + input_shape[i] * strides[i]` for each axis `i`. The padding is split + between the two sides equally or almost equally (depending on whether it + is even or odd). In case the padding is an odd number, the extra padding + is added at the end for SAME_UPPER and at the beginning for SAME_LOWER. + + dilations: dilation value along each spatial axis of the filter. If not + present, the dilation defaults to 1 along each spatial axis. + + group: number of groups input channels and output channels are divided into. + + kernel_shape: The shape of the convolution kernel. If not present, should be + inferred from input W. + + output_padding: Additional elements added to the side with higher coordinate + indices in the output. Each padding value in "output_padding" must be + less than the corresponding stride/dilation dimension. By default, this + attribute is a zero vector. Note that this attribute doesn't directly + affect the computed output values. It only controls the selection of the + computed values, so changing this attribute only adds or removes output + elements. If "output_shape" is explicitly provided, "output_padding" + does not contribute additional size to "output_shape" but participates + in the computation of the needed padding amount. This is also called + adjs or adjustment in some frameworks. + + output_shape: The shape of the output can be explicitly set which will cause + pads values to be auto generated. If output_shape is specified pads + values are ignored. See doc for details for equations to generate pads. + Note that the output_shape attribute value should not include dimensions + for batch size and channels, which are automatically inferred. + + pads: Padding for the beginning and ending along each spatial axis, it can + take any value greater than or equal to 0. The value represent the + number of pixels added to the beginning and end part of the + corresponding axis. `pads` format should be as follow [x1_begin, + x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels + added at the beginning of axis `i` and xi_end, the number of pixels + added at the end of axis `i`. This attribute cannot be used + simultaneously with auto_pad attribute. If not present, the padding + defaults to 0 along start and end of each spatial axis. + + strides: Stride along each spatial axis. If not present, the stride defaults + to 1 along each spatial axis. + """ + + schema = get_schema("ConvTranspose", 22, "") + op = Op(self, "ConvTranspose", schema) + return op( + *self._prepare_inputs(schema, X, W, B), + auto_pad=auto_pad, + dilations=dilations, + group=group, + kernel_shape=kernel_shape, + output_padding=output_padding, + output_shape=output_shape, + pads=pads, + strides=strides, + ) + + T_Cos = TypeVar("T_Cos", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Cos(self, input: T_Cos) -> T_Cos: + r"""[🌐 Cos(22)](https://onnx.ai/onnx/operators/onnx__Cos.html#cos-22 "Online Documentation") + + + Calculates the cosine of the given input tensor, element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Cos", 22, "") + op = Op(self, "Cos", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Cosh = TypeVar("T_Cosh", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Cosh(self, input: T_Cosh) -> T_Cosh: + r"""[🌐 Cosh(22)](https://onnx.ai/onnx/operators/onnx__Cosh.html#cosh-22 "Online Documentation") + + + Calculates the hyperbolic cosine of the given input tensor element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Cosh", 22, "") + op = Op(self, "Cosh", schema) + return op(*self._prepare_inputs(schema, input)) + + T_DeformConv = TypeVar("T_DeformConv", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def DeformConv( + self, + X: T_DeformConv, + W: T_DeformConv, + offset: T_DeformConv, + B: Optional[T_DeformConv] = None, + mask: Optional[T_DeformConv] = None, + *, + dilations: Optional[Sequence[int]] = None, + group: int = 1, + kernel_shape: Optional[Sequence[int]] = None, + offset_group: int = 1, + pads: Optional[Sequence[int]] = None, + strides: Optional[Sequence[int]] = None, + ) -> T_DeformConv: + r"""[🌐 DeformConv(22)](https://onnx.ai/onnx/operators/onnx__DeformConv.html#deformconv-22 "Online Documentation") + + + Performs deformable convolution as described in https://arxiv.org/abs/1703.06211 and https://arxiv.org/abs/1811.11168. + This operator specification supports the general N-D case. Note that most common use cases have 2D or 3D data. + + + Args: + X: Input data tensor. For 2D image data, it has shape (N, C, H, W) where N + is the batch size, C is the number of input channels, and H and W are + the height and width. In general, the shape is (N, C, D1, D2, ... , Dn) + for n-dimensional data, where D1 to Dn are the spatial dimension sizes. + Most common use cases have n = 2 or 3. + + W: Weight tensor that will be used in the convolutions. It has shape (oC, + C/group, kH, kW), where oC is the number of output channels and kH and + kW are the kernel height and width. For more than 2 dimensions, it has + shape (oC, C/group, k1, k2, ... , kn). + + offset: Offset tensor denoting the offset for the sampling locations in the + convolution kernel. It has shape (N, offset_group * kH * kW * 2, oH, oW) + for 2D data or (N, offset_group * k1 * k2 * ... * kn * n, o1, o2, ... , + on) for nD data. Use linear interpolationfor fractional offset values. + Sampling locations outside of the padded input tensor gives zero. + + B: (optional) Optional 1D bias of length oC to be added to the convolution. + Default is a tensor of zeros. + + mask: (optional) The mask tensor to be applied to each position in the + convolution kernel. It has shape (N, offset_group * kH * kW, oH, oW) for + 2D data or (N, offset_group * k1 * k2 * ... * kn * n, o1, o2, ... , on) + for nD data. Default is a tensor of ones. + + dilations: Dilation value along each spatial axis of the kernel. Default is + 1 along each axis. + + group: Number of groups the input and output channels, C and oC, are divided + into. C and oC must both be divisible by group. Default is 1. + + kernel_shape: Shape of the convolution kernel. If not present, it is + inferred from the shape of input W. + + offset_group: Number of groups of offset. C must be divisible by + offset_group. Default is 1. + + pads: Padding for the beginning and end along each spatial axis. The values + represent the number of pixels added to the beginning and end of the + corresponding axis and can take any nonnegative value. The format should + be as follows: [x1_begin, x2_begin, ..., x1_end, x2_end, ...], where + xi_begin is the number of pixels added at the beginning of axis `i` and + xi_end is the number of pixels added at the end of axis `i`. Default is + 0 along each axis. + + strides: Stride along each spatial axis. Default is 1 along each axis. + """ + + schema = get_schema("DeformConv", 22, "") + op = Op(self, "DeformConv", schema) + return op( + *self._prepare_inputs(schema, X, W, offset, B, mask), + dilations=dilations, + group=group, + kernel_shape=kernel_shape, + offset_group=offset_group, + pads=pads, + strides=strides, + ) + + T_Det = TypeVar("T_Det", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Det(self, X: T_Det) -> T_Det: + r"""[🌐 Det(22)](https://onnx.ai/onnx/operators/onnx__Det.html#det-22 "Online Documentation") + + + Det calculates determinant of a square matrix or batches of square matrices. + Det takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions, + and the inner-most 2 dimensions form square matrices. + The output is a tensor of shape `[*]`, containing the determinants of all input submatrices. + e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`). + + + Args: + X: (differentiable) Input tensor + """ + + schema = get_schema("Det", 22, "") + op = Op(self, "Det", schema) + return op(*self._prepare_inputs(schema, X)) + + T_Dropout = TypeVar( + "T_Dropout", + BFLOAT16, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + ) + + T1_Dropout = TypeVar( + "T1_Dropout", + BFLOAT16, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + ) + + T2_Dropout: TypeAlias = BOOL + + def Dropout( + self, + data: T_Dropout, + ratio: Optional[T1_Dropout] = None, + training_mode: Optional[T2_Dropout] = None, + *, + seed: Optional[int] = None, + ) -> Tuple[T_Dropout, T2_Dropout]: + r"""[🌐 Dropout(22)](https://onnx.ai/onnx/operators/onnx__Dropout.html#dropout-22 "Online Documentation") + + + Dropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar). It produces two tensor outputs, + output (floating-point tensor) and mask (optional `Tensor`). If `training_mode` is true then the output Y will be a random dropout; + Note that this Dropout scales the masked input data by the following equation, so to convert the trained model into inference mode, + the user can simply not pass `training_mode` input or set it to false. + :: + + output = scale * data * mask, + + + where + :: + + scale = 1. / (1. - ratio). + + + This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + + + Args: + data: (differentiable) The input data as Tensor. + + ratio: (optional, non-differentiable) The ratio of random dropout, with + value in [0, 1). If set to 0, the output would be a simple copy of the + input. If it's non-zero, output will be a random dropout of the scaled + input, which is typically the case during training. It is an optional + value, if not specified it will default to 0.5. + + training_mode: (optional, non-differentiable) If set to true then it + indicates dropout is being used for training. It is an optional value + hence unless specified explicitly, it is false. If it is false, ratio is + ignored and the operation mimics inference mode where nothing will be + dropped from the input data and if mask is requested as output it will + contain all ones. + + seed: (Optional) Seed to the random generator, if not specified we will auto + generate one. + """ + + schema = get_schema("Dropout", 22, "") + op = Op(self, "Dropout", schema) + return op(*self._prepare_inputs(schema, data, ratio, training_mode), seed=seed) + + T_Elu = TypeVar("T_Elu", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Elu(self, X: T_Elu, *, alpha: float = 1.0) -> T_Elu: + r"""[🌐 Elu(22)](https://onnx.ai/onnx/operators/onnx__Elu.html#elu-22 "Online Documentation") + + + Elu takes one input data (Tensor) and produces one output data + (Tensor) where the function `f(x) = alpha * (exp(x) - 1.) for x < + 0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise. + + + + Args: + X: (differentiable) Input tensor + + alpha: Coefficient of ELU. + """ + + schema = get_schema("Elu", 22, "") + op = Op(self, "Elu", schema) + return op(*self._prepare_inputs(schema, X), alpha=alpha) + + T1_EyeLike = TypeVar( + "T1_EyeLike", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + T2_EyeLike: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, + ] + + def EyeLike( + self, input: T1_EyeLike, *, dtype: Optional[int] = None, k: int = 0 + ) -> T2_EyeLike: + r"""[🌐 EyeLike(22)](https://onnx.ai/onnx/operators/onnx__EyeLike.html#eyelike-22 "Online Documentation") + + + Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D + tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the + same as the input tensor. The data type can be specified by the 'dtype' argument. If + 'dtype' is not specified, then the type of input tensor is used. By default, the main diagonal + is populated with ones, but attribute 'k' can be used to populate upper or lower diagonals. + The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the + TensorProto message and be valid as an output type. + + + Args: + input: 2D input tensor to copy shape, and optionally, type information from. + + dtype: (Optional) The data type for the elements of the output tensor. If + not specified, the data type of the input tensor T1 is used. + + k: (Optional) Index of the diagonal to be populated with ones. Default is 0. + If T2 is the output, this op sets T2[i, i+k] = 1. k = 0 populates the + main diagonal, k > 0 populates an upper diagonal, and k < 0 populates a + lower diagonal. + """ + + schema = get_schema("EyeLike", 22, "") + op = Op(self, "EyeLike", schema) + return op(*self._prepare_inputs(schema, input), dtype=dtype, k=k) + + T_GRU = TypeVar("T_GRU", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T1_GRU: TypeAlias = INT32 + + def GRU( + self, + X: T_GRU, + W: T_GRU, + R: T_GRU, + B: Optional[T_GRU] = None, + sequence_lens: Optional[T1_GRU] = None, + initial_h: Optional[T_GRU] = None, + *, + activation_alpha: Optional[Sequence[float]] = None, + activation_beta: Optional[Sequence[float]] = None, + activations: Optional[Sequence[str]] = None, + clip: Optional[float] = None, + direction: str = "forward", + hidden_size: Optional[int] = None, + layout: int = 0, + linear_before_reset: int = 0, + ) -> Tuple[T_GRU, T_GRU]: + r"""[🌐 GRU(22)](https://onnx.ai/onnx/operators/onnx__GRU.html#gru-22 "Online Documentation") + + + Computes an one-layer GRU. This operator is usually supported via some custom + implementation such as CuDNN. + + Notations: + + * `X` - input tensor + * `z` - update gate + * `r` - reset gate + * `h` - hidden gate + * `t` - time step (t-1 means previous time step) + * `W[zrh]` - W parameter weight matrix for update, reset, and hidden gates + * `R[zrh]` - R recurrence weight matrix for update, reset, and hidden gates + * `Wb[zrh]` - W bias vectors for update, reset, and hidden gates + * `Rb[zrh]` - R bias vectors for update, reset, and hidden gates + * `WB[zrh]` - W parameter weight matrix for backward update, reset, and hidden gates + * `RB[zrh]` - R recurrence weight matrix for backward update, reset, and hidden gates + * `WBb[zrh]` - W bias vectors for backward update, reset, and hidden gates + * `RBb[zrh]` - R bias vectors for backward update, reset, and hidden gates + * `H` - Hidden state + * `num_directions` - 2 if direction == bidirectional else 1 + + Activation functions: + + * Relu(x) - max(0, x) + * Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) + * Sigmoid(x) - 1/(1 + e^{-x}) + + NOTE: + Below are optional + + * Affine(x) - alpha * x + beta + * LeakyRelu(x) - x if x >= 0 else alpha * x + * ThresholdedRelu(x) - x if x >= alpha else 0 + * ScaledTanh(x) - alpha * Tanh(beta * x) + * HardSigmoid(x) - min(max(alpha * x + beta, 0), 1) + * Elu(x) - x if x >= 0 else alpha * (e^x - 1) + * Softsign(x) - x/(1 + |x|) + * Softplus(x) - log(1 + e^x) + + Equations (Default: f=Sigmoid, g=Tanh): + + * zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) + * rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) + * ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0 + * ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0 + * Ht = (1 - zt) (.) ht + zt (.) Ht-1 + This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + + + Args: + X: (differentiable) The input sequences packed (and potentially padded) into + one 3-D tensor with the shape of `[seq_length, batch_size, input_size]`. + + W: (differentiable) The weight tensor for the gates. Concatenation of + `W[zrh]` and `WB[zrh]` (if bidirectional) along dimension 0. This tensor + has shape `[num_directions, 3*hidden_size, input_size]`. + + R: (differentiable) The recurrence weight tensor. Concatenation of `R[zrh]` + and `RB[zrh]` (if bidirectional) along dimension 0. This tensor has + shape `[num_directions, 3*hidden_size, hidden_size]`. + + B: (optional, differentiable) The bias tensor for the gates. Concatenation + of `[Wb[zrh], Rb[zrh]]` and `[WBb[zrh], RBb[zrh]]` (if bidirectional) + along dimension 0. This tensor has shape `[num_directions, + 6*hidden_size]`. Optional: If not specified - assumed to be 0 + + sequence_lens: (optional, non-differentiable) Optional tensor specifying + lengths of the sequences in a batch. If not specified - assumed all + sequences in the batch to have length `seq_length`. It has shape + `[batch_size]`. + + initial_h: (optional, non-differentiable) Optional initial value of the + hidden. If not specified - assumed to be 0. It has shape + `[num_directions, batch_size, hidden_size]`. + + activation_alpha: Optional scaling values used by some activation functions. + The values are consumed in the order of activation functions, for + example (f, g, h) in LSTM. Default values are the same as of + corresponding ONNX operators.For example with LeakyRelu, the default + alpha is 0.01. + + activation_beta: Optional scaling values used by some activation functions. + The values are consumed in the order of activation functions, for + example (f, g, h) in LSTM. Default values are the same as of + corresponding ONNX operators. + + activations: A list of 2 (or 4 if bidirectional) activation functions for + update, reset, and hidden gates. The activation functions must be one of + the activation functions specified above. Optional: See the equations + for default if not specified. + + clip: Cell clip threshold. Clipping bounds the elements of a tensor in the + range of [-threshold, +threshold] and is applied to the input of + activations. No clip if not specified. + + direction: Specify if the RNN is forward, reverse, or bidirectional. Must be + one of forward (default), reverse, or bidirectional. + + hidden_size: Number of neurons in the hidden layer + + layout: The shape format of inputs X, initial_h and outputs Y, Y_h. If 0, + the following shapes are expected: X.shape = [seq_length, batch_size, + input_size], Y.shape = [seq_length, num_directions, batch_size, + hidden_size], initial_h.shape = Y_h.shape = [num_directions, batch_size, + hidden_size]. If 1, the following shapes are expected: X.shape = + [batch_size, seq_length, input_size], Y.shape = [batch_size, seq_length, + num_directions, hidden_size], initial_h.shape = Y_h.shape = [batch_size, + num_directions, hidden_size]. + + linear_before_reset: When computing the output of the hidden gate, apply the + linear transformation before multiplying by the output of the reset + gate. + """ + + schema = get_schema("GRU", 22, "") + op = Op(self, "GRU", schema) + return op( + *self._prepare_inputs(schema, X, W, R, B, sequence_lens, initial_h), + activation_alpha=activation_alpha, + activation_beta=activation_beta, + activations=activations, + clip=clip, + direction=direction, + hidden_size=hidden_size, + layout=layout, + linear_before_reset=linear_before_reset, + ) + + T_GlobalAveragePool = TypeVar("T_GlobalAveragePool", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def GlobalAveragePool(self, X: T_GlobalAveragePool) -> T_GlobalAveragePool: + r"""[🌐 GlobalAveragePool(22)](https://onnx.ai/onnx/operators/onnx__GlobalAveragePool.html#globalaveragepool-22 "Online Documentation") + + + GlobalAveragePool consumes an input tensor X and applies average pooling across + the values in the same channel. This is equivalent to AveragePool with kernel size + equal to the spatial dimension of input tensor. + + Args: + X: (differentiable) Input data tensor from the previous operator; dimensions + for image case are (N x C x H x W), where N is the batch size, C is the + number of channels, and H and W are the height and the width of the + data. For non image case, the dimensions are in the form of (N x C x D1 + x D2 ... Dn), where N is the batch size. + """ + + schema = get_schema("GlobalAveragePool", 22, "") + op = Op(self, "GlobalAveragePool", schema) + return op(*self._prepare_inputs(schema, X)) + + T_GlobalLpPool = TypeVar("T_GlobalLpPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def GlobalLpPool(self, X: T_GlobalLpPool, *, p: int = 2) -> T_GlobalLpPool: + r"""[🌐 GlobalLpPool(22)](https://onnx.ai/onnx/operators/onnx__GlobalLpPool.html#globallppool-22 "Online Documentation") + + + GlobalLpPool consumes an input tensor X and applies lp pool pooling across + the values in the same channel. This is equivalent to LpPool with kernel size + equal to the spatial dimension of input tensor. + + Args: + X: (differentiable) Input data tensor from the previous operator; dimensions + for image case are (N x C x H x W), where N is the batch size, C is the + number of channels, and H and W are the height and the width of the + data. For non image case, the dimensions are in the form of (N x C x D1 + x D2 ... Dn), where N is the batch size. + + p: p value of the Lp norm used to pool over the input data. + """ + + schema = get_schema("GlobalLpPool", 22, "") + op = Op(self, "GlobalLpPool", schema) + return op(*self._prepare_inputs(schema, X), p=p) + + T_GlobalMaxPool = TypeVar("T_GlobalMaxPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def GlobalMaxPool(self, X: T_GlobalMaxPool) -> T_GlobalMaxPool: + r"""[🌐 GlobalMaxPool(22)](https://onnx.ai/onnx/operators/onnx__GlobalMaxPool.html#globalmaxpool-22 "Online Documentation") + + + GlobalMaxPool consumes an input tensor X and applies max pooling across + the values in the same channel. This is equivalent to MaxPool with kernel size + equal to the spatial dimension of input tensor. + + Args: + X: (differentiable) Input data tensor from the previous operator; dimensions + for image case are (N x C x H x W), where N is the batch size, C is the + number of channels, and H and W are the height and the width of the + data. For non image case, the dimensions are in the form of (N x C x D1 + x D2 ... Dn), where N is the batch size. + """ + + schema = get_schema("GlobalMaxPool", 22, "") + op = Op(self, "GlobalMaxPool", schema) + return op(*self._prepare_inputs(schema, X)) + + T1_GridSample = TypeVar( + "T1_GridSample", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + T2_GridSample = TypeVar("T2_GridSample", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def GridSample( + self, + X: T1_GridSample, + grid: T2_GridSample, + *, + align_corners: int = 0, + mode: str = "linear", + padding_mode: str = "zeros", + ) -> T1_GridSample: + r"""[🌐 GridSample(22)](https://onnx.ai/onnx/operators/onnx__GridSample.html#gridsample-22 "Online Documentation") + + + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. + For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), + the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), + the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). + More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), + the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + + The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). + The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values + at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) + and a padding mode (for `grid` positions falling outside the 2-dimensional image). + + For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. + They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + + The GridSample operator is often used in doing grid generator and sampler in the + [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). + See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + + + Args: + X: (differentiable) Input tensor of rank r+2 that has shape (N, C, D1, D2, + ..., Dr), where N is the batch size, C is the number of channels, D1, + D2, ..., Dr are the spatial dimensions. + + grid: (non-differentiable) Input offset of shape (N, D1_out, D2_out, ..., + Dr_out, r), where D1_out, D2_out, ..., Dr_out are the spatial dimensions + of the grid and output, and r is the number of spatial dimensions. Grid + specifies the sampling locations normalized by the input spatial + dimensions. Therefore, it should have most values in the range of [-1, + 1]. If the grid has values outside the range of [-1, 1], the + corresponding outputs will be handled as defined by padding_mode. + Following computer vision convention, the coordinates in the length-r + location vector are listed from the innermost tensor dimension to the + outermost, the opposite of regular tensor indexing. + + align_corners: If align_corners=1, the extrema (-1 and 1) are considered as + referring to the center points of the input's corner pixels (voxels, + etc.). If align_corners=0, they are instead considered as referring to + the corner points of the input's corner pixels (voxels, etc.), making + the sampling more resolution agnostic. + + mode: Three interpolation modes: linear (default), nearest and cubic. The + "linear" mode includes linear and N-linear interpolation modes depending + on the number of spatial dimensions of the input tensor (i.e. linear for + 1 spatial dimension, bilinear for 2 spatial dimensions, etc.). The + "cubic" mode also includes N-cubic interpolation modes following the + same rules. The "nearest" mode rounds to the nearest even index when the + sampling point falls halfway between two indices. + + padding_mode: Support padding modes for outside grid values: + `zeros`(default), `border`, `reflection`. zeros: use 0 for out-of-bound + grid locations, border: use border values for out-of-bound grid + locations, reflection: use values at locations reflected by the border + for out-of-bound grid locations. If index 0 represents the margin pixel, + the reflected value at index -1 will be the same as the value at index + 1. For location far away from the border, it will keep being reflected + until becoming in bound. If pixel location x = -3.5 reflects by border + -1 and becomes x' = 1.5, then reflects by border 1 and becomes x'' = + 0.5. + """ + + schema = get_schema("GridSample", 22, "") + op = Op(self, "GridSample", schema) + return op( + *self._prepare_inputs(schema, X, grid), + align_corners=align_corners, + mode=mode, + padding_mode=padding_mode, + ) + + T_HardSigmoid = TypeVar("T_HardSigmoid", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def HardSigmoid( + self, X: T_HardSigmoid, *, alpha: float = 0.20000000298023224, beta: float = 0.5 + ) -> T_HardSigmoid: + r"""[🌐 HardSigmoid(22)](https://onnx.ai/onnx/operators/onnx__HardSigmoid.html#hardsigmoid-22 "Online Documentation") + + + HardSigmoid takes one input data (Tensor) and produces one output data + (Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)), + is applied to the tensor elementwise. + + + Args: + X: (differentiable) Input tensor + + alpha: Value of alpha. + + beta: Value of beta. + """ + + schema = get_schema("HardSigmoid", 22, "") + op = Op(self, "HardSigmoid", schema) + return op(*self._prepare_inputs(schema, X), alpha=alpha, beta=beta) + + T_HardSwish = TypeVar("T_HardSwish", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def HardSwish(self, X: T_HardSwish) -> T_HardSwish: + r"""[🌐 HardSwish(22)](https://onnx.ai/onnx/operators/onnx__HardSwish.html#hardswish-22 "Online Documentation") + + + HardSwish takes one input data (Tensor) and produces one output data (Tensor) where + the HardSwish function, y = x * max(0, min(1, alpha * x + beta)) = x * HardSigmoid(x), + where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise. + + + Args: + X: (differentiable) Input tensor + """ + + schema = get_schema("HardSwish", 22, "") + op = Op(self, "HardSwish", schema) + return op(*self._prepare_inputs(schema, X)) + + T_InstanceNormalization = TypeVar( + "T_InstanceNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16 + ) + + def InstanceNormalization( + self, + input: T_InstanceNormalization, + scale: T_InstanceNormalization, + B: T_InstanceNormalization, + *, + epsilon: float = 9.999999747378752e-06, + ) -> T_InstanceNormalization: + r"""[🌐 InstanceNormalization(22)](https://onnx.ai/onnx/operators/onnx__InstanceNormalization.html#instancenormalization-22 "Online Documentation") + + + Carries out instance normalization as described in the paper + https://arxiv.org/abs/1607.08022. + + y = scale * (x - mean) / sqrt(variance + epsilon) + B, + where mean and variance are computed per instance per channel. + + + + Args: + input: (differentiable) Input data tensor from the previous operator; + dimensions for image case are (N x C x H x W), where N is the batch + size, C is the number of channels, and H and W are the height and the + width of the data. For non image case, the dimensions are in the form of + (N x C x D1 x D2 ... Dn), where N is the batch size. + + scale: (differentiable) The input 1-dimensional scale tensor of size C. + + B: (differentiable) The input 1-dimensional bias tensor of size C. + + epsilon: The epsilon value to use to avoid division by zero. + """ + + schema = get_schema("InstanceNormalization", 22, "") + op = Op(self, "InstanceNormalization", schema) + return op(*self._prepare_inputs(schema, input, scale, B), epsilon=epsilon) + + T_LSTM = TypeVar("T_LSTM", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T1_LSTM: TypeAlias = INT32 + + def LSTM( + self, + X: T_LSTM, + W: T_LSTM, + R: T_LSTM, + B: Optional[T_LSTM] = None, + sequence_lens: Optional[T1_LSTM] = None, + initial_h: Optional[T_LSTM] = None, + initial_c: Optional[T_LSTM] = None, + P: Optional[T_LSTM] = None, + *, + activation_alpha: Optional[Sequence[float]] = None, + activation_beta: Optional[Sequence[float]] = None, + activations: Optional[Sequence[str]] = None, + clip: Optional[float] = None, + direction: str = "forward", + hidden_size: Optional[int] = None, + input_forget: int = 0, + layout: int = 0, + ) -> Tuple[T_LSTM, T_LSTM, T_LSTM]: + r"""[🌐 LSTM(22)](https://onnx.ai/onnx/operators/onnx__LSTM.html#lstm-22 "Online Documentation") + + + Computes an one-layer LSTM. This operator is usually supported via some + custom implementation such as CuDNN. + + Notations: + + * `X` - input tensor + * `i` - input gate + * `o` - output gate + * `f` - forget gate + * `c` - cell gate + * `t` - time step (t-1 means previous time step) + * `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates + * `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates + * `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates + * `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates + * `P[iof]` - P peephole weight vector for input, output, and forget gates + * `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates + * `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates + * `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates + * `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates + * `PB[iof]` - P peephole weight vector for backward input, output, and forget gates + * `H` - Hidden state + * `num_directions` - 2 if direction == bidirectional else 1 + + Activation functions: + + * Relu(x) - max(0, x) + * Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) + * Sigmoid(x) - 1/(1 + e^{-x}) + + NOTE: Below are optional + + * Affine(x) - alpha*x + beta + * LeakyRelu(x) - x if x >= 0 else alpha * x + * ThresholdedRelu(x) - x if x >= alpha else 0 + * ScaledTanh(x) - alpha*Tanh(beta*x) + * HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) + * Elu(x) - x if x >= 0 else alpha*(e^x - 1) + * Softsign(x) - x/(1 + |x|) + * Softplus(x) - log(1 + e^x) + + Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): + + * it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) + * ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) + * ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + * Ct = ft (.) Ct-1 + it (.) ct + * ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) + * Ht = ot (.) h(Ct) + This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + + + Args: + X: (differentiable) The input sequences packed (and potentially padded) into + one 3-D tensor with the shape of `[seq_length, batch_size, input_size]`. + + W: (differentiable) The weight tensor for the gates. Concatenation of + `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0. The + tensor has shape `[num_directions, 4*hidden_size, input_size]`. + + R: (differentiable) The recurrence weight tensor. Concatenation of `R[iofc]` + and `RB[iofc]` (if bidirectional) along dimension 0. This tensor has + shape `[num_directions, 4*hidden_size, hidden_size]`. + + B: (optional, differentiable) The bias tensor for input gate. Concatenation + of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if + bidirectional) along dimension 0. This tensor has shape + `[num_directions, 8*hidden_size]`. Optional: If not specified - assumed + to be 0. + + sequence_lens: (optional, non-differentiable) Optional tensor specifying + lengths of the sequences in a batch. If not specified - assumed all + sequences in the batch to have length `seq_length`. It has shape + `[batch_size]`. + + initial_h: (optional, non-differentiable) Optional initial value of the + hidden. If not specified - assumed to be 0. It has shape + `[num_directions, batch_size, hidden_size]`. + + initial_c: (optional, non-differentiable) Optional initial value of the + cell. If not specified - assumed to be 0. It has shape `[num_directions, + batch_size, hidden_size]`. + + P: (optional, differentiable) The weight tensor for peepholes. Concatenation + of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0. It has + shape `[num_directions, 3*hidde_size]`. Optional: If not specified - + assumed to be 0. + + activation_alpha: Optional scaling values used by some activation functions. + The values are consumed in the order of activation functions, for + example (f, g, h) in LSTM. Default values are the same as of + corresponding ONNX operators.For example with LeakyRelu, the default + alpha is 0.01. + + activation_beta: Optional scaling values used by some activation functions. + The values are consumed in the order of activation functions, for + example (f, g, h) in LSTM. Default values are the same as of + corresponding ONNX operators. + + activations: A list of 3 (or 6 if bidirectional) activation functions for + input, output, forget, cell, and hidden. The activation functions must + be one of the activation functions specified above. Optional: See the + equations for default if not specified. + + clip: Cell clip threshold. Clipping bounds the elements of a tensor in the + range of [-threshold, +threshold] and is applied to the input of + activations. No clip if not specified. + + direction: Specify if the RNN is forward, reverse, or bidirectional. Must be + one of forward (default), reverse, or bidirectional. + + hidden_size: Number of neurons in the hidden layer + + input_forget: Couple the input and forget gates if 1. + + layout: The shape format of inputs X, initial_h, initial_c and outputs Y, + Y_h, Y_c. If 0, the following shapes are expected: X.shape = + [seq_length, batch_size, input_size], Y.shape = [seq_length, + num_directions, batch_size, hidden_size], initial_h.shape = Y_h.shape = + initial_c.shape = Y_c.shape = [num_directions, batch_size, hidden_size]. + If 1, the following shapes are expected: X.shape = [batch_size, + seq_length, input_size], Y.shape = [batch_size, seq_length, + num_directions, hidden_size], initial_h.shape = Y_h.shape = + initial_c.shape = Y_c.shape = [batch_size, num_directions, hidden_size]. + """ + + schema = get_schema("LSTM", 22, "") + op = Op(self, "LSTM", schema) + return op( + *self._prepare_inputs(schema, X, W, R, B, sequence_lens, initial_h, initial_c, P), + activation_alpha=activation_alpha, + activation_beta=activation_beta, + activations=activations, + clip=clip, + direction=direction, + hidden_size=hidden_size, + input_forget=input_forget, + layout=layout, + ) + + T_LpNormalization = TypeVar("T_LpNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def LpNormalization( + self, input: T_LpNormalization, *, axis: int = -1, p: int = 2 + ) -> T_LpNormalization: + r"""[🌐 LpNormalization(22)](https://onnx.ai/onnx/operators/onnx__LpNormalization.html#lpnormalization-22 "Online Documentation") + + + Given a matrix, apply Lp-normalization along the provided axis. + + + Args: + input: (differentiable) Input matrix + + axis: The axis on which to apply normalization, -1 mean last axis. + + p: The order of the normalization, only 1 or 2 are supported. + """ + + schema = get_schema("LpNormalization", 22, "") + op = Op(self, "LpNormalization", schema) + return op(*self._prepare_inputs(schema, input), axis=axis, p=p) + + T_LpPool = TypeVar("T_LpPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def LpPool( + self, + X: T_LpPool, + *, + auto_pad: str = "NOTSET", + ceil_mode: int = 0, + dilations: Optional[Sequence[int]] = None, + kernel_shape: Sequence[int], + p: int = 2, + pads: Optional[Sequence[int]] = None, + strides: Optional[Sequence[int]] = None, + ) -> T_LpPool: + r"""[🌐 LpPool(22)](https://onnx.ai/onnx/operators/onnx__LpPool.html#lppool-22 "Online Documentation") + + + LpPool consumes an input tensor X and applies Lp pooling across + the tensor according to kernel sizes, stride sizes, and pad lengths. + Lp pooling consisting of computing the Lp norm on all values of a subset + of the input tensor according to the kernel size and downsampling the + data into the output tensor Y for further processing. The output spatial shape will be following: + ``` + output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - {kernelSpatialShape}) / strides_spatial_shape[i] + 1) + ``` + or + ``` + output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - {kernelSpatialShape}) / strides_spatial_shape[i] + 1) + ``` + if ceil_mode is enabled `pad_shape[i]` is the sum of pads along axis `i`. + + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: + ``` + VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - {kernelSpatialShape} + 1) / strides_spatial_shape[i]) + SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i]) + ``` + And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`: + ``` + pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + {kernelSpatialShape} - input_spatial_shape[i] + ``` + + Args: + X: (differentiable) Input data tensor from the previous operator; dimensions + for image case are (N x C x H x W), where N is the batch size, C is the + number of channels, and H and W are the height and the width of the + data. For non image case, the dimensions are in the form of (N x C x D1 + x D2 ... Dn), where N is the batch size. + + auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID. + Where default value is NOTSET, which means explicit padding is used. + SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] = + ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is + split between the two sides equally or almost equally (depending on + whether it is even or odd). In case the padding is an odd number, the + extra padding is added at the end for SAME_UPPER and at the beginning + for SAME_LOWER. + + ceil_mode: Whether to use ceil or floor (default) to compute the output + shape. + + dilations: dilation value along each spatial axis of the filter. If not + present, the dilation defaults is 1 along each spatial axis. + + kernel_shape: The size of the kernel along each axis. + + p: p value of the Lp norm used to pool over the input data. + + pads: Padding for the beginning and ending along each spatial axis, it can + take any value greater than or equal to 0. The value represent the + number of pixels added to the beginning and end part of the + corresponding axis. `pads` format should be as follow [x1_begin, + x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels + added at the beginning of axis `i` and xi_end, the number of pixels + added at the end of axis `i`. This attribute cannot be used + simultaneously with auto_pad attribute. If not present, the padding + defaults to 0 along start and end of each spatial axis. + + strides: Stride along each spatial axis. If not present, the stride defaults + to 1 along each spatial axis. + """ + + schema = get_schema("LpPool", 22, "") + op = Op(self, "LpPool", schema) + return op( + *self._prepare_inputs(schema, X), + auto_pad=auto_pad, + ceil_mode=ceil_mode, + dilations=dilations, + kernel_shape=kernel_shape, + p=p, + pads=pads, + strides=strides, + ) + + T_MaxPool = TypeVar("T_MaxPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16, INT8, UINT8) + + I_MaxPool: TypeAlias = INT64 + + def MaxPool( + self, + X: T_MaxPool, + *, + auto_pad: str = "NOTSET", + ceil_mode: int = 0, + dilations: Optional[Sequence[int]] = None, + kernel_shape: Sequence[int], + pads: Optional[Sequence[int]] = None, + storage_order: int = 0, + strides: Optional[Sequence[int]] = None, + ) -> Tuple[T_MaxPool, I_MaxPool]: + r"""[🌐 MaxPool(22)](https://onnx.ai/onnx/operators/onnx__MaxPool.html#maxpool-22 "Online Documentation") + + + MaxPool consumes an input tensor X and applies max pooling across + the tensor according to kernel sizes, stride sizes, and pad lengths. + max pooling consisting of computing the max on all values of a + subset of the input tensor according to the kernel size and downsampling the + data into the output tensor Y for further processing. The output spatial shape is calculated differently + depending on whether explicit padding is used, where pads is employed, or auto padding is used, where auto_pad is utilized. + With explicit padding (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html?highlight=maxpool#torch.nn.MaxPool2d): + ``` + output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) + ``` + or + ``` + output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) + ``` + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. + + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: + ``` + VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) + 1) / strides_spatial_shape[i]) + SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i]) + ``` + or when ceil_mode is disabled (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AveragePooling2D): + ``` + VALID: output_spatial_shape[i] = floor((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i]) + 1 + SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = floor((input_spatial_shape[i] - 1) / strides_spatial_shape[i]) + 1 + ``` + And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`: + ``` + pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i] + ``` + The output of each pooling window is maximum number of elements exclude pad. + + + Args: + X: (differentiable) Input data tensor from the previous operator; dimensions + for image case are (N x C x H x W), where N is the batch size, C is the + number of channels, and H and W are the height and the width of the + data. For non image case, the dimensions are in the form of (N x C x D1 + x D2 ... Dn), where N is the batch size. Optionally, if dimension + denotation is in effect, the operation expects the input data tensor to + arrive with the dimension denotation of [DATA_BATCH, DATA_CHANNEL, + DATA_FEATURE, DATA_FEATURE ...]. + + auto_pad: auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or VALID. + Where default value is NOTSET, which means explicit padding is used. + SAME_UPPER or SAME_LOWER mean pad the input so that `output_shape[i] = + ceil(input_shape[i] / strides[i])` for each axis `i`. The padding is + split between the two sides equally or almost equally (depending on + whether it is even or odd). In case the padding is an odd number, the + extra padding is added at the end for SAME_UPPER and at the beginning + for SAME_LOWER. + + ceil_mode: Whether to use ceil or floor (default) to compute the output + shape. + + dilations: Dilation value along each spatial axis of filter. If not present, + the dilation defaults to 1 along each spatial axis. + + kernel_shape: The size of the kernel along each axis. + + pads: Padding for the beginning and ending along each spatial axis, it can + take any value greater than or equal to 0. The value represent the + number of pixels added to the beginning and end part of the + corresponding axis. `pads` format should be as follow [x1_begin, + x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels + added at the beginning of axis `i` and xi_end, the number of pixels + added at the end of axis `i`. This attribute cannot be used + simultaneously with auto_pad attribute. If not present, the padding + defaults to 0 along start and end of each spatial axis. + + storage_order: The storage order of the tensor. 0 is row major, and 1 is + column major. This attribute is used only to convert an n-tuple index + value into a single integer value for producing the second output. + + strides: Stride along each spatial axis. If not present, the stride defaults + to 1 along each spatial axis. + """ + + schema = get_schema("MaxPool", 22, "") + op = Op(self, "MaxPool", schema) + return op( + *self._prepare_inputs(schema, X), + auto_pad=auto_pad, + ceil_mode=ceil_mode, + dilations=dilations, + kernel_shape=kernel_shape, + pads=pads, + storage_order=storage_order, + strides=strides, + ) + + T_MaxRoiPool = TypeVar("T_MaxRoiPool", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def MaxRoiPool( + self, + X: T_MaxRoiPool, + rois: T_MaxRoiPool, + *, + pooled_shape: Sequence[int], + spatial_scale: float = 1.0, + ) -> T_MaxRoiPool: + r"""[🌐 MaxRoiPool(22)](https://onnx.ai/onnx/operators/onnx__MaxRoiPool.html#maxroipool-22 "Online Documentation") + + + ROI max pool consumes an input tensor X and region of interests (RoIs) to + apply max pooling across each RoI, to produce output 4-D tensor of shape + (num_rois, channels, pooled_shape[0], pooled_shape[1]). + + Args: + X: (differentiable) Input data tensor from the previous operator; dimensions + for image case are (N x C x H x W), where N is the batch size, C is the + number of channels, and H and W are the height and the width of the + data. + + rois: (non-differentiable) RoIs (Regions of Interest) to pool over. Should + be a 2-D tensor of shape (num_rois, 5) given as [[batch_id, x1, y1, x2, + y2], ...]. + + pooled_shape: ROI pool output shape (height, width). + + spatial_scale: Multiplicative spatial scale factor to translate ROI + coordinates from their input scale to the scale used when pooling. + """ + + schema = get_schema("MaxRoiPool", 22, "") + op = Op(self, "MaxRoiPool", schema) + return op( + *self._prepare_inputs(schema, X, rois), + pooled_shape=pooled_shape, + spatial_scale=spatial_scale, + ) + + T1_MaxUnpool = TypeVar("T1_MaxUnpool", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T2_MaxUnpool: TypeAlias = INT64 + + def MaxUnpool( + self, + X: T1_MaxUnpool, + I: T2_MaxUnpool, + output_shape: Optional[T2_MaxUnpool] = None, + *, + kernel_shape: Sequence[int], + pads: Optional[Sequence[int]] = None, + strides: Optional[Sequence[int]] = None, + ) -> T1_MaxUnpool: + r"""[🌐 MaxUnpool(22)](https://onnx.ai/onnx/operators/onnx__MaxUnpool.html#maxunpool-22 "Online Documentation") + + + MaxUnpool essentially computes the partial inverse of the MaxPool op. + The input information to this op is typically the output information from a MaxPool op. The first + input tensor X is the tensor that needs to be unpooled, which is typically the pooled tensor (first output) + from MaxPool. The second input tensor, I, contains the indices to the (locally maximal) elements corresponding + to the elements in the first input tensor X. Input tensor I is typically the second output of the MaxPool op. + The third (optional) input is a tensor that specifies the output size of the unpooling operation. + + MaxUnpool is intended to do 'partial' inverse of the MaxPool op. 'Partial' because all the non-maximal + values from the original input to MaxPool are set to zero in the output of the MaxUnpool op. Pooling + the result of an unpooling operation should give back the original input to the unpooling op. + + MaxUnpool can produce the same output size for several input sizes, which makes unpooling op ambiguous. + The third input argument, output_size, is meant to disambiguate the op and produce output tensor of + known/predictable size. + + In addition to the inputs, MaxUnpool takes three attributes, namely kernel_shape, strides, and pads, + which define the exact unpooling op. The attributes typically have the same values as the corresponding + pooling op that the unpooling op is trying to invert. + + + Args: + X: (differentiable) Input data tensor that has to be unpooled. This tensor + is typically the first output of the MaxPool op.Dimensions for image + case are (N x C x H x W), where N is the batch size, C is the number of + channels, and H and W are the height and the width of the data. For + non-image case, the dimensions are in the form of (N x C x D1 x D2 ... + Dn), where N is the batch size. Optionally, if dimension denotation is + in effect, the operation expects the input data tensor to arrive with + the dimension denotation of [DATA_BATCH, DATA_CHANNEL, DATA_FEATURE, + DATA_FEATURE ...]. + + I: (non-differentiable) Input data tensor containing the indices + corresponding to elements in the first input tensor X.This tensor is + typically the second output of the MaxPool op.Dimensions must be the + same as input tensor X. The indices are linear, i.e. computed + considering the tensor as flattened 1-D tensor, assuming row-major + storage. Also, the linear indices should not consider padding. So the + values in indices are in the range [0, N x C x D1 x ... x Dn). + + output_shape: (optional, non-differentiable) The shape of the output can be + explicitly set which will cause pads values to be auto generated. If + 'output_shape' is specified, 'pads' values are ignored. + + kernel_shape: The size of the kernel along each axis. + + pads: Padding for the beginning and ending along each spatial axis, it can + take any value greater than or equal to 0. The value represent the + number of pixels added to the beginning and end part of the + corresponding axis. `pads` format should be as follow [x1_begin, + x2_begin...x1_end, x2_end,...], where xi_begin the number of pixels + added at the beginning of axis `i` and xi_end, the number of pixels + added at the end of axis `i`. This attribute cannot be used + simultaneously with auto_pad attribute. If not present, the padding + defaults to 0 along start and end of each spatial axis. + + strides: Stride along each spatial axis. If not present, the stride defaults + to 1 along each spatial axis. + """ + + schema = get_schema("MaxUnpool", 22, "") + op = Op(self, "MaxUnpool", schema) + return op( + *self._prepare_inputs(schema, X, I, output_shape), + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + T_Mish = TypeVar("T_Mish", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Mish(self, X: T_Mish) -> T_Mish: + r"""[🌐 Mish(22)](https://onnx.ai/onnx/operators/onnx__Mish.html#mish-22 "Online Documentation") + + + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + Perform the linear unit element-wise on the input tensor X using formula: + + :: + + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) + + + + + Args: + X: (differentiable) Input tensor + """ + + schema = get_schema("Mish", 22, "") + op = Op(self, "Mish", schema) + return op(*self._prepare_inputs(schema, X)) + + T1_Multinomial = TypeVar("T1_Multinomial", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T2_Multinomial: TypeAlias = Union[INT32, INT64] + + def Multinomial( + self, + input: T1_Multinomial, + *, + dtype: int = 6, + sample_size: int = 1, + seed: Optional[float] = None, + ) -> T2_Multinomial: + r"""[🌐 Multinomial(22)](https://onnx.ai/onnx/operators/onnx__Multinomial.html#multinomial-22 "Online Documentation") + + + Generate a tensor of samples from a multinomial distribution according to the probabilities + of each of the possible outcomes. + + + Args: + input: Input tensor with shape [batch_size, class_size], where class_size is + the number of all possible outcomes. Each value along the axis zero + represents the unnormalized log-probability of each corresponding + outcome in a batch. + + dtype: (Optional) The data type for the elements of the output tensor, if + not specified, we will use int32. + + sample_size: Number of times to sample. + + seed: (Optional) Seed to the random generator, if not specified we will auto + generate one. + """ + + schema = get_schema("Multinomial", 22, "") + op = Op(self, "Multinomial", schema) + return op( + *self._prepare_inputs(schema, input), + dtype=dtype, + sample_size=sample_size, + seed=seed, + ) + + T_NegativeLogLikelihoodLoss = TypeVar( + "T_NegativeLogLikelihoodLoss", BFLOAT16, DOUBLE, FLOAT, FLOAT16 + ) + + Tind_NegativeLogLikelihoodLoss = TypeVar("Tind_NegativeLogLikelihoodLoss", INT32, INT64) + + def NegativeLogLikelihoodLoss( + self, + input: T_NegativeLogLikelihoodLoss, + target: Tind_NegativeLogLikelihoodLoss, + weight: Optional[T_NegativeLogLikelihoodLoss] = None, + *, + ignore_index: Optional[int] = None, + reduction: str = "mean", + ) -> T_NegativeLogLikelihoodLoss: + r"""[🌐 NegativeLogLikelihoodLoss(22)](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html#negativeloglikelihoodloss-22 "Online Documentation") + + + A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss. + Its "input" tensor has the shape of (N, C, d1, d2, ..., dk) where k >= 0. + The "input" tensor contains log-probabilities for input[n, :, d_1, d_2,..., d_k] being in a class of [0, C). + The operator's "target" input tensor has the shape of (N, d1, d2, ..., dk). It encodes class labels (one of C classes) + or it may contain a special value (indicated by an attribute ignore_index) for N x d1 x d2 x ... x dk samples. + The loss value for input[n, :, d_1, d_2,...d_k] being classified as class c = target[n][d_1][d_2]...[d_k] is computed as: + + :: + + loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k]. + + + + When an optional "weight" is provided, the sample loss is calculated as: + + :: + + loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k] * weight[c]. + + + + loss is zero for the case when target-value equals ignore_index. + + :: + + loss[n][d_1][d_2]...[d_k] = 0, when target[n][d_1][d_2]...[d_k] = ignore_index + + + + If "reduction" attribute is set to "none", the operator's output will be the above loss with shape (N, d1, d2, ..., dk). + If "reduction" attribute is set to "mean" (the default attribute value), the output loss is (weight) averaged: + + :: + + mean(loss), if "weight" is not provided, + + + + or if weight is provided, + + :: + + sum(loss) / sum(weight[target[n][d_1][d_2]...[d_k]]]), for all samples. + + + + If "reduction" attribute is set to "sum", the output is a scalar: `sum(loss)`. + + See also https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss. + + Example 1: + + :: + + // negative log likelihood loss, "none" reduction + N, C, d1 = 2, 3, 2 + input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], + [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] + target = [[2, 1], [0, 2]] + + loss = np.zeros((N, d1)) + for n in range(N): + for d_1 in range(d1): + c = target[n][d_1] + loss[n][d_1] = -input[n][c][d_1] + + // print(loss) + // [[-3. -2.] + // [-0. -2.]] + + + + Example 2: + + :: + + // weighted negative log likelihood loss, sum reduction + N, C, d1 = 2, 3, 2 + input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], + [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] + target = [[2, 1], [0, 2]] + weight = [0.2, 0.3, 0.1] + loss = np.zeros((N, d1)) + for n in range(N): + for d_1 in range(d1): + c = target[n][d_1] + loss[n][d_1] = -input[n][c][d_1] * weight[c] + + loss = np.sum(loss) + // print(loss) + // -1.1 + + + + Example 3: + + :: + + // weighted negative log likelihood loss, mean reduction + N, C, d1 = 2, 3, 2 + input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], + [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] + target = [[2, 1], [0, 2]] + weight = [0.2, 0.3, 0.1] + loss = np.zeros((N, d1)) + weight_total = 0 + for n in range(N): + for d_1 in range(d1): + c = target[n][d_1] + loss[n][d_1] = -input[n][c][d_1] * weight[c] + weight_total = weight_total + weight[c] + + loss = np.sum(loss) / weight_total + // print(loss) + // -1.57 + + + + + Args: + input: (differentiable) Input tensor of shape (N, C) or (N, C, d1, d2, ..., + dk). + + target: (non-differentiable) Target tensor of shape (N) or (N, d1, d2, ..., + dk). Target element value shall be in range of [0, C). If ignore_index + is specified, it may have a value outside [0, C) and the target values + should either be in the range [0, C) or have the value ignore_index. + + weight: (optional, non-differentiable) Optional rescaling weight tensor. If + given, it has to be a tensor of size C. Otherwise, it is treated as if + having all ones. + + ignore_index: Specifies a target value that is ignored and does not + contribute to the input gradient. It's an optional value. + + reduction: Type of reduction to apply to loss: none, sum, mean (default). + 'none': the output is the loss for each sample. 'sum': the output will + be summed. 'mean': the sum of the output will be divided by the sum of + applied weights. + """ + + schema = get_schema("NegativeLogLikelihoodLoss", 22, "") + op = Op(self, "NegativeLogLikelihoodLoss", schema) + return op( + *self._prepare_inputs(schema, input, target, weight), + ignore_index=ignore_index, + reduction=reduction, + ) + + T_RNN = TypeVar("T_RNN", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T1_RNN: TypeAlias = INT32 + + def RNN( + self, + X: T_RNN, + W: T_RNN, + R: T_RNN, + B: Optional[T_RNN] = None, + sequence_lens: Optional[T1_RNN] = None, + initial_h: Optional[T_RNN] = None, + *, + activation_alpha: Optional[Sequence[float]] = None, + activation_beta: Optional[Sequence[float]] = None, + activations: Sequence[str] = ("Tanh", "Tanh"), + clip: Optional[float] = None, + direction: str = "forward", + hidden_size: Optional[int] = None, + layout: int = 0, + ) -> Tuple[T_RNN, T_RNN]: + r"""[🌐 RNN(22)](https://onnx.ai/onnx/operators/onnx__RNN.html#rnn-22 "Online Documentation") + + + Computes an one-layer simple RNN. This operator is usually supported + via some custom implementation such as CuDNN. + + Notations: + + * `X` - input tensor + * `i` - input gate + * `t` - time step (t-1 means previous time step) + * `Wi` - W parameter weight matrix for input gate + * `Ri` - R recurrence weight matrix for input gate + * `Wbi` - W parameter bias vector for input gate + * `Rbi` - R parameter bias vector for input gate + * `WBi` - W parameter weight matrix for backward input gate + * `RBi` - R recurrence weight matrix for backward input gate + * `WBbi` - WR bias vectors for backward input gate + * `RBbi` - RR bias vectors for backward input gate + * `H` - Hidden state + * `num_directions` - 2 if direction == bidirectional else 1 + + Activation functions: + + * Relu(x) - max(0, x) + * Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) + * Sigmoid(x) - 1/(1 + e^{-x}) + + NOTE: Below are optional + + * Affine(x) - alpha*x + beta + * LeakyRelu(x) - x if x >= 0 else alpha * x + * ThresholdedRelu(x) - x if x >= alpha else 0 + * ScaledTanh(x) - alpha*Tanh(beta*x) + * HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) + * Elu(x) - x if x >= 0 else alpha*(e^x - 1) + * Softsign(x) - x/(1 + |x|) + * Softplus(x) - log(1 + e^x) + + Equations (Default: f=Tanh): + + * Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + This operator has **optional** inputs/outputs. See `ONNX `_ for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + + + Args: + X: (differentiable) The input sequences packed (and potentially padded) into + one 3-D tensor with the shape of `[seq_length, batch_size, input_size]`. + + W: (differentiable) The weight tensor for input gate. Concatenation of `Wi` + and `WBi` (if bidirectional). The tensor has shape `[num_directions, + hidden_size, input_size]`. + + R: (differentiable) The recurrence weight tensor. Concatenation of `Ri` and + `RBi` (if bidirectional). The tensor has shape `[num_directions, + hidden_size, hidden_size]`. + + B: (optional, differentiable) The bias tensor for input gate. Concatenation + of `[Wbi, Rbi]` and `[WBbi, RBbi]` (if bidirectional). The tensor has + shape `[num_directions, 2*hidden_size]`. Optional: If not specified - + assumed to be 0. + + sequence_lens: (optional, non-differentiable) Optional tensor specifying + lengths of the sequences in a batch. If not specified - assumed all + sequences in the batch to have length `seq_length`. It has shape + `[batch_size]`. + + initial_h: (optional, non-differentiable) Optional initial value of the + hidden. If not specified - assumed to be 0. It has shape + `[num_directions, batch_size, hidden_size]`. + + activation_alpha: Optional scaling values used by some activation functions. + The values are consumed in the order of activation functions, for + example (f, g, h) in LSTM. Default values are the same as of + corresponding ONNX operators.For example with LeakyRelu, the default + alpha is 0.01. + + activation_beta: Optional scaling values used by some activation functions. + The values are consumed in the order of activation functions, for + example (f, g, h) in LSTM. Default values are the same as of + corresponding ONNX operators. + + activations: One (or two if bidirectional) activation function for input + gate. The activation function must be one of the activation functions + specified above. Optional: Default `Tanh` if not specified. + + clip: Cell clip threshold. Clipping bounds the elements of a tensor in the + range of [-threshold, +threshold] and is applied to the input of + activations. No clip if not specified. + + direction: Specify if the RNN is forward, reverse, or bidirectional. Must be + one of forward (default), reverse, or bidirectional. + + hidden_size: Number of neurons in the hidden layer + + layout: The shape format of inputs X, initial_h and outputs Y, Y_h. If 0, + the following shapes are expected: X.shape = [seq_length, batch_size, + input_size], Y.shape = [seq_length, num_directions, batch_size, + hidden_size], initial_h.shape = Y_h.shape = [num_directions, batch_size, + hidden_size]. If 1, the following shapes are expected: X.shape = + [batch_size, seq_length, input_size], Y.shape = [batch_size, seq_length, + num_directions, hidden_size], initial_h.shape = Y_h.shape = [batch_size, + num_directions, hidden_size]. + """ + + schema = get_schema("RNN", 22, "") + op = Op(self, "RNN", schema) + return op( + *self._prepare_inputs(schema, X, W, R, B, sequence_lens, initial_h), + activation_alpha=activation_alpha, + activation_beta=activation_beta, + activations=activations, + clip=clip, + direction=direction, + hidden_size=hidden_size, + layout=layout, + ) + + T_RandomNormal: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16] + + def RandomNormal( + self, + *, + dtype: int = 1, + mean: float = 0.0, + scale: float = 1.0, + seed: Optional[float] = None, + shape: Sequence[int], + ) -> T_RandomNormal: + r"""[🌐 RandomNormal(22)](https://onnx.ai/onnx/operators/onnx__RandomNormal.html#randomnormal-22 "Online Documentation") + + + Generate a tensor with random values drawn from a normal distribution. The shape + of the tensor is specified by the `shape` argument and the parameter of the normal distribution + specified by `mean` and `scale`. + + The data type is specified by the 'dtype' argument. The 'dtype' argument must + be one of the data types specified in the 'DataType' enum field in the + TensorProto message. + + + Args: + dtype: The data type for the elements of the output tensor. Default is + TensorProto::FLOAT. + + mean: The mean of the normal distribution. + + scale: The standard deviation of the normal distribution. + + seed: (Optional) Seed to the random generator, if not specified we will auto + generate one. + + shape: The shape of the output tensor. + """ + + schema = get_schema("RandomNormal", 22, "") + op = Op(self, "RandomNormal", schema) + return op(dtype=dtype, mean=mean, scale=scale, seed=seed, shape=shape) + + T1_RandomNormalLike = TypeVar( + "T1_RandomNormalLike", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + T2_RandomNormalLike: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16] + + def RandomNormalLike( + self, + input: T1_RandomNormalLike, + *, + dtype: Optional[int] = None, + mean: float = 0.0, + scale: float = 1.0, + seed: Optional[float] = None, + ) -> T2_RandomNormalLike: + r"""[🌐 RandomNormalLike(22)](https://onnx.ai/onnx/operators/onnx__RandomNormalLike.html#randomnormallike-22 "Online Documentation") + + + Generate a tensor with random values drawn from a normal distribution. + The shape of the output tensor is copied from the shape of the input tensor, + and the parameters of the normal distribution are specified by `mean` and `scale`. + + The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided. + The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the + TensorProto message, and be valid as an output type. + + + Args: + input: Input tensor to copy shape and optionally type information from. + + dtype: (Optional) The data type for the elements of the output tensor, if + not specified, we will use the data type of the input tensor. + + mean: The mean of the normal distribution. + + scale: The standard deviation of the normal distribution. + + seed: (Optional) Seed to the random generator, if not specified we will auto + generate one. + """ + + schema = get_schema("RandomNormalLike", 22, "") + op = Op(self, "RandomNormalLike", schema) + return op( + *self._prepare_inputs(schema, input), + dtype=dtype, + mean=mean, + scale=scale, + seed=seed, + ) + + T_RandomUniform: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16] + + def RandomUniform( + self, + *, + dtype: int = 1, + high: float = 1.0, + low: float = 0.0, + seed: Optional[float] = None, + shape: Sequence[int], + ) -> T_RandomUniform: + r"""[🌐 RandomUniform(22)](https://onnx.ai/onnx/operators/onnx__RandomUniform.html#randomuniform-22 "Online Documentation") + + + Generate a tensor with random values drawn from a uniform distribution. The shape + of the tensor is specified by the `shape` argument and the range by `low` and `high`. + + The data type is specified by the 'dtype' argument. The 'dtype' argument must + be one of the data types specified in the 'DataType' enum field in the + TensorProto message. + + + Args: + dtype: The data type for the elements of the output tensor. If not + specified, default is TensorProto::FLOAT. + + high: Upper boundary of the output values. + + low: Lower boundary of the output values. + + seed: (Optional) Seed to the random generator, if not specified we will auto + generate one. + + shape: The shape of the output tensor. + """ + + schema = get_schema("RandomUniform", 22, "") + op = Op(self, "RandomUniform", schema) + return op(dtype=dtype, high=high, low=low, seed=seed, shape=shape) + + T1_RandomUniformLike = TypeVar( + "T1_RandomUniformLike", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + T2_RandomUniformLike: TypeAlias = Union[BFLOAT16, DOUBLE, FLOAT, FLOAT16] + + def RandomUniformLike( + self, + input: T1_RandomUniformLike, + *, + dtype: Optional[int] = None, + high: float = 1.0, + low: float = 0.0, + seed: Optional[float] = None, + ) -> T2_RandomUniformLike: + r"""[🌐 RandomUniformLike(22)](https://onnx.ai/onnx/operators/onnx__RandomUniformLike.html#randomuniformlike-22 "Online Documentation") + + + Generate a tensor with random values drawn from a uniform distribution. + The shape of the output tensor is copied from the shape of the input tensor, + and the parameters of the uniform distribution are specified by `low` and `high`. + + The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided. + The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the + TensorProto message and be valid as an output type. + + + Args: + input: Input tensor to copy shape and optionally type information from. + + dtype: (Optional) The data type for the elements of the output tensor, if + not specified, we will use the data type of the input tensor. + + high: Upper boundary of the output values. + + low: Lower boundary of the output values. + + seed: (Optional) Seed to the random generator, if not specified we will auto + generate one. + """ + + schema = get_schema("RandomUniformLike", 22, "") + op = Op(self, "RandomUniformLike", schema) + return op( + *self._prepare_inputs(schema, input), + dtype=dtype, + high=high, + low=low, + seed=seed, + ) + + T1_RoiAlign = TypeVar("T1_RoiAlign", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T2_RoiAlign: TypeAlias = INT64 + + def RoiAlign( + self, + X: T1_RoiAlign, + rois: T1_RoiAlign, + batch_indices: T2_RoiAlign, + *, + coordinate_transformation_mode: str = "half_pixel", + mode: str = "avg", + output_height: int = 1, + output_width: int = 1, + sampling_ratio: int = 0, + spatial_scale: float = 1.0, + ) -> T1_RoiAlign: + r"""[🌐 RoiAlign(22)](https://onnx.ai/onnx/operators/onnx__RoiAlign.html#roialign-22 "Online Documentation") + + + Region of Interest (RoI) align operation described in the + [Mask R-CNN paper](https://arxiv.org/abs/1703.06870). + RoiAlign consumes an input tensor X and region of interests (rois) + to apply pooling across each RoI; it produces a 4-D tensor of shape + (num_rois, C, output_height, output_width). + + RoiAlign is proposed to avoid the misalignment by removing + quantizations while converting from original image into feature + map and from feature map into RoI feature; in each ROI bin, + the value of the sampled locations are computed directly + through bilinear interpolation. + + + Args: + X: Input data tensor from the previous operator; 4-D feature map of shape + (N, C, H, W), where N is the batch size, C is the number of channels, + and H and W are the height and the width of the data. + + rois: RoIs (Regions of Interest) to pool over; rois is 2-D input of shape + (num_rois, 4) given as [[x1, y1, x2, y2], ...]. The RoIs' coordinates + are in the coordinate system of the input image. Each coordinate set has + a 1:1 correspondence with the 'batch_indices' input. + + batch_indices: 1-D tensor of shape (num_rois,) with each element denoting + the index of the corresponding image in the batch. + + coordinate_transformation_mode: Allowed values are 'half_pixel' and + 'output_half_pixel'. Use the value 'half_pixel' to pixel shift the input + coordinates by -0.5 (the recommended behavior). Use the value + 'output_half_pixel' to omit the pixel shift for the input (use this for + a backward-compatible behavior). + + mode: The pooling method. Two modes are supported: 'avg' and 'max'. Default + is 'avg'. + + output_height: default 1; Pooled output Y's height. + + output_width: default 1; Pooled output Y's width. + + sampling_ratio: Number of sampling points in the interpolation grid used to + compute the output value of each pooled output bin. If > 0, then exactly + sampling_ratio x sampling_ratio grid points are used. If == 0, then an + adaptive number of grid points are used (computed as ceil(roi_width / + output_width), and likewise for height). Default is 0. + + spatial_scale: Multiplicative spatial scale factor to translate ROI + coordinates from their input spatial scale to the scale used when + pooling, i.e., spatial scale of the input feature map X relative to the + input image. E.g.; default is 1.0f. + """ + + schema = get_schema("RoiAlign", 22, "") + op = Op(self, "RoiAlign", schema) + return op( + *self._prepare_inputs(schema, X, rois, batch_indices), + coordinate_transformation_mode=coordinate_transformation_mode, + mode=mode, + output_height=output_height, + output_width=output_width, + sampling_ratio=sampling_ratio, + spatial_scale=spatial_scale, + ) + + T_Round = TypeVar("T_Round", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Round(self, X: T_Round) -> T_Round: + r"""[🌐 Round(22)](https://onnx.ai/onnx/operators/onnx__Round.html#round-22 "Online Documentation") + + + Round takes one input Tensor and rounds the values, element-wise, meaning + it finds the nearest integer for each value. + In case of halves, the rule is to round them to the nearest even integer. + If input x is integral, +0, -0, NaN, or infinite, x itself is returned. + The output tensor has the same shape and type as the input. + + Examples: + :: + + round([0.9]) = [1.0] + round([2.5]) = [2.0] + round([2.3]) = [2.0] + round([1.5]) = [2.0] + round([-4.5]) = [-4.0] + + + + + Args: + X: (non-differentiable) Input tensor + """ + + schema = get_schema("Round", 22, "") + op = Op(self, "Round", schema) + return op(*self._prepare_inputs(schema, X)) + + T_Selu = TypeVar("T_Selu", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Selu( + self, + X: T_Selu, + *, + alpha: float = 1.6732631921768188, + gamma: float = 1.0507010221481323, + ) -> T_Selu: + r"""[🌐 Selu(22)](https://onnx.ai/onnx/operators/onnx__Selu.html#selu-22 "Online Documentation") + + + Selu takes one input data (Tensor) and produces one output data + (Tensor) where the scaled exponential linear unit function, + `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`, + is applied to the tensor elementwise. + + + Args: + X: (differentiable) Input tensor + + alpha: Coefficient of SELU default to 1.67326319217681884765625 (i.e., + float32 approximation of 1.6732632423543772848170429916717). + + gamma: Coefficient of SELU default to 1.05070102214813232421875 (i.e., + float32 approximation of 1.0507009873554804934193349852946). + """ + + schema = get_schema("Selu", 22, "") + op = Op(self, "Selu", schema) + return op(*self._prepare_inputs(schema, X), alpha=alpha, gamma=gamma) + + T_Sin = TypeVar("T_Sin", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Sin(self, input: T_Sin) -> T_Sin: + r"""[🌐 Sin(22)](https://onnx.ai/onnx/operators/onnx__Sin.html#sin-22 "Online Documentation") + + + Calculates the sine of the given input tensor, element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Sin", 22, "") + op = Op(self, "Sin", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Sinh = TypeVar("T_Sinh", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Sinh(self, input: T_Sinh) -> T_Sinh: + r"""[🌐 Sinh(22)](https://onnx.ai/onnx/operators/onnx__Sinh.html#sinh-22 "Online Documentation") + + + Calculates the hyperbolic sine of the given input tensor element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Sinh", 22, "") + op = Op(self, "Sinh", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Softplus = TypeVar("T_Softplus", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Softplus(self, X: T_Softplus) -> T_Softplus: + r"""[🌐 Softplus(22)](https://onnx.ai/onnx/operators/onnx__Softplus.html#softplus-22 "Online Documentation") + + + Softplus takes one input data (Tensor) and produces one output data + (Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to + the tensor elementwise. + + + Args: + X: (differentiable) Input tensor + """ + + schema = get_schema("Softplus", 22, "") + op = Op(self, "Softplus", schema) + return op(*self._prepare_inputs(schema, X)) + + T_Softsign = TypeVar("T_Softsign", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Softsign(self, input: T_Softsign) -> T_Softsign: + r"""[🌐 Softsign(22)](https://onnx.ai/onnx/operators/onnx__Softsign.html#softsign-22 "Online Documentation") + + + Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Softsign", 22, "") + op = Op(self, "Softsign", schema) + return op(*self._prepare_inputs(schema, input)) + + T_Tan = TypeVar("T_Tan", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Tan(self, input: T_Tan) -> T_Tan: + r"""[🌐 Tan(22)](https://onnx.ai/onnx/operators/onnx__Tan.html#tan-22 "Online Documentation") + + + Calculates the tangent of the given input tensor, element-wise. + + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Tan", 22, "") + op = Op(self, "Tan", schema) + return op(*self._prepare_inputs(schema, input)) + + T_ThresholdedRelu = TypeVar("T_ThresholdedRelu", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def ThresholdedRelu( + self, X: T_ThresholdedRelu, *, alpha: float = 1.0 + ) -> T_ThresholdedRelu: + r"""[🌐 ThresholdedRelu(22)](https://onnx.ai/onnx/operators/onnx__ThresholdedRelu.html#thresholdedrelu-22 "Online Documentation") + + + ThresholdedRelu takes one input data (Tensor) and produces one output data + (Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise, + is applied to the tensor elementwise. + + + Args: + X: (differentiable) Input tensor + + alpha: Threshold value + """ + + schema = get_schema("ThresholdedRelu", 22, "") + op = Op(self, "ThresholdedRelu", schema) + return op(*self._prepare_inputs(schema, X), alpha=alpha) diff --git a/onnxscript/onnx_opset/_impl/opset23.py b/onnxscript/onnx_opset/_impl/opset23.py new file mode 100644 index 0000000000..73b7480073 --- /dev/null +++ b/onnxscript/onnx_opset/_impl/opset23.py @@ -0,0 +1,2210 @@ +# -------------------------------------------------------------------------- +# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ +# ⚙️ Generated by 'python -m opgen' +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=W0221,W0222,R0901,W0237 +# mypy: disable-error-code=override +# ruff: noqa: D214, D402, D405, D411, D412, D416 +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, TypeVar, Union + +from onnx import GraphProto, SparseTensorProto, TensorProto +from onnx.defs import get_schema +from typing_extensions import TypeAlias + +from onnxscript.onnx_opset._impl.opset22 import Opset22 +from onnxscript.onnx_types import ( + BFLOAT16, + BOOL, + COMPLEX64, + COMPLEX128, + DOUBLE, + FLOAT, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT16, + INT4, + INT8, + INT16, + INT32, + INT64, + STRING, + UINT4, + UINT8, + UINT16, + UINT32, + UINT64, +) +from onnxscript.values import Op, Opset + + +class Opset23(Opset22): + def __new__(cls): + return Opset.__new__(cls, "", 23) + + T1_Attention = TypeVar("T1_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T2_Attention = TypeVar("T2_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + U_Attention = TypeVar( + "U_Attention", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + def Attention( + self, + Q: T1_Attention, + K: T1_Attention, + V: T2_Attention, + attn_mask: Optional[U_Attention] = None, + past_key: Optional[T1_Attention] = None, + past_value: Optional[T2_Attention] = None, + *, + is_causal: int = 0, + kv_num_heads: Optional[int] = None, + q_num_heads: Optional[int] = None, + qk_matmul_output_mode: int = 0, + scale: Optional[float] = None, + softcap: float = 0.0, + softmax_precision: Optional[int] = None, + ) -> Tuple[T1_Attention, T1_Attention, T2_Attention, T1_Attention]: + r"""[🌐 Attention(23)](https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23 "Online Documentation") + + + + Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed. + + This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V. + + For self attention, `kv_sequence_length` equals to `q_sequence_length`. + + For cross attention, query and key might have different lengths. + + This operator also covers the 3 following variants based on the number of heads: + 1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`. + 2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`. + 3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`. + + Attention bias to be added is calculated based on `attn_mask` input and `is_causal attribute`, only one of which can be provided. + 1) If `is_causal` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment. + 2) `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score. + + Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them. + The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided: + + :: + + The following pattern is applied by this operator: + Q K V + | | | + Q*sqrt(scale) K*sqrt(scale) | + | | | + | Transpose | + | | | + ---MatMul--- | + | | + at_mask---Add | + | | + softcap (if provided) | + | | + Softmax | + | | + -----MatMul------ + | + Y + + + + + + Args: + Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, + q_sequence_length, head_size)` or 3D tensor with shape `(batch_size, + q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor, + `q_hidden_size = q_num_heads * head_size` + + K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, + kv_sequence_length, head_size)` or 3D tensor with shape `(batch_size, + kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor, + `k_hidden_size = kv_num_heads * head_size` + + V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, + kv_sequence_length, v_head_size)` or 3D tensor with shape `(batch_size, + kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor, + `v_hidden_size = kv_num_heads * v_head_size` + + attn_mask: (optional) Attention mask. Shape must be broadcastable to 4D + tensor with shape `(batch_size, q_num_heads, q_sequence_length, + total_sequence_length)` where `total_sequence_length = + past_sequence_length + kv_sequence_length.` Two types of masks are + supported. A boolean mask where a value of `True` indicates that the + element should take part in attention. Also supports a float mask of the + same type as query, key, value that is added to the attention score. + + past_key: (optional) past state cache for key with shape `(batch_size, + kv_num_heads, past_sequence_length, head_size)` + + past_value: (optional) past state cache for value with shape `(batch_size, + kv_num_heads, past_sequence_length, v_head_size)` + + is_causal: If set to `1`, the attention masking is a lower triangular matrix + when the mask is a square matrix. The attention masking has the form of + the upper left causal bias due to the alignment. + + kv_num_heads: Number of heads of key and value. Must be used with 3D inputs + of Q, K and V. + + q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K + and V. + + qk_matmul_output_mode: If set to `0`, qk_matmul_output is the output of qk + matmul. If set to `1`, qk_matmul_output includes the addition of the + attention mask to the output of qk matmul. If set to `2`, + qk_matmul_output is the output after the softcap operation. If set to + `3`, qk_matmul_output is the output after the softmax operation. Default + value is 0. + + scale: Scaling factor applied to $Q*K^T$. Default value is + `1/sqrt(head_size)`. To prevent [numerical + overflow](https://tinyurl.com/sudb9s96), scale `Q`, `K` by `sqrt(scale)` + before matmul. + + softcap: Softcap value for attention weights. Default value is 0. + + softmax_precision: The floating-point precision used in softmax computation. + If softmax precision is not provided, the same precision as the input of + softmax (Q and K) is used. + """ + + schema = get_schema("Attention", 23, "") + op = Op(self, "Attention", schema) + return op( + *self._prepare_inputs(schema, Q, K, V, attn_mask, past_key, past_value), + is_causal=is_causal, + kv_num_heads=kv_num_heads, + q_num_heads=q_num_heads, + qk_matmul_output_mode=qk_matmul_output_mode, + scale=scale, + softcap=softcap, + softmax_precision=softmax_precision, + ) + + T1_Cast = TypeVar( + "T1_Cast", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T2_Cast: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def Cast(self, input: T1_Cast, *, saturate: int = 1, to: int) -> T2_Cast: + r"""[🌐 Cast(23)](https://onnx.ai/onnx/operators/onnx__Cast.html#cast-23 "Online Documentation") + + + The operator casts the elements of a given input tensor to a data type + specified by the 'to' argument and returns an output tensor of the same size in + the converted type. The 'to' argument must be one of the data types specified + in the 'DataType' enum field in the TensorProto message. + + Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations + (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may + yield result 100. There are some string literals reserved for special floating-point values; + "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively. + Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly, + this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors + to string tensors, plain floating-point representation (such as "314.15926") would be used. + Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases + of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior. + + Conversion from a numerical type to any numerical type is always allowed. + User must be aware of precision loss and value change caused by range difference between two types. + For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting + an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type. + + In more detail, the conversion among numerical types should follow these rules + if the destination type is not a float 8 type. + + * Casting from floating point to: + * floating point: +/- infinity if OOR (out of range). + * fixed point: undefined if OOR. + * bool: +/- 0.0 to False; all else to True. + * Casting from fixed point to: + * floating point: +/- infinity if OOR. (+ infinity in the case of uint) + * fixed point: when OOR, discard higher bits and reinterpret (with respect to two's complement representation for + signed types). For example, 200 (int16) -> -56 (int8). + * bool: zero to False; nonzero to True. + * Casting from bool to: + * floating point: `{1.0, 0.0}`. + * fixed point: `{1, 0}`. + * bool: no change. + + Float 8 type were introduced to speed up the training of + deep models. By default the conversion of a float *x* obeys + to the following rules. `[x]` means the value rounded to + the target mantissa width. + + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | -------- | -------- | -------- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | Inf | FLT_MAX | NaN | FLT_MAX | NaN | + | -Inf | -FLT_MAX | NaN | -FLT_MAX | NaN | + | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | + | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | + | else | RNE | RNE | RNE | RNE | + + The behavior changes if the parameter 'saturate' is set to False. + The rules then become: + + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | ------ | -------- | ---- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | -NaN | -NaN | NaN | -NaN | NaN | + | Inf | NaN | NaN | Inf | NaN | + | -Inf | -NaN | NaN | -Inf | NaN | + | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN | + | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN | + | else | RNE | RNE | RNE | RNE | + + + Args: + input: (differentiable) Input tensor to be cast. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). + It is true by default. All cases are fully described in two tables + inserted in the operator description. + + to: The data type to which the elements of the input tensor are cast. + Strictly must be one of the types from DataType enum in TensorProto + """ + + schema = get_schema("Cast", 23, "") + op = Op(self, "Cast", schema) + return op(*self._prepare_inputs(schema, input), saturate=saturate, to=to) + + T1_CastLike = TypeVar( + "T1_CastLike", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T2_CastLike = TypeVar( + "T2_CastLike", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def CastLike( + self, input: T1_CastLike, target_type: T2_CastLike, *, saturate: int = 1 + ) -> T2_CastLike: + r"""[🌐 CastLike(23)](https://onnx.ai/onnx/operators/onnx__CastLike.html#castlike-23 "Online Documentation") + + + The operator casts the elements of a given input tensor (the first input) to + the same data type as the elements of the second input tensor. + See documentation of the Cast operator for further details. + + + Args: + input: (differentiable) Input tensor to be cast. + + target_type: (non-differentiable) The (first) input tensor will be cast to + produce a tensor of the same type as this (second input) tensor. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). + It is true by default. Please refer to operator Cast description for + further details. + """ + + schema = get_schema("CastLike", 23, "") + op = Op(self, "CastLike", schema) + return op(*self._prepare_inputs(schema, input, target_type), saturate=saturate) + + T_Constant: TypeAlias = Union[ + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def Constant( + self, + *, + sparse_value: Optional[SparseTensorProto] = None, + value: Optional[TensorProto] = None, + value_float: Optional[float] = None, + value_floats: Optional[Sequence[float]] = None, + value_int: Optional[int] = None, + value_ints: Optional[Sequence[int]] = None, + value_string: Optional[str] = None, + value_strings: Optional[Sequence[str]] = None, + ) -> T_Constant: + r"""[🌐 Constant(23)](https://onnx.ai/onnx/operators/onnx__Constant.html#constant-23 "Online Documentation") + + + This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value, + or value_* must be specified. + + + Args: + sparse_value: The value for the elements of the output tensor in sparse + format. + + value: The value for the elements of the output tensor. + + value_float: The value for the sole element for the scalar, float32, output + tensor. + + value_floats: The values for the elements for the 1D, float32, output + tensor. + + value_int: The value for the sole element for the scalar, int64, output + tensor. + + value_ints: The values for the elements for the 1D, int64, output tensor. + + value_string: The value for the sole element for the scalar, UTF-8 string, + output tensor. + + value_strings: The values for the elements for the 1D, UTF-8 string, output + tensor. + """ + + schema = get_schema("Constant", 23, "") + op = Op(self, "Constant", schema) + return op( + sparse_value=sparse_value, + value=value, + value_float=value_float, + value_floats=value_floats, + value_int=value_int, + value_ints=value_ints, + value_string=value_string, + value_strings=value_strings, + ) + + T1_ConstantOfShape: TypeAlias = INT64 + + T2_ConstantOfShape: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def ConstantOfShape( + self, input: T1_ConstantOfShape, *, value: Optional[TensorProto] = None + ) -> T2_ConstantOfShape: + r"""[🌐 ConstantOfShape(23)](https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html#constantofshape-23 "Online Documentation") + + + Generate a tensor with given value and shape. + + + Args: + input: 1D tensor. The shape of the expected output tensor. If empty tensor + is given, the output would be a scalar. All values must be >= 0. + + value: (Optional) The value of the output elements.Should be a one-element + tensor. If not specified, it defaults to a tensor of value 0 and + datatype float32 + """ + + schema = get_schema("ConstantOfShape", 23, "") + op = Op(self, "ConstantOfShape", schema) + return op(*self._prepare_inputs(schema, input), value=value) + + T1_DequantizeLinear = TypeVar( + "T1_DequantizeLinear", + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT8, + UINT16, + UINT4, + UINT8, + ) + + T2_DequantizeLinear = TypeVar("T2_DequantizeLinear", BFLOAT16, FLOAT, FLOAT16) + + T3_DequantizeLinear: TypeAlias = Union[BFLOAT16, FLOAT, FLOAT16] + + def DequantizeLinear( + self, + x: T1_DequantizeLinear, + x_scale: T2_DequantizeLinear, + x_zero_point: Optional[T1_DequantizeLinear] = None, + *, + axis: int = 1, + block_size: int = 0, + output_dtype: int = 0, + ) -> T3_DequantizeLinear: + r"""[🌐 DequantizeLinear(23)](https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html#dequantizelinear-23 "Online Documentation") + + + The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the + full-precision tensor. The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point` + must have the same shape, determining the quantization's granularity: a scalar for per-tensor/per-layer quantization, + a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization. + See QuantizeLinear for details on quantization granularity. + + `x_zero_point` and `x` must have the same type. `x` and `y` must have the same shape. In the case of dequantizing + `int32`, there's no zero point (zero point is supposed to be 0). + `zero-point` is usually not used in the case of float8 and 4-bit types quantization, but the dequantization formula remains the same + for consistency. The output type is determined by the attribute `output_dtype`. If `output_dtype` is not supplied then the output type + is the same as `x_scale`. The output type also determines the precision of the multiplication operation. + + + + Args: + x: N-D quantized input tensor to be de-quantized. + + x_scale: Scale for input `x`. For per-tensor/layer dequantization the scale + is a scalar, for per per-axis dequantization it is a 1-D Tensor and for + blocked dequantization it has the same shape as the input, except for + one dimension in which blocking is performed. + + x_zero_point: (optional) Zero point for input `x`. Shape must match x_scale. + It's optional. Zero point is 0 when it's not specified. + + axis: (Optional) The axis of the dequantizing dimension of the input tensor. + Used for per-axis and blocked quantization. Negative value means + counting dimensions from the back. Accepted range is `[-r, r-1]` where + `r = rank(input)`. + + block_size: (Optional) The size of the quantization block (number of times + every scale is replicated). Used only for blocked quantization. The + block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., + Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted + range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]` + + output_dtype: (Optional) The output data type. If not supplied, the output + data type is inferred from `x_scale` data type (`T2`) + """ + + schema = get_schema("DequantizeLinear", 23, "") + op = Op(self, "DequantizeLinear", schema) + return op( + *self._prepare_inputs(schema, x, x_scale, x_zero_point), + axis=axis, + block_size=block_size, + output_dtype=output_dtype, + ) + + T_Flatten = TypeVar( + "T_Flatten", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Flatten(self, input: T_Flatten, *, axis: int = 1) -> T_Flatten: + r"""[🌐 Flatten(23)](https://onnx.ai/onnx/operators/onnx__Flatten.html#flatten-23 "Online Documentation") + + + Flattens the input tensor into a 2D matrix. If input tensor has shape + (d_0, d_1, ... d_n) then the output will have shape + (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn). + + + Args: + input: (differentiable) A tensor of rank >= axis. + + axis: Indicate up to which input dimensions (exclusive) should be flattened + to the outer dimension of the output. The value for axis must be in the + range [-r, r], where r is the rank of the input tensor. Negative value + means counting dimensions from the back. When axis = 0, the shape of the + output tensor is (1, (d_0 X d_1 ... d_n), where the shape of the input + tensor is (d_0, d_1, ... d_n). + """ + + schema = get_schema("Flatten", 23, "") + op = Op(self, "Flatten", schema) + return op(*self._prepare_inputs(schema, input), axis=axis) + + V_Identity = TypeVar( + "V_Identity", + Optional[Sequence[BOOL]], + Optional[Sequence[COMPLEX128]], + Optional[Sequence[COMPLEX64]], + Optional[Sequence[DOUBLE]], + Optional[Sequence[FLOAT]], + Optional[Sequence[FLOAT16]], + Optional[Sequence[INT16]], + Optional[Sequence[INT32]], + Optional[Sequence[INT64]], + Optional[Sequence[INT8]], + Optional[Sequence[STRING]], + Optional[Sequence[UINT16]], + Optional[Sequence[UINT32]], + Optional[Sequence[UINT64]], + Optional[Sequence[UINT8]], + Optional[BOOL], + Optional[COMPLEX128], + Optional[COMPLEX64], + Optional[DOUBLE], + Optional[FLOAT], + Optional[FLOAT16], + Optional[INT16], + Optional[INT32], + Optional[INT64], + Optional[INT8], + Optional[STRING], + Optional[UINT16], + Optional[UINT32], + Optional[UINT64], + Optional[UINT8], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Identity(self, input: V_Identity) -> V_Identity: + r"""[🌐 Identity(23)](https://onnx.ai/onnx/operators/onnx__Identity.html#identity-23 "Online Documentation") + + Identity operator + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Identity", 23, "") + op = Op(self, "Identity", schema) + return op(*self._prepare_inputs(schema, input)) + + B_If: TypeAlias = BOOL + + V_If: TypeAlias = Union[ + None, + Sequence[BFLOAT16], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + Sequence[FLOAT4E2M1], + Sequence[FLOAT8E4M3FN], + Sequence[FLOAT8E4M3FNUZ], + Sequence[FLOAT8E5M2], + Sequence[FLOAT8E5M2FNUZ], + Sequence[INT4], + Sequence[UINT4], + ] + + def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If: + r"""[🌐 If(23)](https://onnx.ai/onnx/operators/onnx__If.html#if-23 "Online Documentation") + + If conditional + + Args: + cond: Condition for the if. The tensor must contain a single element. + + else_branch: Graph to run if condition is false. Has N outputs: values you + wish to be live-out to the enclosing scope. The number of outputs must + match the number of outputs in the then_branch. + + then_branch: Graph to run if condition is true. Has N outputs: values you + wish to be live-out to the enclosing scope. The number of outputs must + match the number of outputs in the else_branch. + """ + + schema = get_schema("If", 23, "") + op = Op(self, "If", schema) + return op( + *self._prepare_inputs(schema, cond), + else_branch=else_branch, + then_branch=then_branch, + ) + + I_Loop: TypeAlias = INT64 + + B_Loop: TypeAlias = BOOL + + V_Loop = TypeVar( + "V_Loop", + Optional[Sequence[BFLOAT16]], + Optional[Sequence[BOOL]], + Optional[Sequence[COMPLEX128]], + Optional[Sequence[COMPLEX64]], + Optional[Sequence[DOUBLE]], + Optional[Sequence[FLOAT]], + Optional[Sequence[FLOAT16]], + Optional[Sequence[INT16]], + Optional[Sequence[INT32]], + Optional[Sequence[INT64]], + Optional[Sequence[INT8]], + Optional[Sequence[STRING]], + Optional[Sequence[UINT16]], + Optional[Sequence[UINT32]], + Optional[Sequence[UINT64]], + Optional[Sequence[UINT8]], + Optional[BFLOAT16], + Optional[BOOL], + Optional[COMPLEX128], + Optional[COMPLEX64], + Optional[DOUBLE], + Optional[FLOAT], + Optional[FLOAT16], + Optional[FLOAT4E2M1], + Optional[FLOAT8E4M3FN], + Optional[FLOAT8E4M3FNUZ], + Optional[FLOAT8E5M2], + Optional[FLOAT8E5M2FNUZ], + Optional[INT16], + Optional[INT32], + Optional[INT4], + Optional[INT64], + Optional[INT8], + Optional[STRING], + Optional[UINT16], + Optional[UINT32], + Optional[UINT4], + Optional[UINT64], + Optional[UINT8], + Sequence[BFLOAT16], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[FLOAT4E2M1], + Sequence[FLOAT8E4M3FN], + Sequence[FLOAT8E4M3FNUZ], + Sequence[FLOAT8E5M2], + Sequence[FLOAT8E5M2FNUZ], + Sequence[INT16], + Sequence[INT32], + Sequence[INT4], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT4], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Loop( + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, + ) -> V_Loop: + r"""[🌐 Loop(23)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-23 "Online Documentation") + + + Generic Looping construct. This loop has multiple termination conditions: + + 1) Trip count. Iteration count specified at runtime. Set by + specifying the input M. Optional. Set to empty string to omit. + Note that a static trip count (specified at graph construction time) can be + specified by passing in a constant node for input M. + 2) Loop termination condition. This is an input to the op that determines + whether to run the first iteration and also a loop-carried dependency for + the body graph. The body graph must yield a value for the condition variable, + whether this input is provided or not. + + This table summarizes the operating modes of this operator with equivalent + C-style code: + + Operator inputs defined as (max_trip_count, condition_var). + + * input ("", ""): + for (int i=0; ; ++i) { + cond = ... // Note this value is ignored, but is required in the body + } + + * input ("", cond) // Note this is analogous to a while loop + bool cond = ...; + for (int i=0; cond; ++i) { + cond = ...; + } + + * input ("", 1) // Note this is analogous to a do-while loop + bool cond = true + for (int i=0; cond; ++i) { + cond = ...; + } + + * input (trip_count, "") // Note this is analogous to a for loop + int trip_count = ... + for (int i=0; i < trip_count; ++i) { + cond = ...; // ignored + } + + * input (trip_count, cond) + int trip_count = ...; + bool cond = ...; + for (int i=0; i < trip_count && cond; ++i) { + cond = ...; + } + + + *Sample usage - cond as well as trip count* + + graph predict-net { + %a = Constant[value = ]() + %b = Constant[value = ]() + %keepgoing = Constant[value = ]() + %max_trip_count = Constant[value = ]() + %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b) + return + } + + graph body-net ( + %i[INT32, scalar] // iteration number + %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used + %b_in[INT32, scalar] // incoming value of loop-carried-dependency b + ) { + %my_local = Add(%a, %b_in) + %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b + %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition + %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated + return %keepgoing_out, %b_out, %user_defined_val + } + + *Sample equivalent C code* + + { + /* User-defined code (enclosing scope) */ + int a = 3, b = 6; + bool keepgoing = true; // Analogous to input cond + /* End user-defined code */ + + /* Implicitly-defined code */ + const int max_trip_count = 10; // Analogous to input M + int user_defined_vals[]; // Imagine this is resizable + /* End implicitly-defined code */ + /* initialize loop-carried variables and scan-output variables */ + bool keepgoing_out = keepgoing + int b_out = b + + for (int i=0; i < max_trip_count && keepgoing_out; ++i) { + /* Implicitly-defined code: bind actual parameter values + to formal parameter variables of loop-body */ + bool keepgoing_in = keepgoing_out; + bool b_in = b_out; + + /* User-defined code (loop body) */ + int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine + b_out = a - b_in; + keepgoing_out = my_local > b_out; + user_defined_val = b_in + b_in; // b_in and b_out are different variables + /* End user-defined code */ + + /* Implicitly defined-code */ + user_defined_vals[i] = user_defined_val // accumulate scan-output values + } + // int t = my_local; // Can't do this. my_local is not accessible here. + + // The values below are bound to the output variables of the loop and therefore accessible + // b_out; user_defined_vals; keepgoing_out; + } + + There are several things of note in this code snippet: + + 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can + be referenced in the inputs of the loop. + 2) Any values computed in the loop body that needs to be used in a subsequent + iteration or after the loop are modelled using a pair of variables in the loop-body, + consisting of an input variable (eg., b_in) and an output variable (eg., b_out). + These are referred to as loop-carried dependences. The loop operation node + supplies the input value of the input variable for the first iteration, and + returns the output value of the output variable produced by the final + iteration. + 3) Scan_output variables are used to implicitly concatenate values computed across + all the iterations. In the above example, the value of user_defined_val computed + over all iterations are concatenated and returned as the value of user_defined_vals + after the loop. + 4) Values created in the body cannot be accessed in the enclosing scope, + except using the mechanism described above. + + Note that the semantics of this op support "diagonal" or "wavefront" execution. + (See Step 3 here for an example: + https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/). + Frontends should emit multi-layer RNNs as a series of While operators (with + time being the inner looping dimension), with each successive layer consuming + the scan_outputs from the previous layer, possibly going through several + point-wise operators (e.g. dropout, residual connections, linear layer). + + The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order. + + + Args: + M: (optional) A maximum trip-count for the loop specified at runtime. + Optional. Pass empty string to skip. + + cond: (optional) A boolean termination condition. Optional. Pass empty + string to skip. + + v_initial: (variadic, heterogeneous) The initial values of any loop-carried + dependencies (values that change across loop iterations) + + body: The graph run each iteration. It has 2+N inputs: (iteration_num, + condition, loop carried dependencies...). It has 1+N+K outputs: + (condition, loop carried dependencies..., scan_outputs...). Each + scan_output is created by concatenating the value of the specified + output value at the end of each iteration of the loop. It is an error if + the dimensions or data type of these scan_outputs change across loop + iterations. + """ + + schema = get_schema("Loop", 23, "") + op = Op(self, "Loop", schema) + return op(*self._prepare_inputs(schema, M, cond, *v_initial), body=body) + + T_Pad = TypeVar( + "T_Pad", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + Tind_Pad = TypeVar("Tind_Pad", INT32, INT64) + + def Pad( + self, + data: T_Pad, + pads: INT64, + constant_value: Optional[T_Pad] = None, + axes: Optional[Tind_Pad] = None, + *, + mode: str = "constant", + ) -> T_Pad: + r"""[🌐 Pad(23)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-23 "Online Documentation") + + + Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, + a padded tensor (`output`) is generated. + + The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`): + + 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False) + + 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis + + 3) `edge` - pads with the edge values of array + + 4) `wrap` - wrap-around padding as if the data tensor forms a torus + + + Example 1 (`constant` mode): + + Insert 0 pads to the beginning of the second dimension. + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'constant' + + constant_value = 0.0 + + output = [ + [0.0, 0.0, 1.0, 1.2], + [0.0, 0.0, 2.3, 3.4], + [0.0, 0.0, 4.5, 5.7], + ] + + + + Example 2 (`reflect` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'reflect' + + output = [ + [1.0, 1.2, 1.0, 1.2], + [2.3, 3.4, 2.3, 3.4], + [4.5, 5.7, 4.5, 5.7], + ] + + + + Example 3 (`edge` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'edge' + + output = [ + [1.0, 1.0, 1.0, 1.2], + [2.3, 2.3, 2.3, 3.4], + [4.5, 4.5, 4.5, 5.7], + ] + + + + Example 4 (`wrap` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [2, 1, 1, 1] + + mode = 'wrap' + + output = [ + [3.4, 2.3, 3.4, 2.3], + [5.7, 4.5, 5.7, 4.5], + [1.2, 1.0, 1.2, 1.0], + [3.4, 2.3, 3.4, 2.3], + [5.7, 4.5, 5.7, 4.5], + [1.2, 1.0, 1.2, 1.0], + ] + + + + + Args: + data: (differentiable) Input tensor. + + pads: (non-differentiable) Tensor of integers indicating the number of + padding elements to add or remove (if negative) at the beginning and end + of each axis. For 2D input tensor, it is the number of pixels. `pads` + should be a 1D tensor of shape [2 * num_axes] where `num_axes` refers to + the number of elements in the `axes` input or the input rank if `axes` + are not provided explicitly. `pads` format should be: [x1_begin, + x2_begin, ..., x1_end, x2_end,...], where xi_begin is the number of pad + values added at the beginning of axis `axes[i]` and xi_end, the number + of pad values added at the end of axis `axes[i]`. + + constant_value: (optional, non-differentiable) (Optional) A scalar value to + be used if the mode chosen is `constant` (by default it is 0, empty + string or False). + + axes: (optional, non-differentiable) 1-D tensor of axes that `pads` apply + to. Negative value means counting dimensions from the back. Accepted + range is [-r, r-1] where r = rank(data). Behavior is undefined if an + axis is repeated. If not provided, all axes are assumed (`[0, 1, ..., + input_rank-1]`). + + mode: Supported modes: `constant`(default), `reflect`, `edge`, `wrap` + """ + + schema = get_schema("Pad", 23, "") + op = Op(self, "Pad", schema) + return op(*self._prepare_inputs(schema, data, pads, constant_value, axes), mode=mode) + + T1_QuantizeLinear = TypeVar("T1_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32) + + T2_QuantizeLinear = TypeVar("T2_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32) + + T3_QuantizeLinear = TypeVar( + "T3_QuantizeLinear", + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT4, + INT8, + UINT16, + UINT4, + UINT8, + ) + + def QuantizeLinear( + self, + x: T1_QuantizeLinear, + y_scale: T2_QuantizeLinear, + y_zero_point: Optional[T3_QuantizeLinear] = None, + *, + axis: int = 1, + block_size: int = 0, + output_dtype: int = 0, + precision: int = 0, + saturate: int = 1, + ) -> T3_QuantizeLinear: + r"""[🌐 QuantizeLinear(23)](https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html#quantizelinear-23 "Online Documentation") + + + The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the + low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization + granularity. The quantization formula is `y = saturate((x / y_scale) + y_zero_point)`. + + Saturation is done according to: + - uint16: [0, 65535] + - int16: [-32768, 32767] + - uint8: [0, 255] + - int8: [-128, 127] + - uint4: [0, 15] + - int4: [-8, 7] + + For `(x / y_scale)`, it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details. + + `y_zero_point` and `y` must have the same type. `y_zero_point` is usually not used for quantization to float8 and 4bit types, but the quantization + formula remains the same for consistency, and the type of the attribute `y_zero_point` still determines the quantization type. + `x` and `y_scale` are allowed to have different types. The type of `y_scale` determines the precision of the division operation between `x` and + `y_scale`, unless the `precision` attribute is specified. + + There are three supported quantization granularities, determined by the shape of `y_scale`. + In all cases, `y_zero_point` must have the same shape as `y_scale`. + - Per-tensor (per-layer) quantization: `y_scale` is a scalar. + - Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape + `(D0, ..., Di, ..., Dn)` and `axis=i`, `y_scale` is a 1-D tensor of length `Di`. + - Blocked quantization: The scale's shape is identical to the input's shape, except for one dimension, in which + blocking is performed. Given `x` shape `(D0, ..., Di, ..., Dn)`, `axis=i`, and block size `B`: `y_scale` shape is + `(D0, ..., ceil(Di/B), ..., Dn)`. + + + Args: + x: N-D full precision Input tensor to be quantized. + + y_scale: Scale for doing quantization to get `y`. For per-tensor/layer + quantization the scale is a scalar, for per-axis quantization it is a + 1-D Tensor and for blocked quantization it has the same shape as the + input, except for one dimension in which blocking is performed. + + y_zero_point: (optional) Zero point for doing quantization to get `y`. Shape + must match `y_scale`.Default is uint8 with zero point of 0 if it's not + specified. + + axis: (Optional) The axis of the dequantizing dimension of the input tensor. + Used only for per-axis and blocked quantization. Negative value means + counting dimensions from the back. Accepted range is `[-r, r-1]` where + `r = rank(input)`. When the rank of the input is 1, per-tensor + quantization is applied, rendering the axis unnecessary in this + scenario. + + block_size: (Optional) The size of the quantization block (number of times + every scale is replicated). Used only for blocked quantization. The + block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., + Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted + range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]` + + output_dtype: (Optional) The output data type. If not supplied, the output + data type is inferred from `y_zero_point` data type (`T3`). If neither + `output_dtype` nor `y_zero_point` are supplied, output data type is + uint8. If both `output_dtype` and `y_zero_point` are specified, + `output_dtype` must be `T3`. + + precision: (Optional) The precision of the division operation between `x` + and `y_scale`. If not provided, it will be the same as the type of + `y_scale`. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + quantization (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). + It is true by default. All cases are fully described in two tables + inserted in the operator description. + """ + + schema = get_schema("QuantizeLinear", 23, "") + op = Op(self, "QuantizeLinear", schema) + return op( + *self._prepare_inputs(schema, x, y_scale, y_zero_point), + axis=axis, + block_size=block_size, + output_dtype=output_dtype, + precision=precision, + saturate=saturate, + ) + + T_RMSNormalization = TypeVar("T_RMSNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + V_RMSNormalization = TypeVar("V_RMSNormalization", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def RMSNormalization( + self, + X: T_RMSNormalization, + scale: V_RMSNormalization, + *, + axis: int = -1, + epsilon: float = 9.999999747378752e-06, + stash_type: int = 1, + ) -> V_RMSNormalization: + r"""[🌐 RMSNormalization(23)](https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#rmsnormalization-23 "Online Documentation") + + + This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467. + The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions, + where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape), + the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be + described by the following equations. + ``` + XSquared = Mul(X, X) + XSquaredMean = ReduceMean(XSquared) + MeanSquareEpsilon = Add(XSquaredMean, epsilon) + RMS = Sqrt(MeanSquareEpsilon) + Normalized = Div(X, RMS) + ``` + where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square, + Depending on `stash_type` attribute, the actual computation + must happen in different floating-point precision. + For example, if `stash_type` is 1, this operator casts + all input variables to 32-bit float, perform the computation, and + finally cast `Normalized` back to the original type of `X`. + The second stage then scales the outcome of the first stage using: + ``` + Y= Mul(Normalized, Scale) + ``` + Let `d[i]` indicate the i-th dimension of `X`. + If `X`'s shape is `[d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]`, + the shape of `RMS` is `[d[0], ..., d[axis-1], 1, ..., 1]`. + `Y` and `X` have the same shape. This operator supports unidirectional broadcasting + (`Scale` should be unidirectional broadcastable to tensor `X`); + for more details please check `Broadcasting in ONNX `_. + + + Args: + X: The input tensor to be normalized. In general, the shape is (D1, D2, ... + , Dn) for n-dimensional data, where the root mean squared norm is taken + over the last D dimensions, D is determined by the axis attribute. + + scale: Scale tensor. Scale tensor shape should be broadcastable to the + normalized shape. + + axis: The first normalization dimension. If rank(X) is r, axis' allowed + range is [-r, r). Negative value means counting dimensions from the + back. + + epsilon: The epsilon value to use to avoid division by zero. + + stash_type: The floating-point precision used in stage one of the + computation. + """ + + schema = get_schema("RMSNormalization", 23, "") + op = Op(self, "RMSNormalization", schema) + return op( + *self._prepare_inputs(schema, X, scale), + axis=axis, + epsilon=epsilon, + stash_type=stash_type, + ) + + T_Reshape = TypeVar( + "T_Reshape", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Reshape(self, data: T_Reshape, shape: INT64, *, allowzero: int = 0) -> T_Reshape: + r"""[🌐 Reshape(23)](https://onnx.ai/onnx/operators/onnx__Reshape.html#reshape-23 "Online Documentation") + + + Reshape the input tensor similar to numpy.reshape. + First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor. + At most one dimension of the new shape can be -1. In this case, the value is + inferred from the size of the tensor and the remaining dimensions. A dimension + could also be 0, in which case the actual dimension value is unchanged (i.e. taken + from the input tensor). If 'allowzero' is set, and the new shape includes 0, the + dimension will be set explicitly to zero (i.e. not taken from input tensor). + Shape (second input) could be an empty shape, which means converting to a scalar. + The input tensor's shape and the output tensor's shape are required to have the same number of elements. + + If the attribute 'allowzero' is set, it is invalid for the specified shape to + contain both a zero value and -1, as the value of the dimension corresponding + to -1 cannot be determined uniquely. + + + Args: + data: (differentiable) An input tensor. + + shape: (non-differentiable) Specified shape for output. + + allowzero: (Optional) By default, when any value in the 'shape' input is + equal to zero the corresponding dimension value is copied from the input + tensor dynamically. allowzero=1 indicates that if any value in the + 'shape' input is set to zero, the zero value is honored, similar to + NumPy. + """ + + schema = get_schema("Reshape", 23, "") + op = Op(self, "Reshape", schema) + return op(*self._prepare_inputs(schema, data, shape), allowzero=allowzero) + + T_RotaryEmbedding = TypeVar("T_RotaryEmbedding", BFLOAT16, FLOAT, FLOAT16) + + M_RotaryEmbedding: TypeAlias = INT64 + + def RotaryEmbedding( + self, + X: T_RotaryEmbedding, + cos_cache: T_RotaryEmbedding, + sin_cache: T_RotaryEmbedding, + position_ids: Optional[M_RotaryEmbedding] = None, + *, + interleaved: int = 0, + num_heads: Optional[int] = None, + rotary_embedding_dim: int = 0, + ) -> T_RotaryEmbedding: + r"""[🌐 RotaryEmbedding(23)](https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23 "Online Documentation") + + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864. + The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances + between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token's absolute position (position_ids). + + The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles. + For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the + embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector. + The rotation matrix is parameterized by the token's position in the sequence. The rotated halves of the embedding vector are concatenated + to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism. + The rotation ensures that the model captures both absolute and relative positional information. + + Rotary embeddings are defined using the following algorithm: + + :: + + def compute_rotary_embedding( + input, + position_ids, + sin_cache, + cos_cache, + interleaved=0, + rotary_embedding_dim=0, + num_heads=0, + ): + # First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size] + if len(input.shape) == 4: + input = np.transpose(input, (0, 2, 1, 3)) + batch_size = input.shape[0] + sequence_length = input.shape[1] + if len(input.shape) == 3: + hidden_size = input.shape[2] + assert num_heads != 0 + head_size = int(hidden_size / num_heads) + new_shape = [batch_size, sequence_length, num_heads, head_size] + input = np.reshape(input, new_shape) + assert len(input.shape) == 4 + head_size = input.shape[3] + + # Fully or partially perform rotation on input based on rotary_embedding_dim attribute + if rotary_embedding_dim == 0: + # If rotary_embedding_dim not provided, perform full rotation by using head_size + rotary_embedding_dim = head_size + x_rotate = input[:, :, :, :rotary_embedding_dim] + x_not_rotate = input[:, :, :, rotary_embedding_dim:] + rotary_embedding_dim_half = int(rotary_embedding_dim / 2) + + # Retrieve sin and cos caches using position ids + if position_ids is not None: + cos = cos_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2] + sin = sin_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2] + else: + cos = cos_cache + sin = sin_cache + cos = cos[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + sin = sin[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + cos = np.expand_dims(cos, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + sin = np.expand_dims(sin, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] + + # Either divide the input in halves or interleave (based on interleaved attribute) + if interleaved: + x1 = x_rotate[:, :, :, 0::2] + x2 = x_rotate[:, :, :, 1::2] + else: + x1, x2 = np.split(x_rotate, 2, axis=-1) + + # Calculate real and imaginary values + real = cos * x1 - sin * x2 + imag = sin * x1 + cos * x2 + + # Inserted rotated embeddings back to the original input + if interleaved: + # x_rotate[:, :, :, 0::2] = real + # x_rotate[:, :, :, 1::2] = imag + real = np.expand_dims(real, axis=-1) + imag = np.expand_dims(imag, axis=-1) + x_rotate_concat = np.concatenate((real, imag), axis=-1) + x_rotate = np.reshape(x_rotate_concat, x_rotate.shape) + else: + x_rotate = np.concatenate((real, imag), axis=-1) + output = np.concatenate((x_rotate, x_not_rotate), axis=-1) + if len(original_input_shape) == 3: + output = np.reshape(output, input.shape) + else: + output = np.transpose(output, (0, 2, 1, 3)) + return output + + + + + Args: + X: The input tensor representing the token embeddings. 4D tensor with shape + `(batch_size, num_heads, sequence_length, head_size)` or 3D tensor with + shape `(batch_size, sequence_length, hidden_size)`. For cases with a 4D + input tensor, `head_size` has to be even. For cases with a 3D input + tensor, `num_heads` attribute must be provided and `hidden_size` must be + an even multiple of `num_heads` where `hidden_size = num_heads * + head_size` + + cos_cache: The cosine values for the rotation. 2D tensor with shape + `(max_position_id_plus_1, head_size / 2)` for full rotation or + `(max_position_id_plus_1, rotary_embedding_dim / 2)` for partial + rotation when `position_ids` are provided. 3D tensor with shape + `(batch_size, sequence_length, head_size / 2)` for full rotation or + `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial + rotation when `position_ids` are not provided. `max_position_id_plus_1` + is a parameter to the model. + + sin_cache: The sine values for the rotation. 2D tensor with shape + `(max_position_id_plus_1, head_size / 2)` for full rotation or + `(max_position_id_plus_1, rotary_embedding_dim / 2)` for partial + rotation when `position_ids` are provided. 3D tensor with shape + `(batch_size, sequence_length, head_size / 2)` for full rotation or + `(batch_size, sequence_length, rotary_embedding_dim / 2)` for partial + rotation when `position_ids` are not provided. `max_position_id_plus_1` + is a parameter to the model. + + position_ids: (optional) The position indices for the tokens. 2D tensor with + shape `(batch_size, sequence_length)` + + interleaved: Rotate using interleaved pattern. Default value is 0 (False). + + num_heads: Number of attention heads. Must be provided when input is a 3D + tensor. + + rotary_embedding_dim: Rotary embedding dimension used to apply partial + rotary embeddings. + """ + + schema = get_schema("RotaryEmbedding", 23, "") + op = Op(self, "RotaryEmbedding", schema) + return op( + *self._prepare_inputs(schema, X, cos_cache, sin_cache, position_ids), + interleaved=interleaved, + num_heads=num_heads, + rotary_embedding_dim=rotary_embedding_dim, + ) + + V_Scan = TypeVar( + "V_Scan", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Scan( + self, + *initial_state_and_scan_inputs: V_Scan, + body: GraphProto, + num_scan_inputs: int, + scan_input_axes: Optional[Sequence[int]] = None, + scan_input_directions: Optional[Sequence[int]] = None, + scan_output_axes: Optional[Sequence[int]] = None, + scan_output_directions: Optional[Sequence[int]] = None, + ) -> V_Scan: + r"""[🌐 Scan(23)](https://onnx.ai/onnx/operators/onnx__Scan.html#scan-23 "Online Documentation") + + + Scan can be used to iterate over one or more scan_input tensors, + constructing zero or more scan_output tensors. It combines ideas from general recurrences, + functional programming constructs such as scan, fold, map, and zip, and is intended to enable + generalizations of RNN-like constructs for sequence-to-sequence processing. + Other tensors (referred to as state_variables here) can be used to carry a state + when iterating from one element to another (similar to hidden-state in RNNs, also referred + to as loop-carried dependences in the context of loops). + Many common usages involve a single scan_input tensor (where functionality + similar to scan, fold and map can be obtained). When more than one scan_input is used, + a behavior similar to zip is obtained. + + The attribute body must be a graph, specifying the computation to be performed in + every iteration. It takes as input the current values of the state_variables and + the current iterated element of the scan_inputs. It must return the (updated) values + of the state_variables and zero or more scan_output_element tensors. The values of the + scan_output_element tensors are concatenated over all the iterations to produce the + scan_output values of the scan construct (similar to the concatenated intermediate + hidden-state values of RNN-like constructs). All the output tensors (state_variables as + well as scan_output_element tensors) are required to have the same shape in each iteration + of the loop (a restriction imposed to enable efficient memory allocation). + + Note that the iterated element passed to the body subgraph does not have a sequence + axis. It will have a rank one less than the rank of the corresponding scan_input. + + The scan operation returns the final values of the state_variables as well as the + scan_outputs. + + The optional attribute scan_input_directions specifies the direction (forward or backward) + for each scan input. If this attribute is omitted, all sequences are scanned in the forward + direction. A bidirectional scan may be performed by specifying the same tensor input twice + in the scan_inputs, once with a forward direction, and once with a backward direction. + + The scan_output of the operation is produced by concatenating the scan_output_element + values produced by the body in each iteration. The optional attribute scan_output_directions + specifies the direction in which scan_output is constructed (by appending or prepending the + scan_output_element to scan_output in each iteration) for each scan_output. If this attribute + is omitted, the scan_output_element is appended to the scan_output in each iteration. + + The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input. + If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the + batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1. + Note that scanning a non-zero axis may be less efficient than scanning axis zero. + + The optional attribute scan_output_axes specifies the axis along which the scan_outputs + are accumulated for each scan_output. For example, if axis 1 is the time axis (to be + scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis + value of 1. + + Note that because of the ONNX restriction that only the last parameter of an operator can + be variadic, the initial-states and scan-inputs are listed together as one input parameter. + Similarly, the final-states and scan-outputs are listed together as one output parameter. + The attribute num_scan_inputs indicates the number M of scan-inputs. + + The behavior of + + Scan < + num_scan_inputs = m, + body = loop-body, + scan_input_axes = [axis_1, ..., axis_m] + > (init_1, ..., init_n, scan_1, ..., scan_m) + + is equivalent to the following pseudo-code: + + // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i + // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j. + sequence_length = scan_1.shape[axis_1]; + + // initialize state-variables + st_1 = init_1; ... st_n = init_n; + // initialize scan-output variables: [] denotes an empty tensor + scan_out_1 = []; ...; scan_out_k = []; + // identify number of iterations: + + // execute loop + for (int t = 0; t < sequence_length; ++t) { + // generate the scan-input elements: the notation T[t] indicates the sub-tensor + // of rank one less than T obtained by indexing T at position t along axis k. + si_1 = scan_1[t]; + ... ; + si_m = scan_m[t]; + // execute loop-body + st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m) + // accumulate the scan-output elements + scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k); + } + + return st_1, ..., st_n, scan_out_1, ..., scan_out_k; + + *Sample usage: Encoding RNN using a Scan* + + The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi, + recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can + be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes + %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these + values are computed in the outer graph, they need to be passed in as extra state_variables. + + graph rnn-encoding { + %H_0 = ... + %X = ... + %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X) + return %Y, %Y_h + } + + graph rnn-cell-1 ( + %H_tminus1[FLOAT, tensor] + %X_t[FLOAT, tensor] + ) { + %Wi = ... + %Ri = ... + %Wbi = ... + %Rbi = ... + %t1 = X_t * (Wi^T) + %t2 = H_tminus1*(Ri^T) + %t3 = Add(%t1, %t2) + %t4 = Add(%t3, %Wbi) + %t5 = Add(%t4, %Rbi) + %Ht = Tanh(%t5) + %Accumulate = Identity(%Ht) + return %Ht, %Accumulate + } + + + + Args: + initial_state_and_scan_inputs: (variadic, heterogeneous) Initial values of + the loop's N state variables followed by M scan_inputs + + body: The graph run each iteration. It has N+M inputs: (loop state + variables..., scan_input_elts...). It has N+K outputs: (loop state + variables..., scan_output_elts...). Each scan_output is created by + concatenating the value of the specified scan_output_elt value at the + end of each iteration of the loop. It is an error if the dimensions of + these values change across loop iterations. + + num_scan_inputs: An attribute specifying the number of scan_inputs M. + + scan_input_axes: An optional list of M flags. The i-th element of the list + specifies the axis to be scanned (the sequence axis) for the i-th + scan_input. If omitted, 0 will be used as the scan axis for every + scan_input. Negative value for an axis means counting dimensions from + the back. Accepted range is [-r, r-1] where r = rank(input). + + scan_input_directions: An optional list of M flags. The i-th element of the + list specifies the direction to be scanned for the i-th scan_input + tensor: 0 indicates forward direction and 1 indicates reverse direction. + If omitted, all scan_input tensors will be scanned in the forward + direction. + + scan_output_axes: An optional list of K flags. The i-th element of the list + specifies the axis for the i-th scan_output. The scan outputs are + accumulated along the specified axis. If omitted, 0 will be used as the + scan axis for every scan_output. Negative value for an axis means + counting dimensions from the back. Accepted range is [-r, r-1]. + + scan_output_directions: An optional list of K flags, one for each + scan_output. The i-th element of the list specifies whether the i-th + scan_output should be constructed by appending or prepending a new value + in each iteration: 0 indicates appending and 1 indicates prepending. If + omitted, all scan_output tensors will be produced by appending a value + in each iteration. + """ + + schema = get_schema("Scan", 23, "") + op = Op(self, "Scan", schema) + return op( + *self._prepare_inputs(schema, *initial_state_and_scan_inputs), + body=body, + num_scan_inputs=num_scan_inputs, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + ) + + T_Shape = TypeVar( + "T_Shape", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T1_Shape: TypeAlias = INT64 + + def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) -> T1_Shape: + r"""[🌐 Shape(23)](https://onnx.ai/onnx/operators/onnx__Shape.html#shape-23 "Online Documentation") + + + Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor. + Optional attributes start and end can be used to compute a slice of the input tensor's shape. + If start axis is omitted, the slice starts from axis 0. + The end axis, if specified, is exclusive (and the returned value will not include the size of that axis). + If the end axis is omitted, the axes upto the last one will be included. + Negative axes indicate counting back from the last axis. + Note that axes will be clamped to the range [0, r], where r is the + rank of the input tensor if they are out-of-range (after adding r in the case of + negative axis). Thus, specifying any end value > r is equivalent to specifying an end + value of r, and specifying any start value < -r is equivalent to specifying a start + value of 0. If start > end, the result will be an empty shape. + + Examples: + + :: + + Input tensor with shape: [2, 3, 4] + No attributes specified. + Output: [2, 3, 4] + + + + :: + + Input tensor with shape: [2, 3, 4] + start: -1 + Output: [4] + + + + :: + + Input tensor with shape: [2, 3, 4] + end: -1 + Output: [2, 3] + + + + :: + + Input tensor with shape: [2, 3, 4] + start: 1 + end: 2 + Output: [3] + + + + + Args: + data: (non-differentiable) An input tensor. + + end: (Optional) Ending axis for slicing the shape. Negative value means + counting dimensions from the back. If omitted, sizes of all axes upto + (including) the last one will be included. + + start: (Optional) Starting axis for slicing the shape. Default value is + 0.Negative value means counting dimensions from the back. + """ + + schema = get_schema("Shape", 23, "") + op = Op(self, "Shape", schema) + return op(*self._prepare_inputs(schema, data), end=end, start=start) + + T_Size = TypeVar( + "T_Size", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T1_Size: TypeAlias = INT64 + + def Size(self, data: T_Size) -> T1_Size: + r"""[🌐 Size(23)](https://onnx.ai/onnx/operators/onnx__Size.html#size-23 "Online Documentation") + + + Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor. + + + Args: + data: (non-differentiable) An input tensor. + """ + + schema = get_schema("Size", 23, "") + op = Op(self, "Size", schema) + return op(*self._prepare_inputs(schema, data)) + + T_Squeeze = TypeVar( + "T_Squeeze", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Squeeze(self, data: T_Squeeze, axes: Optional[INT64] = None) -> T_Squeeze: + r"""[🌐 Squeeze(23)](https://onnx.ai/onnx/operators/onnx__Squeeze.html#squeeze-23 "Online Documentation") + + + Remove single-dimensional entries from the shape of a tensor. + Takes an input `axes` with a list of axes to squeeze. + If `axes` is not provided, all the single dimensions will be removed from + the shape. If an axis is selected with shape entry not equal to one, an error is raised. + + + Args: + data: (differentiable) Tensors with at least max(dims) dimensions. + + axes: (optional, non-differentiable) 1D tensor of integers indicating the + dimensions to squeeze. Negative value means counting dimensions from the + back. Accepted range is [-r, r-1] where r = rank(data). + """ + + schema = get_schema("Squeeze", 23, "") + op = Op(self, "Squeeze", schema) + return op(*self._prepare_inputs(schema, data, axes)) + + T_Transpose = TypeVar( + "T_Transpose", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Transpose( + self, data: T_Transpose, *, perm: Optional[Sequence[int]] = None + ) -> T_Transpose: + r"""[🌐 Transpose(23)](https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-23 "Online Documentation") + + + Transpose the input tensor similar to numpy.transpose. For example, when + perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape + will be (2, 1, 3). + + + Args: + data: (differentiable) An input tensor. + + perm: A list of integers. By default, reverse the dimensions, otherwise + permute the axes according to the values given. Its length must be equal + to the rank of the input. + """ + + schema = get_schema("Transpose", 23, "") + op = Op(self, "Transpose", schema) + return op(*self._prepare_inputs(schema, data), perm=perm) + + T_Unsqueeze = TypeVar( + "T_Unsqueeze", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Unsqueeze(self, data: T_Unsqueeze, axes: INT64) -> T_Unsqueeze: + r"""[🌐 Unsqueeze(23)](https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#unsqueeze-23 "Online Documentation") + + + Insert single-dimensional entries to the shape of an input tensor (`data`). + Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`). + + For example, given an input tensor (`data`) of shape [3, 4, 5], then + Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1]. + + The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates. + The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`. + Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1]. + The order of values in `axes` does not matter and can come in any order. + + + Args: + data: (differentiable) Original tensor + + axes: (non-differentiable) 1D tensor of integers indicating the dimensions + to be inserted. Negative value means counting dimensions from the back. + Accepted range is [-r, r-1] where r = rank(expanded). + """ + + schema = get_schema("Unsqueeze", 23, "") + op = Op(self, "Unsqueeze", schema) + return op(*self._prepare_inputs(schema, data, axes)) diff --git a/onnxscript/onnx_opset/_impl/opset24.py b/onnxscript/onnx_opset/_impl/opset24.py new file mode 100644 index 0000000000..d85fcaefe5 --- /dev/null +++ b/onnxscript/onnx_opset/_impl/opset24.py @@ -0,0 +1,2342 @@ +# -------------------------------------------------------------------------- +# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ +# ⚙️ Generated by 'python -m opgen' +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=W0221,W0222,R0901,W0237 +# mypy: disable-error-code=override +# ruff: noqa: D214, D402, D405, D411, D412, D416 +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, TypeVar, Union + +from onnx import GraphProto, SparseTensorProto, TensorProto +from onnx.defs import get_schema +from typing_extensions import TypeAlias + +from onnxscript.onnx_opset._impl.opset23 import Opset23 +from onnxscript.onnx_types import ( + BFLOAT16, + BOOL, + COMPLEX64, + COMPLEX128, + DOUBLE, + FLOAT, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + FLOAT16, + INT4, + INT8, + INT16, + INT32, + INT64, + STRING, + UINT4, + UINT8, + UINT16, + UINT32, + UINT64, +) +from onnxscript.values import Op, Opset + + +class Opset24(Opset23): + def __new__(cls): + return Opset.__new__(cls, "", 24) + + T1_Attention = TypeVar("T1_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + T2_Attention = TypeVar("T2_Attention", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + U_Attention = TypeVar( + "U_Attention", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + def Attention( + self, + Q: T1_Attention, + K: T1_Attention, + V: T2_Attention, + attn_mask: Optional[U_Attention] = None, + past_key: Optional[T1_Attention] = None, + past_value: Optional[T2_Attention] = None, + nonpad_kv_seqlen: Optional[INT64] = None, + *, + is_causal: int = 0, + kv_num_heads: Optional[int] = None, + q_num_heads: Optional[int] = None, + qk_matmul_output_mode: int = 0, + scale: Optional[float] = None, + softcap: float = 0.0, + softmax_precision: Optional[int] = None, + ) -> Tuple[T1_Attention, T1_Attention, T2_Attention, T1_Attention]: + r"""[🌐 Attention(24)](https://onnx.ai/onnx/operators/onnx__Attention.html#attention-24 "Online Documentation") + + + + Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed. + + This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V. + + For self attention, `kv_sequence_length` equals to `q_sequence_length`. + + For cross attention, query and key might have different lengths. + + This operator also covers the 3 following variants based on the number of heads: + 1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`. + 2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`. + 3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`. + + Attention bias to be added is calculated based on `attn_mask` input and `is_causal` attribute: + 1) `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score. + 2) If `is_causal` is set to `1`, attention scores above the diagonal are masked out, regardless of the `attn_mask` input. + + With respect to KV cache update, this operator allows the following two use cases: + + 1) Cache update happens inside the Attention operator. In this case, the `K` and `V` inputs contain only the incoming + tokens for the current autoregressive step, and the four optional inputs/outputs past and present key and value are + all needed. The Attention op performs a Concat operation on the past and incoming key and value to form the present + key and value, respectively. Note that this only works correctly for the special case where the past key and value + do not contain padded tokens. + 2) Cache update happens outside the Attention operator (for example, through the `TensorScatter` operator). In this + case, the `K` and `V` inputs correspond to the entire cache tensor, so the four optional inputs/outputs past and + present key and value should not be used. An additional input `nonpad_kv_seqlen` of shape (batch_size,) may be + provided to indicate the number of non-padding tokens in each sample of the batch to save unnecessary computation. + Here, the kv_sequence dimension of `attn_mask` can be shorter than `K` and `V`, but still needs to be at least as long + as the maximum value of `nonpad_kv_seqlen`. + + Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them. + The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided: + + :: + + The following pattern is applied by this operator: + Q K V + | | | + Q*sqrt(scale) K*sqrt(scale) | + | | | + | Transpose | + | | | + ---MatMul--- | + | | + at_mask---Add | + | | + softcap (if provided) | + | | + Softmax | + | | + -----MatMul------ + | + Y + + + + + + Args: + Q: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, + q_sequence_length, head_size)` or 3D tensor with shape `(batch_size, + q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor, + `q_hidden_size = q_num_heads * head_size` + + K: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, + kv_sequence_length, head_size)` or 3D tensor with shape `(batch_size, + kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor, + `k_hidden_size = kv_num_heads * head_size` + + V: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, + kv_sequence_length, v_head_size)` or 3D tensor with shape `(batch_size, + kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor, + `v_hidden_size = kv_num_heads * v_head_size` + + attn_mask: (optional) Attention mask. Shape must be broadcastable to + `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` + where `total_sequence_length = past_sequence_length + + kv_sequence_length.` The last dimension can also be shorter than + `total_sequence_length` and will be padded to `total_sequence_length` + with negative infinity. Two types of masks are supported: a boolean mask + where a value of `True` indicates that the element should take part in + attention, or a float mask of the same type as query, key, value that is + added to the attention score. + + past_key: (optional) past state cache for key with shape `(batch_size, + kv_num_heads, past_sequence_length, head_size)` + + past_value: (optional) past state cache for value with shape `(batch_size, + kv_num_heads, past_sequence_length, v_head_size)` + + nonpad_kv_seqlen: (optional) A vector of integers of shape `(batch_size,)` + that indicates the number of valid (ie, non-padding) tokens in each + sample. A padding mask can be derived from this. This should not be used + together with `past_key` and `past_value` inputs or `present_key` and + `present_value` outputs (See the KV cache use cases in the operator + description). + + is_causal: If set to `1`, the attention masking is a lower triangular matrix + when the mask is a square matrix. The attention masking has the form of + the upper left causal bias due to the alignment. + + kv_num_heads: Number of heads of key and value. Must be used with 3D inputs + of Q, K and V. + + q_num_heads: Number of heads of query. Must be used with 3D inputs of Q, K + and V. + + qk_matmul_output_mode: If set to `0`, qk_matmul_output is the output of qk + matmul. If set to `1`, qk_matmul_output includes the addition of the + attention mask to the output of qk matmul. If set to `2`, + qk_matmul_output is the output after the softcap operation. If set to + `3`, qk_matmul_output is the output after the softmax operation. Default + value is 0. + + scale: Scaling factor applied to $Q*K^T$. Default value is + `1/sqrt(head_size)`. To prevent [numerical + overflow](https://tinyurl.com/sudb9s96), scale `Q`, `K` by `sqrt(scale)` + before matmul. + + softcap: Softcap value for attention weights. Default value is 0. + + softmax_precision: The floating-point precision used in softmax computation. + If softmax precision is not provided, the same precision as the input of + softmax (Q and K) is used. + """ + + schema = get_schema("Attention", 24, "") + op = Op(self, "Attention", schema) + return op( + *self._prepare_inputs( + schema, Q, K, V, attn_mask, past_key, past_value, nonpad_kv_seqlen + ), + is_causal=is_causal, + kv_num_heads=kv_num_heads, + q_num_heads=q_num_heads, + qk_matmul_output_mode=qk_matmul_output_mode, + scale=scale, + softcap=softcap, + softmax_precision=softmax_precision, + ) + + T1_Cast = TypeVar( + "T1_Cast", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T2_Cast: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def Cast( + self, input: T1_Cast, *, round_mode: str = "up", saturate: int = 1, to: int + ) -> T2_Cast: + r"""[🌐 Cast(24)](https://onnx.ai/onnx/operators/onnx__Cast.html#cast-24 "Online Documentation") + + + The operator casts the elements of a given input tensor to a data type + specified by the 'to' argument and returns an output tensor of the same size in + the converted type. The 'to' argument must be one of the data types specified + in the 'DataType' enum field in the TensorProto message. + + Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations + (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may + yield result 100. There are some string literals reserved for special floating-point values; + "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively. + Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly, + this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors + to string tensors, plain floating-point representation (such as "314.15926") would be used. + Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases + of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior. + + Conversion from a numerical type to any numerical type is always allowed. + User must be aware of precision loss and value change caused by range difference between two types. + For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting + an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type. + + In more detail, the conversion among numerical types should follow these rules + if the destination type is not a float 8 type. + + * Casting from floating point to: + * floating point: +/- infinity if OOR (out of range). + * fixed point: undefined if OOR. + * bool: +/- 0.0 to False; all else to True. + * Casting from fixed point to: + * floating point: +/- infinity if OOR. (+ infinity in the case of uint) + * fixed point: when OOR, discard higher bits and reinterpret (with respect to two's complement representation for + signed types). For example, 200 (int16) -> -56 (int8). + * bool: zero to False; nonzero to True. + * Casting from bool to: + * floating point: `{1.0, 0.0}`. + * fixed point: `{1, 0}`. + * bool: no change. + + Float 8 types (E4M3FN, E4M3FNUZ, E5M2, E5M2FNUZ) were introduced to speed up the training of + deep models. By default the conversion of a float *x* obeys + to the following rules. `[x]` means the value rounded to + the target mantissa width. + + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | -------- | -------- | -------- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | Inf | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | + | -Inf | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | + | \[x\] > FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | FLT_MAX | + | \[x\] \< -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | -FLT_MAX | + | else | RNE | RNE | RNE | RNE | + + The behavior changes if the parameter 'saturate' is set to False. + The rules then become: + + | x | E4M3FN | E4M3FNUZ | E5M2 | E5M2FNUZ | + | ----------------- | ------ | -------- | ---- | -------- | + | 0 | 0 | 0 | 0 | 0 | + | -0 | -0 | 0 | -0 | 0 | + | NaN | NaN | NaN | NaN | NaN | + | -NaN | -NaN | NaN | -NaN | NaN | + | Inf | NaN | NaN | Inf | NaN | + | -Inf | -NaN | NaN | -Inf | NaN | + | \[x\] > FLT_MAX | NaN | NaN | Inf | NaN | + | \[x\] \< -FLT_MAX | NaN | NaN | -Inf | NaN | + | else | RNE | RNE | RNE | RNE | + + FLOAT8E8M0 type was introduced to enable [Microscaling (MX) formats](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). + When casting to FLOAT8E8M0, the rounding behavior can be specified using the `round_mode` and `saturate` attributes. + The current CUDA behavior is to round up and saturate. Casting negative values to FLOAT8E8M0 gives undefined behavior. + The following table describes the casting behavior of special values to FLOAT8E8M0 in the two most common cases. + + | x | saturate + up | non-saturate + nearest | + | ----------------- | ------------- | --------------------- | + | 0 | 0 | NaN | + | -0 | Unspecified | Unspecified | + | NaN | NaN | NaN | + | Inf | E8M0_MAX | NaN | + | x > E8M0_MAX | E8M0_MAX | NaN | + | x \< E8M0_MIN | E8M0_MIN | NaN | + | x \< 0 | Unspecified | Unspecified | + + + Args: + input: (differentiable) Input tensor to be cast. + + round_mode: Rounding mode for conversion to float8e8m0. It only applies to + casting to float8e8m0 and is `up` by default. `up`: round to nearest + value away from zero, `down`: round to nearest value towards zero, + `nearest`: round to nearest value and ties round up. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz, + float8e8m0). It is true by default. All cases are fully described in the + tables inserted in the operator description. + + to: The data type to which the elements of the input tensor are cast. + Strictly must be one of the types from DataType enum in TensorProto + """ + + schema = get_schema("Cast", 24, "") + op = Op(self, "Cast", schema) + return op( + *self._prepare_inputs(schema, input), + round_mode=round_mode, + saturate=saturate, + to=to, + ) + + T1_CastLike = TypeVar( + "T1_CastLike", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T2_CastLike = TypeVar( + "T2_CastLike", + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def CastLike( + self, + input: T1_CastLike, + target_type: T2_CastLike, + *, + round_mode: str = "up", + saturate: int = 1, + ) -> T2_CastLike: + r"""[🌐 CastLike(24)](https://onnx.ai/onnx/operators/onnx__CastLike.html#castlike-24 "Online Documentation") + + + The operator casts the elements of a given input tensor (the first input) to + the same data type as the elements of the second input tensor. + See documentation of the Cast operator for further details. + + + Args: + input: (differentiable) Input tensor to be cast. + + target_type: (non-differentiable) The (first) input tensor will be cast to + produce a tensor of the same type as this (second input) tensor. + + round_mode: Rounding mode for conversion to float8e8m0. It only applies to + casting to float8e8m0 and is `up` by default. `up`: round to nearest + value away from zero, `down`: round to nearest value towards zero, + `nearest`: round to nearest value and ties round up. Please refer to + operator Cast description for further details. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + conversion (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz, + float8e8m0). It is true by default. Please refer to operator Cast + description for further details. + """ + + schema = get_schema("CastLike", 24, "") + op = Op(self, "CastLike", schema) + return op( + *self._prepare_inputs(schema, input, target_type), + round_mode=round_mode, + saturate=saturate, + ) + + T_Constant: TypeAlias = Union[ + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def Constant( + self, + *, + sparse_value: Optional[SparseTensorProto] = None, + value: Optional[TensorProto] = None, + value_float: Optional[float] = None, + value_floats: Optional[Sequence[float]] = None, + value_int: Optional[int] = None, + value_ints: Optional[Sequence[int]] = None, + value_string: Optional[str] = None, + value_strings: Optional[Sequence[str]] = None, + ) -> T_Constant: + r"""[🌐 Constant(24)](https://onnx.ai/onnx/operators/onnx__Constant.html#constant-24 "Online Documentation") + + + This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value, + or value_* must be specified. + + + Args: + sparse_value: The value for the elements of the output tensor in sparse + format. + + value: The value for the elements of the output tensor. + + value_float: The value for the sole element for the scalar, float32, output + tensor. + + value_floats: The values for the elements for the 1D, float32, output + tensor. + + value_int: The value for the sole element for the scalar, int64, output + tensor. + + value_ints: The values for the elements for the 1D, int64, output tensor. + + value_string: The value for the sole element for the scalar, UTF-8 string, + output tensor. + + value_strings: The values for the elements for the 1D, UTF-8 string, output + tensor. + """ + + schema = get_schema("Constant", 24, "") + op = Op(self, "Constant", schema) + return op( + sparse_value=sparse_value, + value=value, + value_float=value_float, + value_floats=value_floats, + value_int=value_int, + value_ints=value_ints, + value_string=value_string, + value_strings=value_strings, + ) + + T1_ConstantOfShape: TypeAlias = INT64 + + T2_ConstantOfShape: TypeAlias = Union[ + BFLOAT16, + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ] + + def ConstantOfShape( + self, input: T1_ConstantOfShape, *, value: Optional[TensorProto] = None + ) -> T2_ConstantOfShape: + r"""[🌐 ConstantOfShape(24)](https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html#constantofshape-24 "Online Documentation") + + + Generate a tensor with given value and shape. + + + Args: + input: 1D tensor. The shape of the expected output tensor. If empty tensor + is given, the output would be a scalar. All values must be >= 0. + + value: (Optional) The value of the output elements.Should be a one-element + tensor. If not specified, it defaults to a tensor of value 0 and + datatype float32 + """ + + schema = get_schema("ConstantOfShape", 24, "") + op = Op(self, "ConstantOfShape", schema) + return op(*self._prepare_inputs(schema, input), value=value) + + T1_DequantizeLinear = TypeVar( + "T1_DequantizeLinear", + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT32, + INT4, + INT8, + UINT16, + UINT4, + UINT8, + ) + + T2_DequantizeLinear = TypeVar("T2_DequantizeLinear", BFLOAT16, FLOAT, FLOAT16, FLOAT8E8M0) + + T3_DequantizeLinear: TypeAlias = Union[BFLOAT16, FLOAT, FLOAT16] + + def DequantizeLinear( + self, + x: T1_DequantizeLinear, + x_scale: T2_DequantizeLinear, + x_zero_point: Optional[T1_DequantizeLinear] = None, + *, + axis: int = 1, + block_size: int = 0, + output_dtype: int = 0, + ) -> T3_DequantizeLinear: + r"""[🌐 DequantizeLinear(24)](https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html#dequantizelinear-24 "Online Documentation") + + + The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the + full-precision tensor. The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point` + must have the same shape, determining the quantization's granularity: a scalar for per-tensor/per-layer quantization, + a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization. + See QuantizeLinear for details on quantization granularity. + + `x_zero_point` and `x` must have the same type. `x` and `y` must have the same shape. In the case of dequantizing + `int32`, there's no zero point (zero point is supposed to be 0). + `zero-point` is usually not used in the case of float8 and 4-bit types quantization, but the dequantization formula remains the same + for consistency. The output type is determined by the attribute `output_dtype`. If `output_dtype` is not supplied then the output type + is the same as `x_scale`. The output type also determines the precision of the multiplication operation. + + + + Args: + x: N-D quantized input tensor to be de-quantized. + + x_scale: Scale for input `x`. For per-tensor/layer dequantization the scale + is a scalar, for per per-axis dequantization it is a 1-D Tensor and for + blocked dequantization it has the same shape as the input, except for + one dimension in which blocking is performed. + + x_zero_point: (optional) Zero point for input `x`. Shape must match x_scale. + It's optional. Zero point is 0 when it's not specified. + + axis: (Optional) The axis of the dequantizing dimension of the input tensor. + Used for per-axis and blocked quantization. Negative value means + counting dimensions from the back. Accepted range is `[-r, r-1]` where + `r = rank(input)`. + + block_size: (Optional) The size of the quantization block (number of times + every scale is replicated). Used only for blocked quantization. The + block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., + Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted + range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]` + + output_dtype: (Optional) The output data type. If not supplied, the output + data type is inferred from `x_scale` data type (`T2`) + """ + + schema = get_schema("DequantizeLinear", 24, "") + op = Op(self, "DequantizeLinear", schema) + return op( + *self._prepare_inputs(schema, x, x_scale, x_zero_point), + axis=axis, + block_size=block_size, + output_dtype=output_dtype, + ) + + T_Flatten = TypeVar( + "T_Flatten", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Flatten(self, input: T_Flatten, *, axis: int = 1) -> T_Flatten: + r"""[🌐 Flatten(24)](https://onnx.ai/onnx/operators/onnx__Flatten.html#flatten-24 "Online Documentation") + + + Flattens the input tensor into a 2D matrix. If input tensor has shape + (d_0, d_1, ... d_n) then the output will have shape + (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn). + + + Args: + input: (differentiable) A tensor of rank >= axis. + + axis: Indicate up to which input dimensions (exclusive) should be flattened + to the outer dimension of the output. The value for axis must be in the + range [-r, r], where r is the rank of the input tensor. Negative value + means counting dimensions from the back. When axis = 0, the shape of the + output tensor is (1, (d_0 X d_1 ... d_n), where the shape of the input + tensor is (d_0, d_1, ... d_n). + """ + + schema = get_schema("Flatten", 24, "") + op = Op(self, "Flatten", schema) + return op(*self._prepare_inputs(schema, input), axis=axis) + + V_Identity = TypeVar( + "V_Identity", + Optional[Sequence[BOOL]], + Optional[Sequence[COMPLEX128]], + Optional[Sequence[COMPLEX64]], + Optional[Sequence[DOUBLE]], + Optional[Sequence[FLOAT]], + Optional[Sequence[FLOAT16]], + Optional[Sequence[INT16]], + Optional[Sequence[INT32]], + Optional[Sequence[INT64]], + Optional[Sequence[INT8]], + Optional[Sequence[STRING]], + Optional[Sequence[UINT16]], + Optional[Sequence[UINT32]], + Optional[Sequence[UINT64]], + Optional[Sequence[UINT8]], + Optional[BOOL], + Optional[COMPLEX128], + Optional[COMPLEX64], + Optional[DOUBLE], + Optional[FLOAT], + Optional[FLOAT16], + Optional[INT16], + Optional[INT32], + Optional[INT64], + Optional[INT8], + Optional[STRING], + Optional[UINT16], + Optional[UINT32], + Optional[UINT64], + Optional[UINT8], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Identity(self, input: V_Identity) -> V_Identity: + r"""[🌐 Identity(24)](https://onnx.ai/onnx/operators/onnx__Identity.html#identity-24 "Online Documentation") + + Identity operator + + Args: + input: (differentiable) Input tensor + """ + + schema = get_schema("Identity", 24, "") + op = Op(self, "Identity", schema) + return op(*self._prepare_inputs(schema, input)) + + B_If: TypeAlias = BOOL + + V_If: TypeAlias = Union[ + None, + Sequence[BFLOAT16], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + Sequence[FLOAT4E2M1], + Sequence[FLOAT8E4M3FN], + Sequence[FLOAT8E4M3FNUZ], + Sequence[FLOAT8E5M2], + Sequence[FLOAT8E5M2FNUZ], + Sequence[FLOAT8E8M0], + Sequence[INT4], + Sequence[UINT4], + ] + + def If(self, cond: B_If, *, else_branch: GraphProto, then_branch: GraphProto) -> V_If: + r"""[🌐 If(24)](https://onnx.ai/onnx/operators/onnx__If.html#if-24 "Online Documentation") + + If conditional + + Args: + cond: Condition for the if. The tensor must contain a single element. + + else_branch: Graph to run if condition is false. Has N outputs: values you + wish to be live-out to the enclosing scope. The number of outputs must + match the number of outputs in the then_branch. + + then_branch: Graph to run if condition is true. Has N outputs: values you + wish to be live-out to the enclosing scope. The number of outputs must + match the number of outputs in the else_branch. + """ + + schema = get_schema("If", 24, "") + op = Op(self, "If", schema) + return op( + *self._prepare_inputs(schema, cond), + else_branch=else_branch, + then_branch=then_branch, + ) + + I_Loop: TypeAlias = INT64 + + B_Loop: TypeAlias = BOOL + + V_Loop = TypeVar( + "V_Loop", + Optional[Sequence[BFLOAT16]], + Optional[Sequence[BOOL]], + Optional[Sequence[COMPLEX128]], + Optional[Sequence[COMPLEX64]], + Optional[Sequence[DOUBLE]], + Optional[Sequence[FLOAT]], + Optional[Sequence[FLOAT16]], + Optional[Sequence[INT16]], + Optional[Sequence[INT32]], + Optional[Sequence[INT64]], + Optional[Sequence[INT8]], + Optional[Sequence[STRING]], + Optional[Sequence[UINT16]], + Optional[Sequence[UINT32]], + Optional[Sequence[UINT64]], + Optional[Sequence[UINT8]], + Optional[BFLOAT16], + Optional[BOOL], + Optional[COMPLEX128], + Optional[COMPLEX64], + Optional[DOUBLE], + Optional[FLOAT], + Optional[FLOAT16], + Optional[FLOAT4E2M1], + Optional[FLOAT8E4M3FN], + Optional[FLOAT8E4M3FNUZ], + Optional[FLOAT8E5M2], + Optional[FLOAT8E5M2FNUZ], + Optional[FLOAT8E8M0], + Optional[INT16], + Optional[INT32], + Optional[INT4], + Optional[INT64], + Optional[INT8], + Optional[STRING], + Optional[UINT16], + Optional[UINT32], + Optional[UINT4], + Optional[UINT64], + Optional[UINT8], + Sequence[BFLOAT16], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[FLOAT4E2M1], + Sequence[FLOAT8E4M3FN], + Sequence[FLOAT8E4M3FNUZ], + Sequence[FLOAT8E5M2], + Sequence[FLOAT8E5M2FNUZ], + Sequence[FLOAT8E8M0], + Sequence[INT16], + Sequence[INT32], + Sequence[INT4], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT4], + Sequence[UINT64], + Sequence[UINT8], + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Loop( + self, + M: Optional[I_Loop], + cond: Optional[B_Loop], + *v_initial: V_Loop, + body: GraphProto, + ) -> V_Loop: + r"""[🌐 Loop(24)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-24 "Online Documentation") + + + Generic Looping construct. This loop has multiple termination conditions: + + 1) Trip count. Iteration count specified at runtime. Set by + specifying the input M. Optional. Set to empty string to omit. + Note that a static trip count (specified at graph construction time) can be + specified by passing in a constant node for input M. + 2) Loop termination condition. This is an input to the op that determines + whether to run the first iteration and also a loop-carried dependency for + the body graph. The body graph must yield a value for the condition variable, + whether this input is provided or not. + + This table summarizes the operating modes of this operator with equivalent + C-style code: + + Operator inputs defined as (max_trip_count, condition_var). + + * input ("", ""): + for (int i=0; ; ++i) { + cond = ... // Note this value is ignored, but is required in the body + } + + * input ("", cond) // Note this is analogous to a while loop + bool cond = ...; + for (int i=0; cond; ++i) { + cond = ...; + } + + * input ("", 1) // Note this is analogous to a do-while loop + bool cond = true + for (int i=0; cond; ++i) { + cond = ...; + } + + * input (trip_count, "") // Note this is analogous to a for loop + int trip_count = ... + for (int i=0; i < trip_count; ++i) { + cond = ...; // ignored + } + + * input (trip_count, cond) + int trip_count = ...; + bool cond = ...; + for (int i=0; i < trip_count && cond; ++i) { + cond = ...; + } + + + *Sample usage - cond as well as trip count* + + graph predict-net { + %a = Constant[value = ]() + %b = Constant[value = ]() + %keepgoing = Constant[value = ]() + %max_trip_count = Constant[value = ]() + %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b) + return + } + + graph body-net ( + %i[INT32, scalar] // iteration number + %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used + %b_in[INT32, scalar] // incoming value of loop-carried-dependency b + ) { + %my_local = Add(%a, %b_in) + %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b + %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition + %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated + return %keepgoing_out, %b_out, %user_defined_val + } + + *Sample equivalent C code* + + { + /* User-defined code (enclosing scope) */ + int a = 3, b = 6; + bool keepgoing = true; // Analogous to input cond + /* End user-defined code */ + + /* Implicitly-defined code */ + const int max_trip_count = 10; // Analogous to input M + int user_defined_vals[]; // Imagine this is resizable + /* End implicitly-defined code */ + /* initialize loop-carried variables and scan-output variables */ + bool keepgoing_out = keepgoing + int b_out = b + + for (int i=0; i < max_trip_count && keepgoing_out; ++i) { + /* Implicitly-defined code: bind actual parameter values + to formal parameter variables of loop-body */ + bool keepgoing_in = keepgoing_out; + bool b_in = b_out; + + /* User-defined code (loop body) */ + int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine + b_out = a - b_in; + keepgoing_out = my_local > b_out; + user_defined_val = b_in + b_in; // b_in and b_out are different variables + /* End user-defined code */ + + /* Implicitly defined-code */ + user_defined_vals[i] = user_defined_val // accumulate scan-output values + } + // int t = my_local; // Can't do this. my_local is not accessible here. + + // The values below are bound to the output variables of the loop and therefore accessible + // b_out; user_defined_vals; keepgoing_out; + } + + There are several things of note in this code snippet: + + 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can + be referenced in the inputs of the loop. + 2) Any values computed in the loop body that needs to be used in a subsequent + iteration or after the loop are modelled using a pair of variables in the loop-body, + consisting of an input variable (eg., b_in) and an output variable (eg., b_out). + These are referred to as loop-carried dependences. The loop operation node + supplies the input value of the input variable for the first iteration, and + returns the output value of the output variable produced by the final + iteration. + 3) Scan_output variables are used to implicitly concatenate values computed across + all the iterations. In the above example, the value of user_defined_val computed + over all iterations are concatenated and returned as the value of user_defined_vals + after the loop. + 4) Values created in the body cannot be accessed in the enclosing scope, + except using the mechanism described above. + + Note that the semantics of this op support "diagonal" or "wavefront" execution. + (See Step 3 here for an example: + https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/). + Frontends should emit multi-layer RNNs as a series of While operators (with + time being the inner looping dimension), with each successive layer consuming + the scan_outputs from the previous layer, possibly going through several + point-wise operators (e.g. dropout, residual connections, linear layer). + + The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order. + + + Args: + M: (optional) A maximum trip-count for the loop specified at runtime. + Optional. Pass empty string to skip. + + cond: (optional) A boolean termination condition. Optional. Pass empty + string to skip. + + v_initial: (variadic, heterogeneous) The initial values of any loop-carried + dependencies (values that change across loop iterations) + + body: The graph run each iteration. It has 2+N inputs: (iteration_num, + condition, loop carried dependencies...). It has 1+N+K outputs: + (condition, loop carried dependencies..., scan_outputs...). Each + scan_output is created by concatenating the value of the specified + output value at the end of each iteration of the loop. It is an error if + the dimensions or data type of these scan_outputs change across loop + iterations. + """ + + schema = get_schema("Loop", 24, "") + op = Op(self, "Loop", schema) + return op(*self._prepare_inputs(schema, M, cond, *v_initial), body=body) + + T_Pad = TypeVar( + "T_Pad", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + Tind_Pad = TypeVar("Tind_Pad", INT32, INT64) + + def Pad( + self, + data: T_Pad, + pads: INT64, + constant_value: Optional[T_Pad] = None, + axes: Optional[Tind_Pad] = None, + *, + mode: str = "constant", + ) -> T_Pad: + r"""[🌐 Pad(24)](https://onnx.ai/onnx/operators/onnx__Pad.html#pad-24 "Online Documentation") + + + Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, + a padded tensor (`output`) is generated. + + The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`): + + 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False) + + 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis + + 3) `edge` - pads with the edge values of array + + 4) `wrap` - wrap-around padding as if the data tensor forms a torus + + + Example 1 (`constant` mode): + + Insert 0 pads to the beginning of the second dimension. + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'constant' + + constant_value = 0.0 + + output = [ + [0.0, 0.0, 1.0, 1.2], + [0.0, 0.0, 2.3, 3.4], + [0.0, 0.0, 4.5, 5.7], + ] + + + + Example 2 (`reflect` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'reflect' + + output = [ + [1.0, 1.2, 1.0, 1.2], + [2.3, 3.4, 2.3, 3.4], + [4.5, 5.7, 4.5, 5.7], + ] + + + + Example 3 (`edge` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'edge' + + output = [ + [1.0, 1.0, 1.0, 1.2], + [2.3, 2.3, 2.3, 3.4], + [4.5, 4.5, 4.5, 5.7], + ] + + + + Example 4 (`wrap` mode): + + :: + + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [2, 1, 1, 1] + + mode = 'wrap' + + output = [ + [3.4, 2.3, 3.4, 2.3], + [5.7, 4.5, 5.7, 4.5], + [1.2, 1.0, 1.2, 1.0], + [3.4, 2.3, 3.4, 2.3], + [5.7, 4.5, 5.7, 4.5], + [1.2, 1.0, 1.2, 1.0], + ] + + + + + Args: + data: (differentiable) Input tensor. + + pads: (non-differentiable) Tensor of integers indicating the number of + padding elements to add or remove (if negative) at the beginning and end + of each axis. For 2D input tensor, it is the number of pixels. `pads` + should be a 1D tensor of shape [2 * num_axes] where `num_axes` refers to + the number of elements in the `axes` input or the input rank if `axes` + are not provided explicitly. `pads` format should be: [x1_begin, + x2_begin, ..., x1_end, x2_end,...], where xi_begin is the number of pad + values added at the beginning of axis `axes[i]` and xi_end, the number + of pad values added at the end of axis `axes[i]`. + + constant_value: (optional, non-differentiable) (Optional) A scalar value to + be used if the mode chosen is `constant` (by default it is 0, empty + string or False). + + axes: (optional, non-differentiable) 1-D tensor of axes that `pads` apply + to. Negative value means counting dimensions from the back. Accepted + range is [-r, r-1] where r = rank(data). Behavior is undefined if an + axis is repeated. If not provided, all axes are assumed (`[0, 1, ..., + input_rank-1]`). + + mode: Supported modes: `constant`(default), `reflect`, `edge`, `wrap` + """ + + schema = get_schema("Pad", 24, "") + op = Op(self, "Pad", schema) + return op(*self._prepare_inputs(schema, data, pads, constant_value, axes), mode=mode) + + T1_QuantizeLinear = TypeVar("T1_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, INT32) + + T2_QuantizeLinear = TypeVar( + "T2_QuantizeLinear", BFLOAT16, FLOAT, FLOAT16, FLOAT8E8M0, INT32 + ) + + T3_QuantizeLinear = TypeVar( + "T3_QuantizeLinear", + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + INT16, + INT4, + INT8, + UINT16, + UINT4, + UINT8, + ) + + def QuantizeLinear( + self, + x: T1_QuantizeLinear, + y_scale: T2_QuantizeLinear, + y_zero_point: Optional[T3_QuantizeLinear] = None, + *, + axis: int = 1, + block_size: int = 0, + output_dtype: int = 0, + precision: int = 0, + saturate: int = 1, + ) -> T3_QuantizeLinear: + r"""[🌐 QuantizeLinear(24)](https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html#quantizelinear-24 "Online Documentation") + + + The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the + low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization + granularity. The quantization formula is `y = saturate((x / y_scale) + y_zero_point)`. + + Saturation is done according to: + - uint16: [0, 65535] + - int16: [-32768, 32767] + - uint8: [0, 255] + - int8: [-128, 127] + - uint4: [0, 15] + - int4: [-8, 7] + + For `(x / y_scale)`, it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details. + + `y_zero_point` and `y` must have the same type. `y_zero_point` is usually not used for quantization to float8 and 4bit types, but the quantization + formula remains the same for consistency, and the type of the attribute `y_zero_point` still determines the quantization type. + `x` and `y_scale` are allowed to have different types. The type of `y_scale` determines the precision of the division operation between `x` and + `y_scale`, unless the `precision` attribute is specified. + + There are three supported quantization granularities, determined by the shape of `y_scale`. + In all cases, `y_zero_point` must have the same shape as `y_scale`. + - Per-tensor (per-layer) quantization: `y_scale` is a scalar. + - Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape + `(D0, ..., Di, ..., Dn)` and `axis=i`, `y_scale` is a 1-D tensor of length `Di`. + - Blocked quantization: The scale's shape is identical to the input's shape, except for one dimension, in which + blocking is performed. Given `x` shape `(D0, ..., Di, ..., Dn)`, `axis=i`, and block size `B`: `y_scale` shape is + `(D0, ..., ceil(Di/B), ..., Dn)`. + + + Args: + x: N-D full precision Input tensor to be quantized. + + y_scale: Scale for doing quantization to get `y`. For per-tensor/layer + quantization the scale is a scalar, for per-axis quantization it is a + 1-D Tensor and for blocked quantization it has the same shape as the + input, except for one dimension in which blocking is performed. + + y_zero_point: (optional) Zero point for doing quantization to get `y`. Shape + must match `y_scale`. Default is uint8 with zero point of 0 if it's not + specified. + + axis: (Optional) The axis of the dequantizing dimension of the input tensor. + Used only for per-axis and blocked quantization. Negative value means + counting dimensions from the back. Accepted range is `[-r, r-1]` where + `r = rank(input)`. When the rank of the input is 1, per-tensor + quantization is applied, rendering the axis unnecessary in this + scenario. + + block_size: (Optional) The size of the quantization block (number of times + every scale is replicated). Used only for blocked quantization. The + block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., + Dn)`, `y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted + range is `[ceil(Di/Si), ceil(Di/(Si-1))-1]` + + output_dtype: (Optional) The output data type. If not supplied, the output + data type is inferred from `y_zero_point` data type (`T3`). If neither + `output_dtype` nor `y_zero_point` are supplied, output data type is + uint8. If both `output_dtype` and `y_zero_point` are specified, + `output_dtype` must be `T3`. + + precision: (Optional) The precision of the division operation between `x` + and `y_scale`. If not provided, it will be the same as the type of + `y_scale`. + + saturate: The parameter defines how the conversion behaves if an input value + is out of range of the destination type. It only applies for float 8 + quantization (float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). + It is true by default. All cases are fully described in two tables + inserted in the operator description. + """ + + schema = get_schema("QuantizeLinear", 24, "") + op = Op(self, "QuantizeLinear", schema) + return op( + *self._prepare_inputs(schema, x, y_scale, y_zero_point), + axis=axis, + block_size=block_size, + output_dtype=output_dtype, + precision=precision, + saturate=saturate, + ) + + T_Reshape = TypeVar( + "T_Reshape", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Reshape(self, data: T_Reshape, shape: INT64, *, allowzero: int = 0) -> T_Reshape: + r"""[🌐 Reshape(24)](https://onnx.ai/onnx/operators/onnx__Reshape.html#reshape-24 "Online Documentation") + + + Reshape the input tensor similar to numpy.reshape. + First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor. + At most one dimension of the new shape can be -1. In this case, the value is + inferred from the size of the tensor and the remaining dimensions. A dimension + could also be 0, in which case the actual dimension value is unchanged (i.e. taken + from the input tensor). If 'allowzero' is set, and the new shape includes 0, the + dimension will be set explicitly to zero (i.e. not taken from input tensor). + Shape (second input) could be an empty shape, which means converting to a scalar. + The input tensor's shape and the output tensor's shape are required to have the same number of elements. + + If the attribute 'allowzero' is set, it is invalid for the specified shape to + contain both a zero value and -1, as the value of the dimension corresponding + to -1 cannot be determined uniquely. + + + Args: + data: (differentiable) An input tensor. + + shape: (non-differentiable) Specified shape for output. + + allowzero: (Optional) By default, when any value in the 'shape' input is + equal to zero the corresponding dimension value is copied from the input + tensor dynamically. allowzero=1 indicates that if any value in the + 'shape' input is set to zero, the zero value is honored, similar to + NumPy. + """ + + schema = get_schema("Reshape", 24, "") + op = Op(self, "Reshape", schema) + return op(*self._prepare_inputs(schema, data, shape), allowzero=allowzero) + + V_Scan = TypeVar( + "V_Scan", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Scan( + self, + *initial_state_and_scan_inputs: V_Scan, + body: GraphProto, + num_scan_inputs: int, + scan_input_axes: Optional[Sequence[int]] = None, + scan_input_directions: Optional[Sequence[int]] = None, + scan_output_axes: Optional[Sequence[int]] = None, + scan_output_directions: Optional[Sequence[int]] = None, + ) -> V_Scan: + r"""[🌐 Scan(24)](https://onnx.ai/onnx/operators/onnx__Scan.html#scan-24 "Online Documentation") + + + Scan can be used to iterate over one or more scan_input tensors, + constructing zero or more scan_output tensors. It combines ideas from general recurrences, + functional programming constructs such as scan, fold, map, and zip, and is intended to enable + generalizations of RNN-like constructs for sequence-to-sequence processing. + Other tensors (referred to as state_variables here) can be used to carry a state + when iterating from one element to another (similar to hidden-state in RNNs, also referred + to as loop-carried dependences in the context of loops). + Many common usages involve a single scan_input tensor (where functionality + similar to scan, fold and map can be obtained). When more than one scan_input is used, + a behavior similar to zip is obtained. + + The attribute body must be a graph, specifying the computation to be performed in + every iteration. It takes as input the current values of the state_variables and + the current iterated element of the scan_inputs. It must return the (updated) values + of the state_variables and zero or more scan_output_element tensors. The values of the + scan_output_element tensors are concatenated over all the iterations to produce the + scan_output values of the scan construct (similar to the concatenated intermediate + hidden-state values of RNN-like constructs). All the output tensors (state_variables as + well as scan_output_element tensors) are required to have the same shape in each iteration + of the loop (a restriction imposed to enable efficient memory allocation). + + Note that the iterated element passed to the body subgraph does not have a sequence + axis. It will have a rank one less than the rank of the corresponding scan_input. + + The scan operation returns the final values of the state_variables as well as the + scan_outputs. + + The optional attribute scan_input_directions specifies the direction (forward or backward) + for each scan input. If this attribute is omitted, all sequences are scanned in the forward + direction. A bidirectional scan may be performed by specifying the same tensor input twice + in the scan_inputs, once with a forward direction, and once with a backward direction. + + The scan_output of the operation is produced by concatenating the scan_output_element + values produced by the body in each iteration. The optional attribute scan_output_directions + specifies the direction in which scan_output is constructed (by appending or prepending the + scan_output_element to scan_output in each iteration) for each scan_output. If this attribute + is omitted, the scan_output_element is appended to the scan_output in each iteration. + + The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input. + If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the + batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1. + Note that scanning a non-zero axis may be less efficient than scanning axis zero. + + The optional attribute scan_output_axes specifies the axis along which the scan_outputs + are accumulated for each scan_output. For example, if axis 1 is the time axis (to be + scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis + value of 1. + + Note that because of the ONNX restriction that only the last parameter of an operator can + be variadic, the initial-states and scan-inputs are listed together as one input parameter. + Similarly, the final-states and scan-outputs are listed together as one output parameter. + The attribute num_scan_inputs indicates the number M of scan-inputs. + + The behavior of + + Scan < + num_scan_inputs = m, + body = loop-body, + scan_input_axes = [axis_1, ..., axis_m] + > (init_1, ..., init_n, scan_1, ..., scan_m) + + is equivalent to the following pseudo-code: + + // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i + // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j. + sequence_length = scan_1.shape[axis_1]; + + // initialize state-variables + st_1 = init_1; ... st_n = init_n; + // initialize scan-output variables: [] denotes an empty tensor + scan_out_1 = []; ...; scan_out_k = []; + // identify number of iterations: + + // execute loop + for (int t = 0; t < sequence_length; ++t) { + // generate the scan-input elements: the notation T[t] indicates the sub-tensor + // of rank one less than T obtained by indexing T at position t along axis k. + si_1 = scan_1[t]; + ... ; + si_m = scan_m[t]; + // execute loop-body + st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m) + // accumulate the scan-output elements + scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k); + } + + return st_1, ..., st_n, scan_out_1, ..., scan_out_k; + + *Sample usage: Encoding RNN using a Scan* + + The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi, + recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can + be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes + %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these + values are computed in the outer graph, they need to be passed in as extra state_variables. + + graph rnn-encoding { + %H_0 = ... + %X = ... + %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X) + return %Y, %Y_h + } + + graph rnn-cell-1 ( + %H_tminus1[FLOAT, tensor] + %X_t[FLOAT, tensor] + ) { + %Wi = ... + %Ri = ... + %Wbi = ... + %Rbi = ... + %t1 = X_t * (Wi^T) + %t2 = H_tminus1*(Ri^T) + %t3 = Add(%t1, %t2) + %t4 = Add(%t3, %Wbi) + %t5 = Add(%t4, %Rbi) + %Ht = Tanh(%t5) + %Accumulate = Identity(%Ht) + return %Ht, %Accumulate + } + + + + Args: + initial_state_and_scan_inputs: (variadic, heterogeneous) Initial values of + the loop's N state variables followed by M scan_inputs + + body: The graph run each iteration. It has N+M inputs: (loop state + variables..., scan_input_elts...). It has N+K outputs: (loop state + variables..., scan_output_elts...). Each scan_output is created by + concatenating the value of the specified scan_output_elt value at the + end of each iteration of the loop. It is an error if the dimensions of + these values change across loop iterations. + + num_scan_inputs: An attribute specifying the number of scan_inputs M. + + scan_input_axes: An optional list of M flags. The i-th element of the list + specifies the axis to be scanned (the sequence axis) for the i-th + scan_input. If omitted, 0 will be used as the scan axis for every + scan_input. Negative value for an axis means counting dimensions from + the back. Accepted range is [-r, r-1] where r = rank(input). + + scan_input_directions: An optional list of M flags. The i-th element of the + list specifies the direction to be scanned for the i-th scan_input + tensor: 0 indicates forward direction and 1 indicates reverse direction. + If omitted, all scan_input tensors will be scanned in the forward + direction. + + scan_output_axes: An optional list of K flags. The i-th element of the list + specifies the axis for the i-th scan_output. The scan outputs are + accumulated along the specified axis. If omitted, 0 will be used as the + scan axis for every scan_output. Negative value for an axis means + counting dimensions from the back. Accepted range is [-r, r-1]. + + scan_output_directions: An optional list of K flags, one for each + scan_output. The i-th element of the list specifies whether the i-th + scan_output should be constructed by appending or prepending a new value + in each iteration: 0 indicates appending and 1 indicates prepending. If + omitted, all scan_output tensors will be produced by appending a value + in each iteration. + """ + + schema = get_schema("Scan", 24, "") + op = Op(self, "Scan", schema) + return op( + *self._prepare_inputs(schema, *initial_state_and_scan_inputs), + body=body, + num_scan_inputs=num_scan_inputs, + scan_input_axes=scan_input_axes, + scan_input_directions=scan_input_directions, + scan_output_axes=scan_output_axes, + scan_output_directions=scan_output_directions, + ) + + T_Shape = TypeVar( + "T_Shape", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T1_Shape: TypeAlias = INT64 + + def Shape(self, data: T_Shape, *, end: Optional[int] = None, start: int = 0) -> T1_Shape: + r"""[🌐 Shape(24)](https://onnx.ai/onnx/operators/onnx__Shape.html#shape-24 "Online Documentation") + + + Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor. + Optional attributes start and end can be used to compute a slice of the input tensor's shape. + If start axis is omitted, the slice starts from axis 0. + The end axis, if specified, is exclusive (and the returned value will not include the size of that axis). + If the end axis is omitted, the axes upto the last one will be included. + Negative axes indicate counting back from the last axis. + Note that axes will be clamped to the range [0, r], where r is the + rank of the input tensor if they are out-of-range (after adding r in the case of + negative axis). Thus, specifying any end value > r is equivalent to specifying an end + value of r, and specifying any start value < -r is equivalent to specifying a start + value of 0. If start > end, the result will be an empty shape. + + Examples: + + :: + + Input tensor with shape: [2, 3, 4] + No attributes specified. + Output: [2, 3, 4] + + + + :: + + Input tensor with shape: [2, 3, 4] + start: -1 + Output: [4] + + + + :: + + Input tensor with shape: [2, 3, 4] + end: -1 + Output: [2, 3] + + + + :: + + Input tensor with shape: [2, 3, 4] + start: 1 + end: 2 + Output: [3] + + + + + Args: + data: (non-differentiable) An input tensor. + + end: (Optional) Ending axis for slicing the shape. Negative value means + counting dimensions from the back. If omitted, sizes of all axes upto + (including) the last one will be included. + + start: (Optional) Starting axis for slicing the shape. Default value is + 0.Negative value means counting dimensions from the back. + """ + + schema = get_schema("Shape", 24, "") + op = Op(self, "Shape", schema) + return op(*self._prepare_inputs(schema, data), end=end, start=start) + + T_Size = TypeVar( + "T_Size", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + T1_Size: TypeAlias = INT64 + + def Size(self, data: T_Size) -> T1_Size: + r"""[🌐 Size(24)](https://onnx.ai/onnx/operators/onnx__Size.html#size-24 "Online Documentation") + + + Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor. + + + Args: + data: (non-differentiable) An input tensor. + """ + + schema = get_schema("Size", 24, "") + op = Op(self, "Size", schema) + return op(*self._prepare_inputs(schema, data)) + + T_SplitToSequence = TypeVar( + "T_SplitToSequence", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + I_SplitToSequence = TypeVar("I_SplitToSequence", INT32, INT64) + + S_SplitToSequence: TypeAlias = Union[ + Sequence[BFLOAT16], + Sequence[BOOL], + Sequence[COMPLEX128], + Sequence[COMPLEX64], + Sequence[DOUBLE], + Sequence[FLOAT], + Sequence[FLOAT16], + Sequence[INT16], + Sequence[INT32], + Sequence[INT64], + Sequence[INT8], + Sequence[STRING], + Sequence[UINT16], + Sequence[UINT32], + Sequence[UINT64], + Sequence[UINT8], + ] + + def SplitToSequence( + self, + input: T_SplitToSequence, + split: Optional[I_SplitToSequence] = None, + *, + axis: int = 0, + keepdims: int = 1, + ) -> S_SplitToSequence: + r"""[🌐 SplitToSequence(24)](https://onnx.ai/onnx/operators/onnx__SplitToSequence.html#splittosequence-24 "Online Documentation") + + + Split a tensor into a sequence of tensors, along the specified 'axis'. + Lengths of the parts can be specified using the optional argument 'split'. + If the argument `split' is not specified, a default scalar value of 1 + is used as the value of `split'. + 'split' must contain only positive numbers. + 'split' is either a scalar (tensor of empty shape), or a 1-D tensor. + If 'split' is a scalar, then 'input' will be split into chunks all of size 'split' + if possible. The last chunk alone may be smaller than 'split' if the 'input' size + along the given axis 'axis' is not divisible by 'split'. + If 'split' is a 1-dimensional tensor, the input tensor is split into 'size(split)' chunks, + with lengths of the parts on 'axis' specified in 'split'. In this scenario, the sum of entries + in 'split' must be equal to the dimension size of input tensor on 'axis'. + + + Args: + input: The tensor to split + + split: (optional) Length of each output. It can be either a scalar(tensor of + empty shape), or a 1-D tensor. All values must be >= 0. + + axis: Which axis to split on. A negative value means counting dimensions + from the back. Accepted range is [-rank, rank-1]. + + keepdims: Keep the split dimension or not. Default 1, which means we keep + split dimension. If input 'split' is specified, this attribute is + ignored. + """ + + schema = get_schema("SplitToSequence", 24, "") + op = Op(self, "SplitToSequence", schema) + return op(*self._prepare_inputs(schema, input, split), axis=axis, keepdims=keepdims) + + T_Squeeze = TypeVar( + "T_Squeeze", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Squeeze(self, data: T_Squeeze, axes: Optional[INT64] = None) -> T_Squeeze: + r"""[🌐 Squeeze(24)](https://onnx.ai/onnx/operators/onnx__Squeeze.html#squeeze-24 "Online Documentation") + + + Remove single-dimensional entries from the shape of a tensor. + Takes an input `axes` with a list of axes to squeeze. + If `axes` is not provided, all the single dimensions will be removed from + the shape. If an axis is selected with shape entry not equal to one, an error is raised. + + + Args: + data: (differentiable) Tensors with at least max(dims) dimensions. + + axes: (optional, non-differentiable) 1D tensor of integers indicating the + dimensions to squeeze. Negative value means counting dimensions from the + back. Accepted range is [-r, r-1] where r = rank(data). + """ + + schema = get_schema("Squeeze", 24, "") + op = Op(self, "Squeeze", schema) + return op(*self._prepare_inputs(schema, data, axes)) + + T_Swish = TypeVar("T_Swish", BFLOAT16, DOUBLE, FLOAT, FLOAT16) + + def Swish(self, X: T_Swish, *, alpha: float = 1.0) -> T_Swish: + r"""[🌐 Swish(24)](https://onnx.ai/onnx/operators/onnx__Swish.html#swish-24 "Online Documentation") + + + Swish function takes one input data (Tensor) and produces one output data (Tensor) of the same shape, + where $Swish(x) = x * sigmoid(alpha * x)$. + + + Args: + X: (differentiable) Input tensor + + alpha: Coefficient to multiply with input before sigmoid. + """ + + schema = get_schema("Swish", 24, "") + op = Op(self, "Swish", schema) + return op(*self._prepare_inputs(schema, X), alpha=alpha) + + T_TensorScatter = TypeVar( + "T_TensorScatter", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def TensorScatter( + self, + past_cache: T_TensorScatter, + update: T_TensorScatter, + write_indices: Optional[INT64] = None, + *, + axis: int = -2, + mode: str = "linear", + ) -> T_TensorScatter: + r"""[🌐 TensorScatter(24)](https://onnx.ai/onnx/operators/onnx__TensorScatter.html#tensorscatter-24 "Online Documentation") + + + TensorScatter is a generic tensor update operation, motivated by the requirements for KV cache updates for Attention + ops commonly found in LLMs. It is a functional operation that models an in-place update to a KV cache buffer. + + The past and present cache tensors have the same shape (batch_size, D1, D2, ..., max_sequence_length, ..., Dn), with + the sequence dimension (indicated by the `axis` attribute) being max_sequence_length, so the sizes of these tensors do + not need to grow between iterations. The `update` tensor's shape only differs from the cache tensors in the sequence + dimension: (batch_size, D1, D2, ..., sequence_length, ..., Dn), where sequence_length <= max_sequence_length. + + The optional `write_indices` input indicates the write index for each sample in the batch, assumed to be zero + if not provided. When the `mode` attribute is set to "circular", the write index is modulo max_sequence_length. + The operation can be described using the following pseudocode: + + :: + + for prefix_idx in np.ndindex(past_cache.shape[:axis]): + batch_idx = prefix_idx[0] + for sequence_idx in range(sequence_length): + cache_idx = (*prefix_idx, write_indices[batch_idx] + sequence_idx) + if mode == "circular": + cache_idx = tuple(np.mod(np.asarray(cache_idx), max_sequence_length)) + update_idx = (*prefix_idx, sequence_idx) + present_cache[cache_idx] = update[update_idx] + + + + During the prefill phase of attention, only the first two inputs are needed. During the decode phase, `write_indices` + is also needed so that the incoming key or value update can be appended after the last valid token for each sample + in the batch. + + + Args: + past_cache: (differentiable) Past state cache for key or value with shape + `(batch_size, D1, D2, ..., max_sequence_length, ..., Dn)`. + + update: (differentiable) New update tensor with shape `(batch_size, D1, D2, + ..., sequence_length, ..., Dn)`. + + write_indices: (optional, non-differentiable) Write indices for the incoming + update tensor in the cache. Shape is `(batch_size,)`. Assumed to be all + zeros if not provided. + + axis: Sequence dimension of the `past_cache` and `update` tensors. It cannot + be 0 (the batch dimension). Default is -2. + + mode: Write mode of cache update. Supported modes include `linear` and + `circular`. `linear` mode requires + write_indices+sequence_length<=max_sequence_length. For `circular` mode, + the updates happen in wrap-around fashion, ie, the update index is + modulo `max_sequence_length` + """ + + schema = get_schema("TensorScatter", 24, "") + op = Op(self, "TensorScatter", schema) + return op( + *self._prepare_inputs(schema, past_cache, update, write_indices), + axis=axis, + mode=mode, + ) + + T_TopK = TypeVar( + "T_TopK", + BFLOAT16, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, + ) + + I_TopK: TypeAlias = INT64 + + def TopK( + self, X: T_TopK, K: INT64, *, axis: int = -1, largest: int = 1, sorted: int = 1 + ) -> Tuple[T_TopK, I_TopK]: + r"""[🌐 TopK(24)](https://onnx.ai/onnx/operators/onnx__TopK.html#topk-24 "Online Documentation") + + + Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of + shape [a_0, a_1, ..., a_{n-1}] and integer argument k, return two outputs: + + * Value tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1}] + which contains the values of the top k elements along the specified axis + * Index tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1}] which + contains the indices of the top k elements (original indices from the input + tensor). + + * If "largest" is 1 (the default value) then the k largest elements are returned. + * If "sorted" is 1 (the default value) then the resulting k elements will be sorted. + * If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined. + + Given two equivalent values, this operator uses the indices along the axis as + a tiebreaker. That is, the element with the lower index will appear first. + + + Args: + X: (differentiable) Tensor of shape [a_0, a_1, ..., a_{n-1}] + + K: (non-differentiable) A 1-D tensor containing a single positive value + corresponding to the number of top elements to retrieve + + axis: Dimension on which to do the sort. Negative value means counting + dimensions from the back. Accepted range is [-r, r-1] where r = + rank(input). + + largest: Whether to return the top-K largest or smallest elements. + + sorted: Whether to return the elements in sorted order. + """ + + schema = get_schema("TopK", 24, "") + op = Op(self, "TopK", schema) + return op( + *self._prepare_inputs(schema, X, K), + axis=axis, + largest=largest, + sorted=sorted, + ) + + T_Transpose = TypeVar( + "T_Transpose", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Transpose( + self, data: T_Transpose, *, perm: Optional[Sequence[int]] = None + ) -> T_Transpose: + r"""[🌐 Transpose(24)](https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-24 "Online Documentation") + + + Transpose the input tensor similar to numpy.transpose. For example, when + perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape + will be (2, 1, 3). + + + Args: + data: (differentiable) An input tensor. + + perm: A list of integers. By default, reverse the dimensions, otherwise + permute the axes according to the values given. Its length must be equal + to the rank of the input. + """ + + schema = get_schema("Transpose", 24, "") + op = Op(self, "Transpose", schema) + return op(*self._prepare_inputs(schema, data), perm=perm) + + T_Unsqueeze = TypeVar( + "T_Unsqueeze", + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + FLOAT4E2M1, + FLOAT8E4M3FN, + FLOAT8E4M3FNUZ, + FLOAT8E5M2, + FLOAT8E5M2FNUZ, + FLOAT8E8M0, + INT16, + INT32, + INT4, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT4, + UINT64, + UINT8, + ) + + def Unsqueeze(self, data: T_Unsqueeze, axes: INT64) -> T_Unsqueeze: + r"""[🌐 Unsqueeze(24)](https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#unsqueeze-24 "Online Documentation") + + + Insert single-dimensional entries to the shape of an input tensor (`data`). + Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`). + + For example, given an input tensor (`data`) of shape [3, 4, 5], then + Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1]. + + The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates. + The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`. + Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1]. + The order of values in `axes` does not matter and can come in any order. + + + Args: + data: (differentiable) Original tensor + + axes: (non-differentiable) 1D tensor of integers indicating the dimensions + to be inserted. Negative value means counting dimensions from the back. + Accepted range is [-r, r-1] where r = rank(expanded). + """ + + schema = get_schema("Unsqueeze", 24, "") + op = Op(self, "Unsqueeze", schema) + return op(*self._prepare_inputs(schema, data, axes)) diff --git a/onnxscript/onnx_opset/_impl/opset3.py b/onnxscript/onnx_opset/_impl/opset3.py index f9bbf5d770..fd684dd238 100644 --- a/onnxscript/onnx_opset/_impl/opset3.py +++ b/onnxscript/onnx_opset/_impl/opset3.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset4.py b/onnxscript/onnx_opset/_impl/opset4.py index 0a4f68981a..a1b7fb890b 100644 --- a/onnxscript/onnx_opset/_impl/opset4.py +++ b/onnxscript/onnx_opset/_impl/opset4.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset5.py b/onnxscript/onnx_opset/_impl/opset5.py index f445cfdce4..d7e34f8d5d 100644 --- a/onnxscript/onnx_opset/_impl/opset5.py +++ b/onnxscript/onnx_opset/_impl/opset5.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset6.py b/onnxscript/onnx_opset/_impl/opset6.py index 911192df22..b7b7981154 100644 --- a/onnxscript/onnx_opset/_impl/opset6.py +++ b/onnxscript/onnx_opset/_impl/opset6.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations @@ -211,7 +210,18 @@ def BatchNormalization( ) T2_Cast: TypeAlias = Union[ - BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8 + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, ] def Cast(self, input: T1_Cast, *, to: int) -> T2_Cast: @@ -370,7 +380,7 @@ def Elu(self, X: T_Elu, *, alpha: float = 1.0) -> T_Elu: Args: - X: (differentiable) 1D input tensor + X: (differentiable) Input tensor alpha: Coefficient of ELU. """ diff --git a/onnxscript/onnx_opset/_impl/opset7.py b/onnxscript/onnx_opset/_impl/opset7.py index e584d06c5a..eed9bde7d2 100644 --- a/onnxscript/onnx_opset/_impl/opset7.py +++ b/onnxscript/onnx_opset/_impl/opset7.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset8.py b/onnxscript/onnx_opset/_impl/opset8.py index 39d01f198b..6bedb39b86 100644 --- a/onnxscript/onnx_opset/_impl/opset8.py +++ b/onnxscript/onnx_opset/_impl/opset8.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: D402 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset9.py b/onnxscript/onnx_opset/_impl/opset9.py index 7d99f002ff..be1cec969d 100644 --- a/onnxscript/onnx_opset/_impl/opset9.py +++ b/onnxscript/onnx_opset/_impl/opset9.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: E741, D402 # -------------------------------------------------------------------------- from __future__ import annotations @@ -313,7 +312,18 @@ def Constant(self, *, value: TensorProto) -> T_Constant: T1_ConstantOfShape: TypeAlias = INT64 T2_ConstantOfShape: TypeAlias = Union[ - BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8 + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, ] def ConstantOfShape( @@ -402,7 +412,18 @@ def Erf(self, input: T_Erf) -> T_Erf: ) T2_EyeLike: TypeAlias = Union[ - BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8 + BOOL, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + UINT16, + UINT32, + UINT64, + UINT8, ] def EyeLike( @@ -633,7 +654,7 @@ def MatMul(self, A: T_MatMul, B: T_MatMul) -> T_MatMul: r"""[🌐 MatMul(9)](https://onnx.ai/onnx/operators/onnx__MatMul.html#matmul-9 "Online Documentation") - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). Args: @@ -1142,7 +1163,12 @@ def Scan( Tind_Scatter = TypeVar("Tind_Scatter", INT32, INT64) def Scatter( - self, data: T_Scatter, indices: Tind_Scatter, updates: T_Scatter, *, axis: int = 0 + self, + data: T_Scatter, + indices: Tind_Scatter, + updates: T_Scatter, + *, + axis: int = 0, ) -> T_Scatter: r"""[🌐 Scatter(9)](https://onnx.ai/onnx/operators/onnx__Scatter.html#scatter-9 "Online Documentation") diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py index a190eb17f9..d69cc686a0 100644 --- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py +++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: N801, D417 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py index a78e3ae551..49b38d3344 100644 --- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py +++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: N801 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py index 0092b4fd40..57c0d90a4e 100644 --- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py +++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: N801 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py index 552e545d75..02dc271c6e 100644 --- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py +++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py @@ -2,13 +2,12 @@ # ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ # ⚙️ Generated by 'python -m opgen' # -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=W0221,W0222,R0901,W0237 # mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 +# ruff: noqa: N801 # -------------------------------------------------------------------------- from __future__ import annotations diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py new file mode 100644 index 0000000000..d3f3f0b5cc --- /dev/null +++ b/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py @@ -0,0 +1,157 @@ +# -------------------------------------------------------------------------- +# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ +# ⚙️ Generated by 'python -m opgen' +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=W0221,W0222,R0901,W0237 +# mypy: disable-error-code=override +# ruff: noqa: N801 +# -------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Optional, Sequence, TypeVar + +from onnx import TensorProto +from onnx.defs import get_schema + +from onnxscript.onnx_opset._impl.opset_ai_onnx_ml4 import Opset_ai_onnx_ml4 +from onnxscript.onnx_types import DOUBLE, FLOAT, FLOAT16 +from onnxscript.values import Op, Opset + + +class Opset_ai_onnx_ml5(Opset_ai_onnx_ml4): + def __new__(cls): + return Opset.__new__(cls, "ai.onnx.ml", 5) + + T_TreeEnsemble = TypeVar("T_TreeEnsemble", DOUBLE, FLOAT, FLOAT16) + + def TreeEnsemble( + self, + X: T_TreeEnsemble, + *, + aggregate_function: int = 1, + leaf_targetids: Sequence[int], + leaf_weights: TensorProto, + membership_values: Optional[TensorProto] = None, + n_targets: Optional[int] = None, + nodes_falseleafs: Sequence[int], + nodes_falsenodeids: Sequence[int], + nodes_featureids: Sequence[int], + nodes_hitrates: Optional[TensorProto] = None, + nodes_missing_value_tracks_true: Optional[Sequence[int]] = None, + nodes_modes: TensorProto, + nodes_splits: TensorProto, + nodes_trueleafs: Sequence[int], + nodes_truenodeids: Sequence[int], + post_transform: int = 0, + tree_roots: Sequence[int], + ) -> T_TreeEnsemble: + r"""[🌐 ai.onnx.ml::TreeEnsemble(5)](https://onnx.ai/onnx/operators/onnx_aionnxml_TreeEnsemble.html#treeensemble-5 "Online Documentation") + + + Tree Ensemble operator. Returns the regressed values for each input in a batch. + Inputs have dimensions `[N, F]` where `N` is the input batch size and `F` is the number of input features. + Outputs have dimensions `[N, num_targets]` where `N` is the batch size and `num_targets` is the number of targets, which is a configurable attribute. + + The encoding of this attribute is split along interior nodes and the leaves of the trees. Notably, attributes with the prefix `nodes_*` are associated with interior nodes, and attributes with the prefix `leaf_*` are associated with leaves. + The attributes `nodes_*` must all have the same length and encode a sequence of tuples, as defined by taking all the `nodes_*` fields at a given position. + + All fields prefixed with `leaf_*` represent tree leaves, and similarly define tuples of leaves and must have identical length. + + This operator can be used to implement both the previous `TreeEnsembleRegressor` and `TreeEnsembleClassifier` nodes. + The `TreeEnsembleRegressor` node maps directly to this node and requires changing how the nodes are represented. + The `TreeEnsembleClassifier` node can be implemented by adding a `ArgMax` node after this node to determine the top class. + To encode class labels, a `LabelEncoder` or `GatherND` operator may be used. + + + Args: + X: Input of shape [Batch Size, Number of Features] + + aggregate_function: Defines how to aggregate leaf values within a target. +
One of 'AVERAGE' (0) 'SUM' (1) 'MIN' (2) 'MAX (3) defaults to 'SUM' + (1) + + leaf_targetids: The index of the target that this leaf contributes to (this + must be in range `[0, n_targets)`). + + leaf_weights: The weight for each leaf. + + membership_values: Members to test membership of for each set membership + node. List all of the members to test again in the order that the + 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will + have the same number of sets of values as nodes with mode + 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any + 'BRANCH_MEMBER' nodes. + + n_targets: The total number of targets. + + nodes_falseleafs: 1 if false branch is leaf for each node and 0 if an + interior node. To represent a tree that is a leaf (only has one node), + one can do so by having a single `nodes_*` entry with true and false + branches referencing the same `leaf_*` entry + + nodes_falsenodeids: If `nodes_falseleafs` is false at an entry, this + represents the position of the false branch node. This position can be + used to index into a `nodes_*` entry. If `nodes_falseleafs` is false, it + is an index into the leaf_* attributes. + + nodes_featureids: Feature id for each node. + + nodes_hitrates: Popularity of each node, used for performance and may be + omitted. + + nodes_missing_value_tracks_true: For each node, define whether to follow the + true branch (if attribute value is 1) or false branch (if attribute + value is 0) in the presence of a NaN input feature. This attribute may + be left undefined and the default value is false (0) for all nodes. + + nodes_modes: The comparison operation performed by the node. This is encoded + as an enumeration of 0 ('BRANCH_LEQ'), 1 ('BRANCH_LT'), 2 + ('BRANCH_GTE'), 3 ('BRANCH_GT'), 4 ('BRANCH_EQ'), 5 ('BRANCH_NEQ'), and + 6 ('BRANCH_MEMBER'). Note this is a tensor of type uint8. + + nodes_splits: Thresholds to do the splitting on for each node with mode that + is not 'BRANCH_MEMBER'. + + nodes_trueleafs: 1 if true branch is leaf for each node and 0 an interior + node. To represent a tree that is a leaf (only has one node), one can do + so by having a single `nodes_*` entry with true and false branches + referencing the same `leaf_*` entry + + nodes_truenodeids: If `nodes_trueleafs` is false at an entry, this + represents the position of the true branch node. This position can be + used to index into a `nodes_*` entry. If `nodes_trueleafs` is false, it + is an index into the leaf_* attributes. + + post_transform: Indicates the transform to apply to the score.
One of + 'NONE' (0), 'SOFTMAX' (1), 'LOGISTIC' (2), 'SOFTMAX_ZERO' (3) or + 'PROBIT' (4), defaults to 'NONE' (0) + + tree_roots: Index into `nodes_*` for the root of each tree. The tree + structure is derived from the branching of each node. + """ + + schema = get_schema("TreeEnsemble", 5, "ai.onnx.ml") + op = Op(self, "TreeEnsemble", schema) + return op( + *self._prepare_inputs(schema, X), + aggregate_function=aggregate_function, + leaf_targetids=leaf_targetids, + leaf_weights=leaf_weights, + membership_values=membership_values, + n_targets=n_targets, + nodes_falseleafs=nodes_falseleafs, + nodes_falsenodeids=nodes_falsenodeids, + nodes_featureids=nodes_featureids, + nodes_hitrates=nodes_hitrates, + nodes_missing_value_tracks_true=nodes_missing_value_tracks_true, + nodes_modes=nodes_modes, + nodes_splits=nodes_splits, + nodes_trueleafs=nodes_trueleafs, + nodes_truenodeids=nodes_truenodeids, + post_transform=post_transform, + tree_roots=tree_roots, + ) diff --git a/onnxscript/onnx_opset/_impl/opset_ai_onnx_preview_training1.py b/onnxscript/onnx_opset/_impl/opset_ai_onnx_preview_training1.py deleted file mode 100644 index cb201bdf97..0000000000 --- a/onnxscript/onnx_opset/_impl/opset_ai_onnx_preview_training1.py +++ /dev/null @@ -1,577 +0,0 @@ -# -------------------------------------------------------------------------- -# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ -# ⚙️ Generated by 'python -m opgen' -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# pylint: disable=W0221,W0222,R0901,W0237 -# mypy: disable-error-code=override -# ruff: noqa: N801,E741 -# ruff: noqa: D214,D402,D405,D411,D412,D416,D417 -# -------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Optional, Sequence, TypeVar, Union - -from onnx.defs import get_schema -from typing_extensions import TypeAlias - -from onnxscript.onnx_types import ( - BOOL, - COMPLEX64, - COMPLEX128, - DOUBLE, - FLOAT, - FLOAT16, - INT8, - INT16, - INT32, - INT64, - STRING, - UINT8, - UINT16, - UINT32, - UINT64, -) -from onnxscript.values import Op, Opset - - -class Opset_ai_onnx_preview_training1(Opset): - def __new__(cls): - return Opset.__new__(cls, "ai.onnx.preview.training", 1) - - T1_Adagrad = TypeVar("T1_Adagrad", DOUBLE, FLOAT) - - T2_Adagrad: TypeAlias = INT64 - - T3_Adagrad = TypeVar("T3_Adagrad", DOUBLE, FLOAT) - - def Adagrad( - self, - R: T1_Adagrad, - T: T2_Adagrad, - *inputs: T3_Adagrad, - decay_factor: float = 0.0, - epsilon: float = 9.999999974752427e-07, - norm_coefficient: float = 0.0, - ) -> T3_Adagrad: - r"""[🌐 ai.onnx.preview.training::Adagrad(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Adagrad.html#adagrad-1 "Online Documentation") - - - Compute one iteration of ADAGRAD, a stochastic gradient based optimization - algorithm. This operator can conduct the optimization of multiple tensor variables. - - Let's define the behavior of this operator. As you can imagine, ADAGRAD requires - some parameters: - - - The initial learning-rate "R". - - The update count "T". That is, the number of training iterations conducted. - - A L2-norm regularization coefficient "norm_coefficient". - - A learning-rate decay factor "decay_factor". - - A small constant "epsilon" to avoid dividing-by-zero. - - At each ADAGRAD iteration, the optimized tensors are moved along a direction - computed based on their estimated gradient and accumulated squared gradient. Assume - that only a single tensor "X" is updated by this operator. We need the value of "X", - its gradient "G", and its accumulated squared gradient "H". Therefore, variables in - this operator's input list are sequentially "R", "T", "X", "G", and "H". Other - parameters are given as attributes because they are usually constants. Also, the - corresponding output tensors are the new value of "X" (called "X_new"), and then - the new accumulated squared gradient (called "H_new"). Those outputs are computed - from the given inputs following the pseudo code below. - - Let "+", "-", "*", and "/" are all element-wise arithmetic operations with - numpy-style broadcasting support. The pseudo code to compute those outputs is: - - // Compute a scalar learning-rate factor. At the first update of X, T is generally - // 0 (0-based update index) or 1 (1-based update index). - r = R / (1 + T * decay_factor); - - // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm. - G_regularized = norm_coefficient * X + G; - - // Compute new accumulated squared gradient. - H_new = H + G_regularized * G_regularized; - - // Compute the adaptive part of per-coordinate learning rate. Note that Sqrt(...) - // computes element-wise square-root. - H_adaptive = Sqrt(H_new) + epsilon - - // Compute the new value of "X". - X_new = X - r * G_regularized / H_adaptive; - - If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2", the same - pseudo code may be extended to handle all tensors jointly. More specifically, we can view "X" as a - concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should - be concatenated too) and then just reuse the entire pseudo code. - - Note that ADAGRAD was first proposed in http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf. - In that reference paper, this operator is a special case of the Figure 1's composite mirror - descent update. - - - Args: - R: The initial learning rate. - - T: The update count of "X". It should be a scalar. - - inputs: (variadic, heterogeneous) The current values of optimized tensors, - followed by their respective gradients, followed by their respective - accumulated squared gradients.For example, if two tensor "X_1" and "X_2" - are optimized, The input list would be ["X_1", "X_2", gradient of "X_1", - gradient of "X_2", accumulated squared gradient of "X_1", accumulated - squared gradient of "X_2"]. - - decay_factor: The decay factor of learning rate after one update.The - effective learning rate is computed by r = R / (1 + T * decay_factor). - Default to 0 so that increasing update counts doesn't reduce the - learning rate. - - epsilon: Small scalar to avoid dividing by zero. - - norm_coefficient: Regularization coefficient in 0.5 * norm_coefficient * - ||X||_2^2. Default to 0, which means no regularization. - """ - - schema = get_schema("Adagrad", 1, "ai.onnx.preview.training") - op = Op(self, "Adagrad", schema) - return op( - *self._prepare_inputs(schema, R, T, *inputs), - decay_factor=decay_factor, - epsilon=epsilon, - norm_coefficient=norm_coefficient, - ) - - T1_Adam = TypeVar("T1_Adam", DOUBLE, FLOAT) - - T2_Adam: TypeAlias = INT64 - - T3_Adam = TypeVar("T3_Adam", DOUBLE, FLOAT) - - def Adam( - self, - R: T1_Adam, - T: T2_Adam, - *inputs: T3_Adam, - alpha: float = 0.8999999761581421, - beta: float = 0.9990000128746033, - epsilon: float = 9.999999974752427e-07, - norm_coefficient: float = 0.0, - norm_coefficient_post: float = 0.0, - ) -> T3_Adam: - r"""[🌐 ai.onnx.preview.training::Adam(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Adam.html#adam-1 "Online Documentation") - - - Compute one iteration of Adam, a stochastic gradient based optimization - algorithm. This operator can conduct the optimization of multiple tensor variables. - - Let's define the behavior of this operator. First of all, Adam requires - some parameters: - - - The learning-rate "R". - - The update count "T". That is, the number of training iterations conducted. - - A L2-norm regularization coefficient "norm_coefficient". - - A small constant "epsilon" to avoid dividing-by-zero. - - Two coefficients, "alpha" and "beta". - - At each Adam iteration, the optimized tensors are moved along a direction - computed based on their exponentially-averaged historical gradient and - exponentially-averaged historical squared gradient. Assume that only a tensor - "X" is being optimized. The rest of required information is - - - the value of "X", - - "X"'s gradient (denoted by "G"), - - "X"'s exponentially-averaged historical gradient (denoted by "V"), and - - "X"'s exponentially-averaged historical squared gradient (denoted by "H"). - - Some of those parameters are passed into this operator as input tensors and others - are stored as this operator's attributes. Specifically, this operator's input tensor - list is ["R", "T", "X", "G", "V", "H"]. That is, "R" is the first input, "T" is - the second input, and so on. Other parameters are given as attributes because they - are constants. Moreover, the corresponding output tensors are - - - the new value of "X" (called "X_new"), - - the new exponentially-averaged historical gradient (denoted by "V_new"), and - - the new exponentially-averaged historical squared gradient (denoted by "H_new"). - - Those outputs are computed following the pseudo code below. - - Let "+", "-", "*", and "/" are all element-wise arithmetic operations with - numpy-style broadcasting support. The pseudo code to compute those outputs is: - - // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm. - G_regularized = norm_coefficient * X + G - - // Update exponentially-averaged historical gradient. - V_new = alpha * V + (1 - alpha) * G_regularized - - // Update exponentially-averaged historical squared gradient. - H_new = beta * H + (1 - beta) * G_regularized * G_regularized - - // Compute the element-wise square-root of H_new. V_new will be element-wisely - // divided by H_sqrt for a better update direction. - H_sqrt = Sqrt(H_new) + epsilon - - // Compute learning-rate. Note that "alpha**T"/"beta**T" is alpha's/beta's T-th power. - R_adjusted = T > 0 ? R * Sqrt(1 - beta**T) / (1 - alpha**T) : R - - // Compute new value of "X". - X_new = X - R_adjusted * V_new / H_sqrt - - // Post-update regularization. - X_final = (1 - norm_coefficient_post) * X_new - - If there are multiple inputs to be optimized, the pseudo code will be applied - independently to each of them. - - - Args: - R: The initial learning rate. - - T: The update count of "X". It should be a scalar. - - inputs: (variadic, heterogeneous) The tensors to be optimized, followed by - their respective gradients, followed by their respective accumulated - gradients (aka momentum), followed by their respective accumulated - squared gradients. For example, to optimize tensors "X_1" and "X_2,", - the input list would be ["X_1", "X_2", gradient of "X_1", gradient of - "X_2", accumulated gradient of "X_1", accumulated gradient of "X_2", - accumulated squared gradient of "X_1", accumulated squared gradient of - "X_2"]. - - alpha: Coefficient of previously accumulated gradient in running average. - Default to 0.9. - - beta: Coefficient of previously accumulated squared-gradient in running - average. Default to 0.999. - - epsilon: Small scalar to avoid dividing by zero. - - norm_coefficient: Regularization coefficient of 0.5 * norm_coefficient * - ||X||_2^2. Default to 0, which means no regularization. - - norm_coefficient_post: Regularization coefficient of 0.5 * norm_coefficient - * ||X||_2^2. Default to 0, which means no regularization. - """ - - schema = get_schema("Adam", 1, "ai.onnx.preview.training") - op = Op(self, "Adam", schema) - return op( - *self._prepare_inputs(schema, R, T, *inputs), - alpha=alpha, - beta=beta, - epsilon=epsilon, - norm_coefficient=norm_coefficient, - norm_coefficient_post=norm_coefficient_post, - ) - - T1_Gradient = TypeVar( - "T1_Gradient", - BOOL, - COMPLEX128, - COMPLEX64, - DOUBLE, - FLOAT, - FLOAT16, - INT16, - INT32, - INT64, - INT8, - STRING, - UINT16, - UINT32, - UINT64, - UINT8, - ) - - T2_Gradient: TypeAlias = Union[DOUBLE, FLOAT, FLOAT16] - - def Gradient( - self, - *Inputs: T1_Gradient, - xs: Sequence[str], - y: str, - zs: Optional[Sequence[str]] = None, - ) -> T2_Gradient: - r"""[🌐 ai.onnx.preview.training::Gradient(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Gradient.html#gradient-1 "Online Documentation") - - - Gradient operator computes the partial derivatives of a specific tensor w.r.t. - some other tensors. This operator is widely used in gradient-based training - algorithms. To illustrate its use, let's consider a computation graph, - - :: - - X -----. - | - v - W --> Conv --> H --> Gemm --> Y - ^ - | - Z - - - - , where W and Z are trainable tensors. Note that operators' attributes are - omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of - Y with respect to W (Z). The user can compute gradient by inserting Gradient - operator to form another graph shown below. - - :: - - W --> Conv --> H --> Gemm --> Y - | ^ ^ - | | | - | X Z - | | | - | | .----------' - | | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in - | | | "xs" followed by "zs") - | v v - '---> Gradient(xs=["W", "Z"], zs=["X"], y="Y") - | | - | '-----------------------------------> dY/dW (1st output of Gradient) - | - '---------------------------------------> dY/dZ (2nd output of Gradient) - - - - By definition, the tensor "y" is a function of independent variables in "xs" - and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable - variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H" - cannot appear in "xs" and "zs". The reason is that "H" can be determined by - tensors "W" and "X" and therefore "H" is not an independent variable. - - All outputs are optional. If needed, for example, user can assign an empty - string to the 1st output name of that Gradient to skip the generation of dY/dW. - Note that the concept of optional outputs can also be found in ONNX's RNN, GRU, - and LSTM. - - Gradient operator can compute derivative against intermediate tensors. For - example, the gradient of Y with respect to H can be done via - - :: - - W --> Conv --> H --> Gemm --> Y - ^ | ^ - | | | - X | Z - .-------' | - | .----------' - | | (H/Z is the 1st/2nd input of Gradient as shown in "xs") - v v - Gradient(xs=["H", "Z"], y="Y") - | | - | '-----------------------------------> dY/dH (1st output of Gradient) - | - '---------------------------------------> dY/dZ (2nd output of Gradient) - - - - It is possible to represent high-order differentiation using Gradient operators. - For example, given the following linear model: - - :: - - W --> Gemm --> Y --> Loss --> O - ^ ^ - | | - X L - - - - To compute the 2nd order derivative of O with respect to W (denoted by - d^2O/dW^2), one can do - - :: - - W --> Gemm --> Y --> Loss --> O - | ^ ^ - | | | - | X .------------L - | | | | - | | | v - +------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient) - | | | | - | | | '---> dO/dW (2nd output of Gradient) - | v v - '---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of - | Gradient) - | - | - '---> d^2O/dW^2 (2nd output of Gradient) - - - - The tensors named in attributes "xs", "zs", and "y" define the differentiated - computation graph, and the inputs to Gradient node define the values at - which the gradient is computed. We can feed different tensors to the identified - graph. For example, one can compute the gradient of Y with respect to H at - a specific value of H, H_1, by providing that value as an input to the Gradient - node. - - :: - - W --> Conv --> H --> Gemm --> Y - ^ ^ - | | - X Z - - Z_1 (2nd input of Gradient) - | - v - H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1. - | - '------------------------------> dY/dZ (2nd output of Gradient) - - - - When the inputs of Gradient are the tensors named in "xs" and "zs", the - computation can be optimized. More specifically, intermediate variables in - forward pass can be reused if the gradient is computed via reverse-mode - auto-differentiation. - - - - Args: - Inputs: (variadic, heterogeneous) The values fed into graph identified by - the attributes. The i-th input is the value of the i-th tensor specified - in the concatenated list of the attribute "xs" and the attribute "zs". - For example, if xs=["A", "B"] and zs=["C"], the first input is used as - the value of symbol "A" and the 3rd input is substituted for all the - occurrences of "C". - - xs: Input tensor names of the differentiated sub-graph. It contains only the - necessary differentiated inputs of a (sub-)graph. Variables (usually - called intermediate variables) that can be generated from inputs cannot - be included in this attribute. - - y: The targeted tensor. It can be viewed as the output of the differentiated - function. The attribute "xs" and attribute "zs" are the minimal - independent variable set that determines the value of "y". - - zs: Input tensor names of the differentiated sub-graph. It contains only the - necessary non-differentiated inputs of a (sub-)graph. Variables (usually - called intermediate variables) that can be generated from inputs cannot - be included in this attribute. - """ - - schema = get_schema("Gradient", 1, "ai.onnx.preview.training") - op = Op(self, "Gradient", schema) - return op(*self._prepare_inputs(schema, *Inputs), xs=xs, y=y, zs=zs) - - T1_Momentum = TypeVar("T1_Momentum", DOUBLE, FLOAT) - - T2_Momentum: TypeAlias = INT64 - - T3_Momentum = TypeVar("T3_Momentum", DOUBLE, FLOAT) - - def Momentum( - self, - R: T1_Momentum, - T: T2_Momentum, - *inputs: T3_Momentum, - alpha: float, - beta: float, - mode: str, - norm_coefficient: float, - ) -> T3_Momentum: - r"""[🌐 ai.onnx.preview.training::Momentum(1)](https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Momentum.html#momentum-1 "Online Documentation") - - - Compute one iteration of stochastic gradient update with momentum. - This operator can conduct the optimization of multiple tensor variables. - - Let's define the behavior of this operator. As you can imagine, SG with momentum requires - several parameters: - - - The learning-rate "R". - - The update count "T". That is, the number of conducted training iterations. It should - be zero in the first training iteration. - - A L2-norm regularization coefficient "norm_coefficient". - - A decay coefficient of previous accumulated gradient (i.e., momentum) "alpha". - - The scaling coefficient of current gradient "beta". - - An attribute to choose either standard momentum or Nesterov's momentum "mode" should - be used. - - For the sake of simplicity, assume that there is only one tensor (called "X") to be optimized. - Other necessary inputs are "X"'s gradient (called "G") and "X"'s momentum (called "V"). This - Momentum operator maps all these inputs to the new value of "X" (called "X_new") and its new - momentum (called "V_new"). - - This operator supports two different momentum algorithms. Set the attribute "mode" to - "nesterov" if Nesterov's momentum is desired. Otherwise, set the attribute "model" to - "standard" to use standard momentum. Computation details are described subsequently. - - Let "+", "-", "*", and "/" are all element-wise operations with numpy-style broadcasting. - - Pseudo code for SG with standard momentum: - - // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared - // values of all elements in X. - G_regularized = norm_coefficient * X + G - - // In the first training iteration, beta should always be 1. - beta_adjusted = T > 0 ? beta : 1 - - // Compute the current momentum based on previous momentum and the current gradient. - V_new = alpha * V + beta_adjusted * G_regularized - - // Update X. - X_new = X - R * V_new - - Pseudo code for SG with Nesterov's momentum: - - // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared - // values of all elements in X. - G_regularized = norm_coefficient * X + G; - - // In the first training iteration, beta should always be 1. - beta_adjusted = T > 0 ? beta : 1 - - // Compute the current momentum based on previous momentum and the current gradient. - V_new = alpha * V + beta_adjusted * G_regularized; - - // Compute final update direction and then update X. - X_new = X - R * (G_regularized + alpha * V_new) - - If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2". The same - pseudo code would be extended to handle all tensors jointly. More specifically, we can view "X" as a - concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should - be concatenated too) and then our pseudo code becomes applicable. - - - Args: - R: The learning rate. - - T: Update count of "X". It should be a scalar. - - inputs: (variadic, heterogeneous) It sequentially contains the current - values of optimized tensors, then their gradient tensors, and finally - their momentum tensors. For example, if two tensors "X_1" and "X_2" are - optimized, The expected input list would be ["X_1", "X_2", gradient of - "X_1", gradient of "X_2", momentum of "X_1", momentum of "X_2"]. - - alpha: The decay factor of momentum. It should be a scalar. - - beta: The coefficient of gradient in computing new momentum. It should be a - scalar. - - mode: Its value should be either "nesterov" or "standard". The value - "nesterov" leads to the use of Nesterov's momentum while "standard" - invokes stochastic gradient method using standard momentum - - norm_coefficient: Coefficient of 0.5 * norm_coefficient * ||X||^2. - """ - - schema = get_schema("Momentum", 1, "ai.onnx.preview.training") - op = Op(self, "Momentum", schema) - return op( - *self._prepare_inputs(schema, R, T, *inputs), - alpha=alpha, - beta=beta, - mode=mode, - norm_coefficient=norm_coefficient, - ) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 6af57d4b1d..9642e3f111 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations @@ -9,31 +7,27 @@ from typing import ClassVar, Optional, Tuple, Union import onnx -import onnx.helper +import onnx_ir as ir -DType = onnx.TensorProto.DataType +_DType = ir.DataType +_DimType = Union[int, str, type(None)] +_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)] -DimType = Union[int, str, type(None)] +_tensor_type_shape_cache: dict[_DType, TensorType] = {} +tensor_type_registry: dict[_DType, TensorType] = {} -def check_dim(dim): +def _check_dim(dim): if not isinstance(dim, (int, str, type(None))): raise TypeError(f"Invalid dimension {dim}") -ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)] - - -def check_shape(shape): +def _check_shape(shape): if isinstance(shape, tuple): for dim in shape: - check_dim(dim) + _check_dim(dim) elif shape != Ellipsis: - check_dim(shape) - - -tensor_type_registry: dict[DType, TensorType] = {} -_tensor_type_shape_cache: dict[DType, TensorType] = {} + _check_dim(shape) class TensorType(abc.ABC): @@ -60,13 +54,13 @@ class TensorType(abc.ABC): tensor: FLOAT[128, 1024] """ - dtype: ClassVar[DType] - shape: ClassVar[Optional[ShapeType]] + dtype: ClassVar[_DType] + shape: ClassVar[Optional[_ShapeType]] def __new__(cls): raise NotImplementedError("TensorTypes cannot be instantiated") - def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None): + def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None): cls.dtype = dtype cls.shape = shape if shape is None: @@ -78,9 +72,9 @@ def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None): ) tensor_type_registry[dtype] = cls else: - check_shape(shape) + _check_shape(shape) - def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]: + def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]: if cls.shape is not None: raise ValueError("Invalid usage: shape already specified.") if shape is None: @@ -103,98 +97,116 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = cls.shape # example: "FLOAT[10,20]" else: shape = [cls.shape] # example: "FLOAT[10]" - return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251 @classmethod def to_string(cls) -> str: return f"tensor({cls.__name__.lower()})" -class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): +class FLOAT(TensorType, dtype=ir.DataType.FLOAT): + pass + + +class UINT8(TensorType, dtype=ir.DataType.UINT8): + pass + + +class INT8(TensorType, dtype=ir.DataType.INT8): + pass + + +class UINT16(TensorType, dtype=ir.DataType.UINT16): + pass + + +class INT16(TensorType, dtype=ir.DataType.INT16): pass -class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): +class INT32(TensorType, dtype=ir.DataType.INT32): pass -class INT8(TensorType, dtype=onnx.TensorProto.INT8): +class INT64(TensorType, dtype=ir.DataType.INT64): pass -class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): +class STRING(TensorType, dtype=ir.DataType.STRING): pass -class INT16(TensorType, dtype=onnx.TensorProto.INT16): +class BOOL(TensorType, dtype=ir.DataType.BOOL): pass -class INT32(TensorType, dtype=onnx.TensorProto.INT32): +class FLOAT16(TensorType, dtype=ir.DataType.FLOAT16): pass -class INT64(TensorType, dtype=onnx.TensorProto.INT64): +class DOUBLE(TensorType, dtype=ir.DataType.DOUBLE): pass -class STRING(TensorType, dtype=onnx.TensorProto.STRING): +class UINT32(TensorType, dtype=ir.DataType.UINT32): pass -class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): +class UINT64(TensorType, dtype=ir.DataType.UINT64): pass -class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): +class COMPLEX64(TensorType, dtype=ir.DataType.COMPLEX64): pass -class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): +class COMPLEX128(TensorType, dtype=ir.DataType.COMPLEX128): pass -class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): +class BFLOAT16(TensorType, dtype=ir.DataType.BFLOAT16): pass -class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): +class FLOAT8E4M3FN(TensorType, dtype=ir.DataType.FLOAT8E4M3FN): pass -class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): +class FLOAT8E4M3FNUZ(TensorType, dtype=ir.DataType.FLOAT8E4M3FNUZ): pass -class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): +class FLOAT8E5M2(TensorType, dtype=ir.DataType.FLOAT8E5M2): pass -class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): +class FLOAT8E5M2FNUZ(TensorType, dtype=ir.DataType.FLOAT8E5M2FNUZ): pass -class FLOAT8E4M3FN(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FN): +class INT4(TensorType, dtype=ir.DataType.INT4): pass -class FLOAT8E4M3FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FNUZ): +class UINT4(TensorType, dtype=ir.DataType.UINT4): pass -class FLOAT8E5M2(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2): +class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1): pass -class FLOAT8E5M2FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2FNUZ): +class FLOAT8E8M0(TensorType, dtype=ir.DataType.FLOAT8E8M0): pass -def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: +def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto, *, reversible: bool = True) -> str: """Converts an onnx type into the string representation of the type in *onnxscript*. Args: onnx_type: an instance of onnx TypeProto + reversible: if True, the conversion produces only types that are + recognized by the onnxscript converter. Returns: The string representation of the type in onnxscript @@ -218,6 +230,10 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: return name return f"{name}[{','.join(shape)}]" return f"{name}[...]" + if not reversible: + if onnx_type.HasField("sequence_type"): + elem_type = onnx_type.sequence_type.elem_type + return f"List[{onnx_type_to_onnxscript_repr(elem_type)}]" raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.") diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 03c1e748eb..978a1b4d65 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -1,114 +1,131 @@ -import logging -from typing import Any +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import TypeVar + +__all__ = [ + "basic_constant_propagation", + "fold_constants_ir", + "fold_constants", + "inline", + "optimize_ir", + "optimize", + "remove_unused_nodes", +] import onnx -import onnx.shape_inference - -from onnxscript import rewriter -from onnxscript.optimizer.constant_folding import fold_constants -from onnxscript.optimizer.remove_unused import remove_unused_nodes -from onnxscript.optimizer.remove_unused_function import remove_unused_functions -from onnxscript.optimizer.simple_function_folding import ( - inline_functions_with_unused_outputs, - inline_simple_functions, -) -from onnxscript.rewriter import ( - broadcast_to_matmul, - cast_constant_of_shape, - gemm_to_matmul_add, - no_op, -) - -logger = logging.getLogger(__name__) +import onnx_ir.passes.common as common_passes + +import onnxscript.optimizer._constant_folding as constant_folding +from onnxscript import ir +from onnxscript.optimizer._constant_folding import basic_constant_propagation +from onnxscript.optimizer._constant_folding import fold_constants as fold_constants_ir +from onnxscript.optimizer._optimizer import optimize_ir + +_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) def optimize( - model: onnx.ModelProto, + model: _ModelProtoOrIr, num_iterations: int = 2, *, onnx_shape_inference: bool = True, stop_if_no_change: bool = True, - external_data_folder: str = "", - **kwargs: Any, -) -> onnx.ModelProto: - """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. + input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, + output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + inline: bool = True, +) -> _ModelProtoOrIr: + """Optimizes a model. Args: - model (onnx.ModelProto): The model to optimize. - num_iterations (int, optional): Number of iterations to perform. - onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. - Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. - This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries - the symbolic shapes recorded from dynamo tracing. - stop_if_no_change (bool, optional): Whether to stop if no change is detected. - external_data_folder (str, optional): The folder to store external data. - **kwargs: Additional keyword arguments. For BC purposes. + model: The model to be optimized. + num_iterations: Number of times the optimization loop is repeated. + onnx_shape_inference: Applies node-level shape-inference as part of optimization + input_size_limit: Will not apply constant folding to ops with any input of size + greater than this. Does not apply to special ops like Shape() and Size(). + output_size_limit: Will not rewrite any foldable-op into a Constant op if the size + of the output tensor is greater than this. + stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. + inline: If True, inlines all functions in the model. + + Returns: + The optimized model. If the input was a ModelProto, the output will also be a + ModelProto. If the input was an ir.Model, the output will also be an ir.Model. """ - if kwargs.pop("function_aware_folding", None) is not None: - logger.warning( - "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " - "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " - "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " - "See 'onnx_shape_inference' for more details." + if isinstance(model, ir.Model): + # In this case, optimize is done inplace. + # TODO(justinchuby): Maybe make functional + optimize_ir( + model, + num_iterations=num_iterations, + onnx_shape_inference=onnx_shape_inference, + stop_if_no_change=stop_if_no_change, + input_size_limit=input_size_limit, + output_size_limit=output_size_limit, + inline=inline, ) - for _ in range(num_iterations): - if onnx_shape_inference: - if model.ByteSize() < 1024 * 1024 * 1024 * 2: - # NOTE: strict mode is disabled because it crashes on the models - # that have different shapes inferred from the model carried shapes. - # The case can be found in: - # https://github.com/microsoft/onnxscript/issues/1443 - model = onnx.shape_inference.infer_shapes( - model, check_type=True, strict_mode=False, data_prop=True - ) - else: - logger.warning( - "The model size is too large for full model shape inference. " - "Skipping this step." - ) - - inline_simple_functions(model) - modified = fold_constants( - model, external_data_folder, onnx_shape_inference=onnx_shape_inference + return model + else: + assert isinstance(model, onnx.ModelProto) + model_ir = ir.serde.deserialize_model(model) + optimize_ir( + model_ir, + num_iterations=num_iterations, + onnx_shape_inference=onnx_shape_inference, + stop_if_no_change=stop_if_no_change, + input_size_limit=input_size_limit, + output_size_limit=output_size_limit, + inline=inline, ) + # Move the model back to the proto + new_proto = ir.serde.serialize_model(model_ir) + return new_proto - remove_unused_nodes(model) - inline_simple_functions(model) - remove_unused_functions(model) - inline_functions_with_unused_outputs(model) - # NOTE: This is general rewrite rules - model = rewriter.rewrite( - model, - pattern_rewrite_rules=[ - *no_op.rules.rules, # TODO: merge this rule into constant folding? - *broadcast_to_matmul.rules.rules, - gemm_to_matmul_add.rule, - *cast_constant_of_shape.rules.rules, - ], - ) - if stop_if_no_change and not modified: - logger.debug("Stopping after %d iterations.", _) - break - for node in model.graph.node: - logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) +def inline(model: ir.Model) -> None: + """Inline all function calls (recursively) in the model.""" + if model.functions: + common_passes.InlinePass()(model) - for function in model.functions: - for node in function.node: - logger.debug( - "Function %s::%s node %s::%s name %s.", - function.domain, - function.name, - node.domain, - node.op_type, - node.name, - ) - return model +def fold_constants( + model: ir.Model | onnx.ModelProto, *args, **kwargs +) -> constant_folding.FoldConstantsResult: + """Fold constants in a model in place.""" + if isinstance(model, ir.Model): + return constant_folding.fold_constants(model, *args, **kwargs) + else: + assert isinstance(model, onnx.ModelProto) + model_proto = model + model = ir.serde.deserialize_model(model_proto) + result = constant_folding.fold_constants(model, *args, **kwargs) + # Move the model back to the proto + new_proto = ir.serde.serialize_model(model) + model_proto.Clear() + model_proto.CopyFrom(new_proto) + return result -__all__ = [ - "fold_constants", - "remove_unused_nodes", - "optimize", -] +def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: + """Removes unused nodes from a model inplace.""" + if isinstance(model, ir.Model): + common_passes.RemoveUnusedNodesPass()(model) + else: + model_ir = ir.serde.deserialize_model(model) + model_ir = common_passes.RemoveUnusedNodesPass()(model_ir).model + new_proto = ir.serde.serialize_model(model_ir) + model.Clear() + model.CopyFrom(new_proto) + + +def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None: + """Removes unused functions from a model inplace.""" + if isinstance(model, ir.Model): + common_passes.RemoveUnusedFunctionsPass()(model) + else: + model_ir = ir.serde.deserialize_model(model) + model_ir = common_passes.RemoveUnusedFunctionsPass()(model_ir).model + new_proto = ir.serde.serialize_model(model_ir) + model.Clear() + model.CopyFrom(new_proto) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py new file mode 100644 index 0000000000..9a740c783c --- /dev/null +++ b/onnxscript/optimizer/_constant_folding.py @@ -0,0 +1,1372 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files. + +from __future__ import annotations + +__all__ = [ + "basic_constant_propagation", + "fold_constants", + "FoldConstantsPass", + "FOLDED_FROM_KEY", +] + +import dataclasses +import logging +import math +import typing +from typing import Any, Callable, Iterable, Sequence, Union + +import numpy as np +import onnx +import onnx.reference.ops +import onnx_ir as ir + +import onnxscript.utils.utils as utils +from onnxscript.ir import _tape + +DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 8192 + +DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512 + +# Key used to store the metadata +FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from" + + +_NON_DETERMINISTIC_OPS = frozenset( + { + "RandomUniform", + "RandomNormal", + "RandomUniformLike", + "RandomNormalLike", + "Multinomial", + } +) + +# A list of ops to always fold regardless of their input size limits, as long as +# they are the single consumer of the large input tensors +_DEFAULT_ALWAYS_FOLD_OPS = frozenset( + { + ("", "Transpose"), + } +) + +logger = logging.getLogger(__name__) + + +def _is_control_flow_op(node: ir.Node) -> bool: + graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS} + return any(attr.type in graph_types for attr in node.attributes.values()) + + +def _is_non_deterministic_op(node: ir.Node) -> bool: + return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain) + + +def _is_onnx_op(node: ir.Node, op_type: str) -> bool: + return node.op_type == op_type and utils.is_onnx_domain(node.domain) + + +# "Standard" evaluators are used to perform constant-folding. +# The API below works only for non-control-flow ops (ops without any graph-attributes). +# This currently used ONNX's reference implementation. But we could also +# use ORT's implementation if we want to. + + +def _process_constant_node(node: ir.Node) -> None: + """Sets const_value of output value of a Constant op node.""" + if not _is_onnx_op(node, "Constant"): + return + if len(node.attributes) != 1: + return + attr_name, attr_value = next(iter(node.attributes.items())) + if len(node.outputs) != 1: + return + ir_value = node.outputs[0] + + if attr_value is None or not isinstance(attr_value, ir.Attr): + return + + # Even if this is an attribute, the value property might not be set, which + # happens e.g. in case of attribute references, i.e., ref_attr_name is set + if attr_value.value is None: + # For now reject this to prevent TypeError from accessing Nones below + return + + const_value: ir.TensorProtocol + if attr_name in {"value_float", "value_floats"}: + const_value = ir.Tensor( + np.array(attr_value.value, dtype=np.float32), name=ir_value.name + ) + elif attr_name in {"value_int", "value_ints"}: + const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name) + elif attr_name in {"value_string", "value_strings"}: + const_value = ir.StringTensor( + np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name + ) + elif attr_name == "value": + const_value = typing.cast(ir.TensorProtocol, attr_value.value) + else: + return + + ir_value.const_value = const_value + ir_value.shape = const_value.shape # type: ignore + ir_value.dtype = const_value.dtype + + +def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None: + """Performs basic constant propagation for a sequence of nodes. + + Just marks the output values of Constant op nodes with their const_value. + """ + for node in nodes: + _process_constant_node(node) + + +class ReferenceEvaluator: + def get_evaluator(self, domain: str, op: str, version: int) -> Callable | None: + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + return op_impl_class.eval # noqa: TRY300 + except Exception: + return None + + def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: + logger.debug("Evaluating %s::%s", domain, op) + evaluator = self.get_evaluator(domain, op, version) + if evaluator is None: + return None + try: + return evaluator(*args, **kwargs) + except Exception as e: + logger.warning("Evaluation failed: %s", e) + return None + + +_reference_evaluator = ReferenceEvaluator() + + +@dataclasses.dataclass +class Replacement: + """A replacement for a node in the graph.""" + + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + + +# The optimizer tracks an optional symbolic value for each value in the model. +# The symbolic value attached to a value X can be: +# - another IR value Y (indicating that X is equal to Y) +# - a list of IR values [Y1, Y2, ...] (indicating that X is a sequence of values Y1, Y2, ...) +# - a Shape object (indicating that X is a shape value) +# A Shape object as a symbolic value indicates that the corresponding value is +# 1-D (or 0-D) tensor of INT64 values. The values in this object may be constants +# or symbolic dimension values (like "batch_size", "sequence_length", etc.). +# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative. +# TODO: Add support for negative symbolic dimensions. + +SymbolicValue = Union[ir.Value, list[ir.Value], ir.Shape] + + +class OptimizerState: + def __init__(self): + self._sym_value_map: dict[ir.Value, SymbolicValue] = {} + self._initializer_inputs: list[set[ir.Value]] = [] + + @property + def symbolic_value_map(self) -> dict[ir.Value, SymbolicValue]: + return self._sym_value_map + + def get_sym_value(self, value: ir.Value | None) -> SymbolicValue | None: + if value is None: + return None + return self._sym_value_map.get(value) + + def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None: + self._sym_value_map[value] = sym_value + + def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None: + const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10) + if const_value is not None: + if const_value.ndim == 1: + return ir.Shape(const_value.tolist()) + return None + sym_value = self.get_sym_value(value) + if isinstance(sym_value, ir.Shape): + return sym_value + # TODO use shape of value if available + return None + + +# The "partial evaluators" below are non-standard evaluators. They are used to perform +# partial evaluation and/or static program analysis (abstract interpretation). + +# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns +# a Replacement for the node or None (if no replacement is needed). It may also return just +# the ir.Value or ir.Values to replace the output values of the node, when the new nodes +# can be inferred from the RewriterContext used to build the new nodes. + +RewriterContext = _tape.Builder +ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] +PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue] + + +@dataclasses.dataclass +class PartialEvaluator: + """A class that represents a partial-evaluator for a particular op. + + It is applicable for a specific version range (min_version, max_version) of the op. + The min_version and max_version can be None, indicating that there is no version + constraint in that direction. + """ + + min_version: int | None + max_version: int | None + function: PartialEvaluatorFunction + + def valid_for(self, version: int) -> bool: + """Returns True if this evaluator is applicable for the given version.""" + return (self.min_version is None or version >= self.min_version) and ( + self.max_version is None or version <= self.max_version + ) + + +class PartialEvaluatorRegistry: + """A class that maintains a registry of evaluators for ops.""" + + def __init__(self): + self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} + + def lookup_evaluators(self, domain: str, opname: str, version: int): + evaluator_list = self.op_evaluators.get((domain, opname), []) + return [ + evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) + ] + + def register( + self, opname: str, domain: str = "", version=None + ) -> Callable[[PartialEvaluatorFunction], PartialEvaluatorFunction]: + if (domain, opname) in self.op_evaluators: + evaluator_list = self.op_evaluators[(domain, opname)] + else: + evaluator_list = [] + self.op_evaluators[(domain, opname)] = evaluator_list + if version is None: + min_version = None + max_version = None + elif isinstance(version, int): + min_version = version + max_version = version + elif isinstance(version, tuple): + min_version, max_version = version + + def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: + evaluator_list.append(PartialEvaluator(min_version, max_version, function)) + return function + + return decorator + + +registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() + +register = registry.register + + +def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool: + # Comparison of shapes as tuples works except if any dimension is None + # (which represents an unknown dimension value). Thus, two shapes such + # as (Batch, 1024) and (Batch, 1024) are considered equal, but (None, 1024) + # and (None, 1024) are not considered equal. + if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1): + return False + return shape1.dims == shape2.dims + + +def _get_numpy_value( + val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None +) -> np.ndarray | None: + """Returns the numpy value of a constant value, if available. + + It returns None if the value is not a constant value, or if the value is not of + the specified element dtype, or if the size of the value exceeds the specified + size_limit. + """ + if val is None: + return None + const_value = val.const_value + if const_value is not None: + if dtype is not None and const_value.dtype != dtype: + return None + if size_limit is not None and const_value.size > size_limit: + return None + try: + # Turn the constant value into a numpy array representation with the + # specifics of this conversion handled by the tensor type + array = const_value.numpy() + # Can/should not reinterpret strings via .view, resulting in + # "TypeError: Cannot change data-type for array of references." + # There is also no reason to reinterpret strings, this is only + # relevant for some arithmetic types + if const_value.dtype != ir.DataType.STRING: + # Reinterpret the array with `.view()` because some + # implementations of ir.TensorProtocol (e.g. PyTorch<=2.7) do + # not use ml_dtypes for bfloat16 etc. + array = array.view(const_value.dtype.numpy()) + except FileNotFoundError: + # External data is not available. + logger.warning( + "External data for value '%s' is not available. " + "This may lead to incorrect constant folding.", + val.name, + ) + return None + assert isinstance(array, np.ndarray) + return array + return None + + +def _get_bool_value(val: ir.Value | None) -> bool | None: + if val is None: + return None + value = _get_numpy_value(val) + if value is None: + return None + if value.size == 1 and value.dtype == bool: + return value.item(0) + return None + + +def _get_input(node: ir.Node, index: int) -> ir.Value | None: + if index < len(node.inputs): + return node.inputs[index] + return None + + +def _get_output(node: ir.Node, index: int) -> ir.Value | None: + if index < len(node.outputs): + return node.outputs[index] + return None + + +def _get_input_element_type(node: ir.Node, index: int) -> int: + input = _get_input(node, index) + if input is not None and input.type is not None: + return input.type.dtype.value + return ir.DataType.UNDEFINED.value + + +def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: + if name in node.attributes: + attr = node.attributes[name] + if not isinstance(attr, ir.Attr): + return None + attr_val = attr.value + if isinstance(attr_val, int): + return attr_val + # This is an invalid model: attribute has invalid/unexpected type. + # For now, we just return None. We could raise an error too. + return None + return default + + +@register("Add") +def add(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Propagate symbolic dim values.""" + + def get_dim_value(input_index): + input = _get_input(node, input_index) + if input is None: + return None + shape_value: ir.Shape | None = state.get_shape_value(input) + if shape_value is None or len(shape_value) != 1: + return None + dim: int | ir.SymbolicDim = shape_value[0] + return dim if isinstance(dim, int) else dim.value + + dim0 = get_dim_value(0) + dim1 = get_dim_value(1) + if dim0 is None or dim1 is None: + return None + if isinstance(dim0, int) and isinstance(dim1, int): + result_dim_value: int | ir.SymbolicDim = dim0 + dim1 + else: + result_dim_value = ir.SymbolicDim(f"{dim0}+{dim1}") + output = _get_output(node, 0) + if output is not None: + state.set_sym_value(output, ir.Shape([result_dim_value])) + + +@register("Abs") +def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace an Abs node by Identity when applicable. + + Currently, addresses Abs applied to symbolic shapes. + """ + input = _get_input(node, 0) + input_sym_value = state.get_shape_value(input) + if input_sym_value is None: + return None + if any(isinstance(d, int) and d < 0 for d in input_sym_value): + return None + # Abs applied to a symbolic shape of the form [1, 1, SequenceLength]. + # We assume that SequenceLength is a non-negative integer. + # The Abs op is redundant in this case. + return op.Identity(input) + + +@register("Gather") +def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Gather node by a constant when applicable. + + Currently, handles the case of Gathering from a shape tensor. + """ + input = _get_input(node, 0) + indices = _get_input(node, 1) + if input is None or indices is None: + return None + input_sym_value = state.get_shape_value(input) + if input_sym_value is None: + return None + axis = _get_int_attribute(node, "axis", None) + if axis != 0: + return None + indices_numpy_value = _get_numpy_value(indices) + if indices_numpy_value is None: + return None + if indices_numpy_value.ndim != 1: + return None + gathered = [input_sym_value[i] for i in indices_numpy_value] + output = _get_output(node, 0) + if output is not None: + state.set_sym_value(output, ir.Shape(gathered)) + if all(isinstance(d, int) for d in gathered): + return op.Constant(value_ints=ir.AttrInt64s("value_ints", gathered)) + return None + + +def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Propagates symbolic shape value of input 0 to output 0. + + Applies to ops like Reshape/Squeeze/Unsqueeze where the shape of the tensor may change + but the values in the tensor remain the same. + """ + input = _get_input(node, 0) + input_shape_value = state.get_shape_value(input) + output = _get_output(node, 0) + if output is not None and input_shape_value is not None: + state.set_sym_value(output, input_shape_value) + return None + + +@register("Reshape") +def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Reshape node by Identity when applicable. + + Also propagate symbolic shape values. + """ + input = _get_input(node, 0) + shape = _get_input(node, 1) + if input is None or shape is None: + return None + + input_shape = input.shape + shape_value = state.get_shape_value(shape) + + if shape_value is None or input_shape is None: + return _propagate_shape_value(node, op, state) + + # No need to check for special values like -1, 0, etc. here + if _same_shape(input_shape, shape_value): + return op.Identity(input) + return _propagate_shape_value(node, op, state) + + +@register("Squeeze") +def squeeze(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Propagate symbolic shape values.""" + return _propagate_shape_value(node, op, state) + + +@register("Cast") +def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = _get_input(node, 0) + output = _get_output(node, 0) + + if input is None or output is None: + return None + + # TODO(rama): Parts of the following logic (implementing type/shape inference + # for Cast op) should be unnecessary. Generic incremental shape-inference + # should handle this. Only the optimization to eliminate redundant Cast ops + # should be needed here. + + output.shape = _merge_shapes(output.shape, input.shape) + + input_dtype = _get_input_element_type(node, 0) + output_dtype = _get_int_attribute(node, "to", None) + if output_dtype is not None: + if input_dtype == output_dtype: + return op.Identity(input) + output.type = ir.TensorType(ir.DataType(output_dtype)) + return None + + +@register("CastLike") +def cast_like(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input0 = node.inputs[0] + source_element_type = _get_input_element_type(node, 0) + target_element_type = _get_input_element_type(node, 1) + + if target_element_type == ir.DataType.UNDEFINED: + return None + if source_element_type == target_element_type: + return op.Identity(input0) + return op.Cast(input0, to=target_element_type) + + +@register("Shape") +def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = node.inputs[0] + if input is None: + return None + shape = input.shape + if shape is None: + return None + start = _get_int_attribute(node, "start", 0) + end = _get_int_attribute(node, "end", None) + shape_slice = shape[start:end] + output = _get_output(node, 0) + if output is not None: + state.set_sym_value(output, ir.Shape(shape_slice)) + if all(isinstance(d, int) for d in shape_slice): + return op.Constant(value_ints=ir.AttrInt64s("value_ints", list(shape_slice))) + return None + + +@register("Size") +def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = _get_input(node, 0) + if input is None: + return None + shape = input.shape + if shape is None: + return None + size = 1 + for d in shape: + if not isinstance(d, int): + return None + size *= d + return op.Constant(value_int=size) + + +@register("If") +def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + cond_input = _get_input(node, 0) + cond = _get_bool_value(cond_input) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph_attr = node.attributes.get(branch) + if graph_attr is None: + return None + if graph_attr.type != ir.AttributeType.GRAPH: + return None + assert isinstance(graph_attr, ir.Attr) + graph = graph_attr.as_graph() + # Copy the graph outputs and clear the graph outputs so that the values are free to move + formal_outs = list(graph.outputs) + graph.outputs.clear() + actual_outs = node.outputs + renamings = { + formal.name: actual.name + for formal, actual in zip(formal_outs, actual_outs) + if actual is not None + } + # TODO: Extend renaming to intermediate values. + + def rename(name): + return renamings.get(name, name) + + graph_nodes = list(graph) + graph.remove(graph_nodes) + for sub_node in graph_nodes: + # TODO: handle renaming inside subgraphs in nodes + for v in sub_node.outputs: + v.name = rename(v.name) + # Avoid name collision. + sub_node.name = f"{node.name}_{sub_node.name}" + + # TODO: we should handle initializers as well! + return Replacement(formal_outs, graph_nodes) + return None + + +@register("Identity") +def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + del op + input = node.inputs[0] + output = node.outputs[0] + if input is not None and output is not None: + input.shape = _merge_shapes(input.shape, output.shape) + if input.type is None: + input.type = output.type + state.set_sym_value(output, input) + return None + + +@register("SequenceConstruct") +def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + del op + output = node.outputs[0] + if output is not None: + state.set_sym_value(output, list(node.inputs)) + return None + + +@register("Concat") +def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Concat node with a single input by Identity""" + + # Replace Concat(x) by Identity(x) + inputs = node.inputs + if len(inputs) == 1: + return op.Identity(inputs[0]) + + axis = _get_int_attribute(node, "axis", None) + if axis is None: + return None + + # Eliminate zero-length operands from Concat + def has_zero_size(operand: ir.Value | None) -> bool: + if operand is None: + return False # Invalid model + if (shape := operand.shape) is None: + return False + try: + # We have already checked that axis is an int value (!= None) + dim_size = shape[axis] # type: ignore[index] + except IndexError: + return False + return dim_size == 0 # return False if symbolic or None or non-zero int value + + new_inputs = [x for x in inputs if not has_zero_size(x)] + if len(new_inputs) != len(inputs): + if new_inputs: + # Remove zero-length operands from Concat + logger.debug( + "Concat: removing zero-length operand(s) %s => %s", inputs, new_inputs + ) + return op.Concat(*new_inputs, axis=axis) + elif inputs: + # All operands are zero-length. Concat is a no-op, but we need to use one of the + # inputs to get the other dimensions correct: + logger.debug("Concat: removing all zero-length operands %s", inputs) + return op.Identity(inputs[0]) + else: + # No inputs: invalid model. + return None + + # Track value of tensors that carry a shape value: + + # Check axis attribute is 0 + + if axis != 0: + return None + shapes = [state.get_shape_value(input) for input in inputs] + if any(shape is None for shape in shapes): + return None + concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr] + output = node.outputs[0] + if output is None: + return None + state.set_sym_value(output, concatenated) + return None + + +@register("Dropout", version=(12, None)) +def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Dropout by Identity when applicable.""" + + def optimized_dropout(): + input = node.inputs[0] + output = op.Identity(input) + if len(node.outputs) == 1: + return output + else: + true_tensor = ir.tensor([True]) + input_shape = op.Shape(input) + mask = op.ConstantOfShape(input_shape, value=true_tensor) + return output, mask + + inputs = node.inputs + if (len(inputs) <= 2) or inputs[2] is None: + # No training_mode specified: + return optimized_dropout() + if _get_bool_value(inputs[2]) is False: + # training_mode is False: dropout is not applied. + return optimized_dropout() + ratio = _get_numpy_value(inputs[1]) + if ratio is None: + return None + if ratio.size != 1: # Only scalar dropout ratio is supported. + return None + if ratio.item() == 0: + # dropout ratio is 0: dropout is not applied. + return optimized_dropout() + return None + + +@register("Expand") +def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace an Expand node by Identity when applicable.""" + if len(node.inputs) != 2: + return None + if (input := node.inputs[0]) is None: + return None + if (input_shape := input.shape) is None: + # Input shape is not known. + return None + if (expanded_shape := _get_numpy_value(node.inputs[1])) is None: + # Target shape is not known. + expanded_sym_shape = state.get_shape_value(node.inputs[1]) + if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape): + return None + return op.Identity(input) + if expanded_shape.ndim != 1: + # Target shape must be a 1D tensor. Erroneous model. + return None + if input_shape.dims == tuple(expanded_shape.tolist()): + return op.Identity(input) + return None + + +@register("ConcatFromSequence") +def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = node.inputs[0] + inputs = state.get_sym_value(input) + if inputs is None or any(x is None for x in inputs): + return None + new_axis = _get_int_attribute(node, "new_axis", 0) + axis = _get_int_attribute(node, "axis", None) + if axis is None: + return None + if input is not None and isinstance(inputs, list): + if new_axis == 0: + logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs]) + return op.Concat(*inputs, axis=axis) + if new_axis == 1: + # Unsqueeze the inputs with concat axis if new_axis is 1 + axis_value = op.Constant(value_int=axis) + unsqueezed_inputs = [] + for node_input in inputs: + unsqueezed_input = op.Unsqueeze( + node_input, axis_value, _outputs=[f"{node_input.name}_unsqueeze"] + ) + unsqueezed_inputs.append(unsqueezed_input) + # Send unsqueezed outputs to Concat + logger.debug( + "ConcatFromSequence => Concat %s", [x.name for x in unsqueezed_inputs] + ) + return op.Concat(*unsqueezed_inputs, axis=axis) + return None + + +@register("SplitToSequence") +def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Rewriting pattern. + + From + + splits = onnx::SplitToSequence(input, split, axis=axis) + + to + + split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + or + + split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + where number of output tensors in `splits` is statically known. + onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. + This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. + """ + input = node.inputs[0] + if len(node.inputs) == 1: + # split is not provided + return None + split = node.inputs[1] + output = node.outputs[0] + + if input is None or split is None or output is None: + return None + + axis = _get_int_attribute(node, "axis", 0) + if axis is None: + return None + shape = input.shape + if shape is None: + return None + rank = len(shape) + if axis < 0: + axis = axis + rank + if axis < 0 or axis >= rank: + return None + + # NOTE: Split needs to either be a scalar or a 1-D tensor. We need to + # calculate the number of outputs for Split. + # If split is a scalar, we split into chunks of size 'split' if possible. + # * the split dimension size and split_value has to be known. + # If split is a 1-D tensor, we split into 'size(split)' chunks + # * Get the size from split_value if it's numpy array. + # * Get the size from symbolic shape if split_value is not available. + split_value = _get_numpy_value(split) + split_shape = ( + split.shape.numpy() if split.shape is not None and split.shape.is_static() else None + ) + + # No information about split value or shape. + if split_value is None and split_shape is None: + return None + + if isinstance(split_shape, tuple) and len(split_shape) == 1: + # If split_shape is known, we can use it to determine the number of outputs. + split_dimension_size = split_shape[0] + assert isinstance(split_dimension_size, int) + num_outputs = split_dimension_size + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) + elif split_value.ndim == 1: + # split into 'size(split)' chunks + num_outputs = split_value.size + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) + elif split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + split_dimension_size = shape[axis] + if not isinstance(split_dimension_size, int): + return None + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split( + input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs + ) + else: + return None + + # If Split returns a single value, we need to wrap it into a list. + if isinstance(split_values, ir.Value): + split_values = [split_values] + + keepdims = _get_int_attribute(node, "keepdims", 1) + if keepdims is None: + return None + if keepdims == 0: + # squeeze the split dimension if keepdims is 0 + axis_val = op.Constant(value_ints=[axis], _outputs=[f"{output.name}_axis"]) + squeezed_values = [] + for i in range(num_outputs): + squeezed = op.Squeeze( + split_values[i], axis_val, _outputs=[f"{split_outputs[i]}_squeeze"] + ) + squeezed_values.append(squeezed) + split_values = squeezed_values + + logger.debug("SplitToSequence => Split + SequenceConstruct") + + if isinstance(split_values, ir.Value): + split_values = [split_values] + return op.SequenceConstruct(*split_values) + + +@register("SequenceAt") +def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = node.inputs[0] + position = node.inputs[1] + output = node.outputs[0] + if input is not None and position is not None: + input_vals = state.get_sym_value(input) + position_val = _get_numpy_value(position) + if isinstance(input_vals, list) and position_val is not None: + if position_val.size != 1: + return None + position_val = position_val.item() + try: + result = input_vals[position_val] # type: ignore[index] + except IndexError: + return None + state.set_sym_value(output, result) + logger.debug("SequenceAt %s => %s", input.name, result.name) + return op.Identity(result) + return None + + +def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None: + def merge_dims(dim1, dim2): + if dim1 == dim2: + return dim1 + if not isinstance(dim1, ir.SymbolicDim): + return dim1 # Prefer int value over symbolic dim + if not isinstance(dim2, ir.SymbolicDim): + return dim2 + if dim1.value is None: + return dim2 + return dim1 + + if shape1 is None: + return shape2 + if shape2 is None: + return shape1 + if len(shape1) != len(shape2): + raise ValueError("Shapes must have the same rank.") + return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) + + +def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None: + """Record the set of original input values that contributed to the constant-folded outputs.""" + folded_from: set[str] = set() + for input in original_node.inputs: + if input is None: + continue + folded_from.update(input.meta.get(FOLDED_FROM_KEY, set())) + assert input.name is not None + folded_from.add(input.name) + + for new_output in replacement.new_outputs: + if new_output is None: + continue + new_output.meta[FOLDED_FROM_KEY] = folded_from + # Store the string representation of the set to metadata_props to persist it across serialization + new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from)) + + +class FoldConstantsPass(ir.passes.InPlacePass): + """A pass that folds constant expressions in the model. + + Attributes: + shape_inference: Whether to perform shape inference. + input_size_limit: Maximum size of input tensors to fold. + output_size_limit: Maximum size of output tensors to fold. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding. + The function should return True/False value to indicate if this particular + node should be folded, or None to use the default folding rules. + """ + + def __init__( + self, + *, + shape_inference: bool, + input_size_limit: int, + output_size_limit: int, + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, + ) -> None: + self.shape_inference = shape_inference + self.input_size_limit = input_size_limit + self.output_size_limit = output_size_limit + self.should_fold = should_fold + + self._opset_imports: dict[str, int] = {} + self._counts: dict[str, int] = {} + self._sizes: dict[str, int] = {} + self._modified: bool = False + self._state = OptimizerState() + self._reset() + + def _reset(self) -> None: + """Reset internal states for a new run.""" + self._counts = {} + self._sizes = {} + self._modified = False + self._state = OptimizerState() + + def _do_inference(self, node: ir.Node) -> None: + output_types = {} + + # TODO: handle optional inputs + def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: + value = _get_numpy_value(x, size_limit=20) + if value is not None: + assert x.const_value is not None + return ir.serde.serialize_tensor(x.const_value) + return None + + def get_type(value: ir.Value) -> onnx.TypeProto | None: + if value.type is not None: + type_proto = ir.serde.serialize_type(value.type) + if value.shape is not None: + ir.serde.serialize_shape_into(type_proto, value.shape) + return type_proto + return None + + input_types = {x.name: get_type(x) for x in node.inputs if x is not None} + input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} + input_data = {k: v for k, v in input_data.items() if v is not None} + if any(t is None for t in input_types.values()): + logger.debug( + "Skipping shape inference for node %r due to missing input type.", + node.name, + ) + else: + # TODO: pass in constant values, ir_version + try: + schema = onnx.defs.get_schema( + node.op_type, self._opset_imports[node.domain], node.domain + ) + output_types = onnx.shape_inference.infer_node_outputs( + schema, + ir.serde.serialize_node(node), + input_types, # type: ignore[arg-type] + input_data, # type: ignore[arg-type] + ) + for output in node.outputs: + if output.name in output_types: + inferred_type = output_types[output.name] + # TODO: merge types, check for conflicts + inferred_shape = ir.serde.deserialize_type_proto_for_shape( + inferred_type + ) + output.shape = _merge_shapes(output.shape, inferred_shape) + output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) + except Exception as e: + logger.debug( + "Skipping shape inference for node %r due to exception: %s", + node.name, + e, + ) + + def new_constant(self, node: ir.Node, value) -> ir.Node | None: + irvalue = node.outputs[0] + if not isinstance(value, np.ndarray): + # ONNX does not have a way to represent non-tensor constants, eg. a sequence. + # So, a constant-value of type sequence is not folded, but it can be used + # to optimize subsequent operations when possible. + logger.info( + "Skip storing constant folded value %s due to unsupported type %s.", + irvalue.name, + type(value), + ) + return None + + tensor = ir.tensor(value) + tensor.name = irvalue.name + irvalue.const_value = tensor + + if value.size > self.output_size_limit: + # Handle examples like Transpose(weight) to be folded even if the size is large, + # as long as weight has no other uses. This won't increase model size. + removed_input_size = 0 + for input in node.inputs: + if (input is not None) and (len(input.uses()) == 1): + array = _get_numpy_value(input) + if array is not None: + removed_input_size += array.size + increased_size = value.size - removed_input_size + if increased_size > 0: + logger.info( + "Skip storing constant folded nvalue %s due to large size %s.", + irvalue.name, + value.size, + ) + return None + + logger.debug( + "New constant for value %s dtype: %s shape: %s", + irvalue.name, + value.dtype, + value.shape, + ) + + attributes = ir.convenience.convert_attributes({"value": tensor}) + node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) + return node + + def process_node(self, node: ir.Node) -> Replacement | None: + """Process a node and return a Replacement if the node can be replaced.""" + for i, value in enumerate(node.inputs): + sym_value = self._state.get_sym_value(value) + if isinstance(sym_value, ir.Value): + logger.debug( + "Node [%s]: Replacing input %s with %s", + node.name, + value.name, # type: ignore[union-attr] + sym_value.name, + ) + node.replace_input_with(i, sym_value) + self._modified = True + # TODO(rama): consider merging type/other info from both values + + # Propagate const_value, and manually find out shape and type + # to avoid potentially expensive shape inference on large tensors. + if _is_onnx_op(node, "Constant"): + _process_constant_node(node) + # Do incremental shape inference + elif self.shape_inference and not _is_control_flow_op(node): + self._do_inference(node) + + if node.domain not in self._opset_imports: + return None + version = self._opset_imports[node.domain] + op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) + for optimizer in op_optimizers: + assert optimizer + context = RewriterContext() + output = optimizer(node, context, self._state) + if output is not None: + if isinstance(output, Replacement): + return output + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) + + if _is_onnx_op(node, "Constant"): + logger.debug("Skipping constant folding for Constant node %r", node.name) + return None + + if _is_control_flow_op(node): + logger.info( + "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet", + node.name, + node.domain, + node.op_type, + ) + + return None + + if _is_non_deterministic_op(node): + logger.info( + "Skipping constant folding for non-deterministic op %r (%s::%s)", + node.name, + node.domain, + node.op_type, + ) + return None + + if any(x.is_graph_input() for x in node.inputs if x is not None): + logger.info( + "Skipping constant folding for node %r because it is graph input to preserve graph signature", + node.name, + ) + return None + + # Ensure all node inputs are constants + if any(x.const_value is None for x in node.inputs if x is not None): + return None + + should_fold = self.should_fold(node) + + if should_fold is False: + logger.info( + "Skipping constant folding for node %r because should_fold returned False", + node.name, + ) + return None + + elif should_fold is None: + # Use default rules to decide whether to fold the node: + # - ConstantOfShape is preserved to avoid increasing model size unnecessarily + # - If the any tensor input size exceeds the input_size_limit, skip folding the node + if _is_onnx_op(node, "ConstantOfShape"): + logger.info( + "Skipping constant folding for node %r because ConstantOfShape is preserved by default", + node.name, + ) + return None + + input_tensors = [x.const_value if x is not None else None for x in node.inputs] + large_inputs = [ + tensor is not None and tensor.size > self.input_size_limit + for tensor in input_tensors + ] + if any(large_inputs): + # Decide whether to fold large constants + assert len(node.inputs) == len(large_inputs) + if (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS and all( + len(input.consumers()) == 1 or (not is_large) + for input, is_large in zip(node.inputs, large_inputs) + if input is not None + ): + # If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node, + # we can still fold it even if the input size exceeds the limit + pass + else: + # Skip folding large tensors + if logger.isEnabledFor(logging.INFO): + input_sizes = [ + tensor.size for tensor in input_tensors if tensor is not None + ] + logger.info( + "Skipping constant folding for node %r due to large input sizes: %s", + node, + input_sizes, + ) + return None + else: + logger.info( + "Constant folding node %r because should_fold returned True", + node.name, + ) + + input_values = [_get_numpy_value(x) for x in node.inputs] + + def convert(av): + if av.type == ir.AttributeType.TENSOR: + return ir.serde.serialize_tensor(av.value) + return av.value + + # TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node + attr_values = {name: convert(attr) for name, attr in node.attributes.items()} + outputs = _reference_evaluator.evaluate( + node.domain, node.op_type, version, *input_values, **attr_values + ) + + if outputs is None: + return None + if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): + replacement = self.new_constant(node, outputs) + if replacement is None: + return None + return Replacement(replacement.outputs, [replacement]) + else: + logger.warning( + "Skipping constant folding for op %s with multiple outputs.", node.op_type + ) + return None + + def replace_node( + self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function + ) -> None: + logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) + + # Record the names of the values that has contributed to the replacement + _record_contributing_values(node, replacement) + + ir.convenience.replace_nodes_and_values( + root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs + ) + + self._modified = True + + # TODO: what about new opset_imports? + # TODO: track statistics about replaced nodes and sizes of new constants + + def visit_attribute(self, attr: ir.Attr) -> None: + if attr.is_ref(): + return + if attr.type == ir.AttributeType.GRAPH: + self.visit_graph(attr.as_graph()) + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.as_graphs(): + self.visit_graph(graph) + + def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None: + replacement = self.process_node(node) + if replacement is None: + # No change. Process attributes. + for attr in node.attributes.values(): + self.visit_attribute(attr) + return + else: + self.replace_node(node, replacement, root) + + def visit_graph(self, graph: ir.Graph) -> None: + for node in graph: + self.visit_node(node, graph) + + # Replace outputs if output nodes can be folded. This are typically outputs from + # Identity nodes + for i, output in enumerate(graph.outputs): + if output is None: + continue + sym_value = self._state.get_sym_value(output) + if not isinstance(sym_value, ir.Value): + # An output must be a Value + continue + if not _sym_value_can_replace_graph_output(graph, sym_value, output): + continue + # Rename sym_value to match the output name + sym_value.name = output.name + graph.outputs[i] = sym_value + self._modified = True + + def visit_function(self, function: ir.Function) -> None: + for node in function: + self.visit_node(node, function) + + def call(self, model: ir.Model) -> FoldConstantsResult: + self._reset() + self._opset_imports = model.opset_imports + self.visit_graph(model.graph) + for function in model.functions.values(): + # TODO(rama): Should we specialize functions? + self.visit_function(function) + return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map) + + +def _sym_value_can_replace_graph_output( + graph: ir.Graph, sym_value: ir.Value, output: ir.Value +) -> bool: + if (producer := sym_value.producer()) is None: + # If the sym_value has no producer, it is some graph's input + # ONNX does not allow a graph input to be a graph output + return False + if producer.graph is not graph: + # The sym_value must be produced by a node in the graph to be an output of this graph + return False + if sym_value.is_graph_output(): + # If the sym_value is already an output of a graph, we cannot rename it + # to this output name. Otherwise the graph output represented by sym_value + # will lose its name. + return False + return True + + +@dataclasses.dataclass +class FoldConstantsResult(ir.passes.PassResult): + symbolic_value_map: dict[ir.Value, SymbolicValue] + + # Add conversion to bool for backward compatibility. The previously returned value + # for the fold_constants method was a boolean indicating whether the model was modified. + def __bool__(self) -> bool: + return self.modified + + +def fold_constants( + model: ir.Model, + *, + onnx_shape_inference: bool = False, + input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, + output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, +) -> FoldConstantsResult: + """ + Applies constant folding optimization to the model. + + Args: + model: The ONNX model to optimize. + onnx_shape_inference: Whether to enable ONNX shape inference during + constant folding. Defaults to False. + input_size_limit: The maximum size of input tensors + that can be considered for constant folding. Defaults to + `DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT`. + output_size_limit: The maximum size of output tensors + that can be stored after constant folding. Defaults to + `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding, False if it should not be folded, + or None to use the default rules. Defaults to a function that always returns None. + + Returns: + An instance of `FoldConstantsResult`. + + """ + folder_pass = FoldConstantsPass( + shape_inference=onnx_shape_inference, + input_size_limit=input_size_limit, + output_size_limit=output_size_limit, + should_fold=should_fold, + ) + return folder_pass(model) # type: ignore[return-value] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py new file mode 100644 index 0000000000..d3d76c4a23 --- /dev/null +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -0,0 +1,699 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import onnx +import parameterized + +import onnxscript.optimizer as optimizer +from onnxscript import ir +from onnxscript.optimizer import _constant_folding + + +class FoldConstantsTest(unittest.TestCase): + def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs): + if isinstance(model, str): + model = ir.from_onnx_text(model) + _constant_folding.fold_constants( + model, onnx_shape_inference=onnx_shape_inference, **kwargs + ) + optimizer.remove_unused_nodes(model) + # Ensure the model is valid after optimization + onnx.checker.check_model(ir.serde.serialize_model(model)) + return model + + def test_fold_add(self): + model = """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, four) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") + + def test_fold_cast_like(self): + model = """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + two_float = CastLike(two, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") + + def test_fold_shape(self): + model = """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + two_float = CastLike(rank, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") + + def test_fold_shape_slice(self): + model = """ + + agraph (float[M, N, 16, 16] x) => (float[M, N, 16, 16] z) { + shape = Shape (x) + two = Size(shape) + two_float = CastLike(two, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") + + def test_fold_if_cond(self): + model = """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + zero = Constant () + zero_cast = CastLike (zero, rank) + is_scalar = Equal(zero_cast, rank) + z = If (is_scalar) < + then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, + else_branch = else_graph () => (else_z) { else_z = Mul (x, x) } + > + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph[0].outputs[0].name, "z") + self.assertEqual(optimized.graph[0].op_type, "Mul") + + def test_fold_inside_if_branch(self): + model = """ + + agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { + two = Constant () + z = If (cond) < + then_branch = then_graph () => (then_z) { + three = Constant () + temp = Add (two, three) + then_z = Mul (temp, x) + }, + else_branch = else_graph () => (else_z) { + four = Constant () + temp = Add (two, four) + else_z = Mul (temp, x) + } + > + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + then_graph = optimized.graph[0].attributes["then_branch"].as_graph() + self.assertEqual(len(then_graph), 2) + else_graph = optimized.graph[0].attributes["else_branch"].as_graph() + self.assertEqual(len(else_graph), 2) + + def test_fold_if_propagate(self): + model = """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + zero = Constant () + two = Constant () + zero_cast = CastLike (zero, rank) + is_scalar = Equal(zero_cast, rank) + m = If (is_scalar) < + then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, + else_branch = else_graph () => (else_z) { else_z = Mul (two, two) } + > + m_square = Mul (m, m) + z = Mul (x, m_square) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "m_square") + self.assertEqual(optimized.graph[0].op_type, "Constant") + + def test_fold_redundant_cast(self): + model = """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + x_cast = CastLike(x, two) + z = Mul(x_cast, two) + } + """ + + optimized = self._fold(model, onnx_shape_inference=True) + self.assertEqual(len(optimized.graph), 2) + + def test_fold_redundant_cast2(self): + model = """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + z = CastLike(x, two) + } + """ + + optimized = self._fold(model, onnx_shape_inference=True) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph[0].op_type, "Identity") + self.assertEqual(optimized.graph[0].outputs[0].name, "z") + self.assertEqual(optimized.graph[0].inputs[0].name, "x") + + def test_shape_inference(self): + model = """ + + agraph (int64[64] x) => (int64[N] z) { + one = Constant () + cond = Equal(one, one) + temp = If (cond) < + then_branch = then_graph () => (then_z) { + shape1 = Constant () + then_z = Reshape(x, shape1) + }, + else_branch = else_graph () => (else_z) { + shape2 = Constant () + else_z = Reshape(x, shape2) + }> + shape = Shape(temp) # shape = [8, 8] or [64], but [8, 8] after constant propagation + rank = Size(shape) # rank = 2 or 1, but 2 after constant propagation + C = Add (rank, rank) + z = Mul(x, C) + } + """ + + optimized = self._fold(model, onnx_shape_inference=True) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "C") + + def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( + self, + ): + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => (float[1,512] return_val) { + int64_128 = Constant () + splits = SplitToSequence (x, int64_128) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + int64_3 = Constant () + split_3 = SequenceAt (splits, int64_3) + return_val = Concat (split_0, split_1, split_2, split_3) +}""" + + # TODO: There is an unrelated limitation that `symbolic_value` is not + # utilized when the value is only referenced by graph output. + # E.g., the following test model will not have this optimization + # applied. + # + # < + # ir_version: 8, + # opset_import: ["" : 18] + # > + # func (float[1,512] x) => ( split_0, split_1, split_2, split_3) { + # int64_128 = Constant () + # splits = SplitToSequence (x, int64_128) + # int64_0 = Constant () + # split_0 = SequenceAt (splits, int64_0) + # int64_1 = Constant () + # split_1 = SequenceAt (splits, int64_1) + # int64_2 = Constant () + # split_2 = SequenceAt (splits, int64_2) + # int64_3 = Constant () + # split_3 = SequenceAt (splits, int64_3) + # } + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(len(optimized.graph[-2].outputs), 4) + self.assertEqual(optimized.graph[-2].op_type, "Split") + + def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split( + self, + ): + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => (float[1,N] return_val) { + const = Constant () + splits = SplitToSequence (x, const) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + return_val = Concat (split_0, split_1, split_2) +}""" + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 3) + self.assertEqual(len(optimized.graph[-2].outputs), 3) + self.assertEqual(optimized.graph[-2].op_type, "Split") + + def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze( + self, + ): + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => (float[1,3] return_val) { + const = Constant () + splits = SplitToSequence (x, const) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + return_val = Concat (split_0, split_1, split_2) +}""" + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 7) + self.assertEqual(len(optimized.graph[1].outputs), 3) + self.assertEqual(optimized.graph[1].op_type, "Split") + self.assertEqual(len([n for n in optimized.graph if n.op_type == "Squeeze"]), 3) + + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( + self, + ): + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => (float[1,3] return_val) { + const = Constant () + splits = SplitToSequence (x, const) + return_val = ConcatFromSequence (splits) +}""" + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 3) + self.assertEqual(optimized.graph[2].op_type, "Concat") + + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( + self, + ): + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => (float[1,3] return_val) { + const = Constant () + splits = SplitToSequence (x, const) + return_val = ConcatFromSequence (splits) +}""" + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 7) + self.assertEqual(optimized.graph[6].op_type, "Concat") + + def test_dynamic_split_to_sequence_list_shape_rewrite(self): + # split is a graph input with known 1-D static shape [4]; values unknown (not constant) + # Ensures the branch: if isinstance(split_shape, tuple) and len(split_shape) == 1 + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[2,N] x, int64[4] split) => (float[2,N] return_val) { + splits = SplitToSequence (x, split) + i0 = Constant () + s0 = SequenceAt (splits, i0) + i1 = Constant () + s1 = SequenceAt (splits, i1) + i2 = Constant () + s2 = SequenceAt (splits, i2) + i3 = Constant () + s3 = SequenceAt (splits, i3) + return_val = Concat (s0, s1, s2, s3) +}""" + optimized = self._fold(model) + # Expect: Split + Concat (index constants & SequenceAt removed) + split_nodes = [n for n in optimized.graph if n.op_type == "Split"] + self.assertEqual(len(split_nodes), 1) + self.assertEqual(len(split_nodes[0].outputs), 4) + self.assertEqual(split_nodes[0].op_type, "Split") + self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph)) + + def test_dynamic_split_to_sequence_list_shape_no_keepdims(self): + # keepdims=0 path with dynamic (non-constant) splits input; triggers squeeze logic. + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,M] x, int64[3] split) => (float[1,M] return_val) { + splits = SplitToSequence (x, split) + i0 = Constant () + s0 = SequenceAt (splits, i0) + i1 = Constant () + s1 = SequenceAt (splits, i1) + i2 = Constant () + s2 = SequenceAt (splits, i2) + return_val = Concat (s0, s1, s2) +}""" + optimized = self._fold(model) + split_nodes = [n for n in optimized.graph if n.op_type == "Split"] + self.assertEqual(len(split_nodes), 1) + self.assertEqual(len(split_nodes[0].outputs), 3) + self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph)) + # Each split output should have a corresponding Squeeze (keepdims=0 branch) + squeeze_nodes = [n for n in optimized.graph if n.op_type == "Squeeze"] + self.assertEqual(len(squeeze_nodes), 3) + + def test_initializer_input_not_folded(self): + model_text = """ + + agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z) + { + # c is not a constant, and following should not be folded. + two_c = Add (c, c) + z = Mul (x, two_c) + }""" + optimized = self._fold(model_text) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph.node(0).op_type, "Add") + + @parameterized.parameterized.expand( + [ + ("output = Dropout(input)",), + ("output = Dropout(input, zero, true)",), + ("output = Dropout(input, half)",), + ("output = Dropout(input, half, false)",), + ] + ) + def test_dropout_identity(self, dropout_node: str): + model = f""" + + agraph (float[N] input) => (float[N] output) + + {{ + {dropout_node} + }} + """ + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph.node(0).op_type, "Identity") + + @parameterized.parameterized.expand( + [ + ("output, mask = Dropout(input)",), + ("output, mask = Dropout(input, zero, true)",), + ("output, mask = Dropout(input, half)",), + ("output, mask = Dropout(input, half, false)",), + ] + ) + def test_dropout_identity_mask(self, dropout_node: str): + model = f""" + + agraph (float[N] input) => (float[N] output, bool[N] mask) + + {{ + {dropout_node} + }} + """ + optimized = self._fold(model) + nodes = list(optimized.graph) + self.assertEqual(len(nodes), 3) + ops = [node.op_type for node in nodes] + self.assertEqual(ops, ["Identity", "Shape", "ConstantOfShape"]) + + def test_concat_identity(self): + model = """ + + agraph (float[N] x) => (float[N] z) + { + z = Concat (x) + } + """ + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph.node(0).op_type, "Identity") + + def test_concat_zero_length(self): + model = """ + + agraph (float[N, 128] x1, float[N, 0] x2, float[N, 128] x3) => (float[N, M] z) + { + z = Concat (x1, x2, x3) + } + """ + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x1", "x3"]) + + def test_concat_zero_length_identity(self): + model = """ + + agraph (float[N, 0] x1, float[N, 128] x2, float[N, 0] x3) => (float[N, M] z) + { + z = Concat (x1, x2, x3) + } + """ + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph.node(0).op_type, "Identity") + self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x2"]) + + def test_concat_zero_length_output(self): + model = """ + + agraph (float[N, 0] x1, float[N, 0] x2, float[N, 0] x3) => (float[N, M] z) + { + z = Concat (x1, x2, x3) + } + """ + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph.node(0).op_type, "Identity") + self.assertEqual([x.name for x in optimized.graph.node(0).inputs], ["x1"]) + + def test_expand_identity(self): + model = """ + + agraph (float[128, 256] x) => (float[128, 256] z) + { + shape = Constant () + z = Expand (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_expand_identity_symdim(self): + model = """ + + agraph (float[B, 256] x) => (float[B, 256] z) + { + b = Shape (x) + const_256 = Constant () + shape = Concat (b, const_256) + z = Expand (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_abs_symdim(self): + model = """ + + agraph (float[B, 256] x) => (float[B, 256] z) + { + b = Shape (x) + const_256 = Constant () + b_256 = Concat (b, const_256) + shape = Abs (b_256) + z = Expand (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_reshape_identity(self): + model = """ + + agraph (float[128, 256] x) => (float[128, 256] z) + { + shape = Constant () + z = Reshape (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_reshape_identity_symdim(self): + model = """ + + agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z) + { + b = Shape (y) + const_256 = Constant () + shape = Concat (b, const_256) + z = Reshape (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_gather_symdim(self): + model = """ + + agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z) + { + b_128 = Shape (y) + index_0 = Constant () + b = Gather (b_128, index_0) + const_256 = Constant () + shape = Concat (b, const_256) + z = Reshape (x, shape) + } + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + + def test_input_size_limit(self): + model_text = """ + + agraph (float[M, 256] x) => (float[M, 256] z) + # placeholder for large initializer of shape [256, 256] + { + w_squared = Mul (w, w) + z = Add (x, w_squared) + } + """ + model = ir.from_onnx_text(model_text) + w = model.graph.initializers["w"] + w.shape = ir.Shape([256, 256]) + w.const_value = ir.tensor(np.random.random((256, 256)).astype(np.float32)) + + # Input size limit will prevent folding of Mul op + optimized = self._fold(model, onnx_shape_inference=False, input_size_limit=128 * 128) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Mul", "Add"]) + + # Input size limit will allow folding of Mul op + # Since there is no increase in model-size, output-size is not a concern. + optimized = self._fold(model, input_size_limit=256 * 256, output_size_limit=256 * 256) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Constant", "Add"]) + + def test_transpose_is_always_folded(self): + model_text = """ + + agraph (float[M, 256] x) => (float[M, 512] z) + # placeholder for large initializer of shape [512, 256] + { + z = Transpose (w) + } + """ + model = ir.from_onnx_text(model_text) + w = model.graph.initializers["w"] + w.shape = ir.Shape([512, 256]) + w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32)) + + # Input size limit will not prevent folding of Transpose op + optimized = self._fold(model, input_size_limit=1) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Constant"]) + + def test_node_is_folded_if_specified_as_should_fold(self): + model_text = """ + + agraph (float[M, 256] x) => (float[42, 42] z) + + { + z = ConstantOfShape (w) + } + """ + model = ir.from_onnx_text(model_text) + + # ConstantOfShape is not folded by default + optimized = self._fold(model) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["ConstantOfShape"]) + + # But ConstantOfShape is folded when specified in should_fold + optimized = self._fold( + model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None + ) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Constant"]) + np.testing.assert_array_equal( + optimized.graph.node(0).attributes["value"].as_tensor().numpy(), + np.ones((42, 42), dtype=np.int64), + ) + + def test_multi_graph_identity_output_preserves_output_name(self): + model = """ + + agraph (float[N] x) => (float[N] graph_output1, float[N] graph_output2) { + t = Identity(x) + graph_output1 = Identity(t) + graph_output2 = Identity(t) + }""" + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual([n.op_type for n in optimized.graph], ["Identity", "Identity"]) + self.assertEqual( + [n.outputs[0].name for n in optimized.graph], ["graph_output1", "graph_output2"] + ) + self.assertEqual([input.name for input in optimized.graph.inputs], ["x"]) + + # This should not be constant-foldable as the constant references an + # attribute and thus the shape cannot be resolved. At the same time it + # should not fail due to the attribute value being None in + # _process_constant_node + def test_attribute_reference(self): + model = """ + + agraph () => (int64[N] z) { + x = Constant () + z = Shape (x) + } + """ + + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/optimizer/_function_folding_test.py b/onnxscript/optimizer/_function_folding_test.py new file mode 100644 index 0000000000..6f2b052b9e --- /dev/null +++ b/onnxscript/optimizer/_function_folding_test.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import onnx + +import onnxscript.testing +from onnxscript import ir, optimizer + + +def _create_model(model_text: str) -> ir.Model: + """Create a model from the given text.""" + model = onnx.parser.parse_model(model_text) + return ir.serde.deserialize_model(model) + + +class FunctionFoldingTest(unittest.TestCase): + def test_identity(self): + model = _create_model( + """ + + agraph (float[N] x1, bool cond1) => (float[N] z1) { + z1 = local.fun1(x1, cond1) + } + + fun1 (x, cond) => (z) { + t = Identity(x) + t2 = Identity(t) + t3 = If (cond) < + then_branch = then_graph() => (t4) { + t5 = Identity(t2) + t4 = Identity(t5) + }, + else_branch = else__graph() => (t6) { + t7 = Identity(t) + t6 = Identity(t7) + } + > + t4 = Add(t3, t3) + z = Identity(t4) + }""" + ) + optimized = optimizer.optimize( + model, onnx_shape_inference=False, num_iterations=1, inline=True + ) + self.assertEqual(len(optimized.functions), 0) + self.assertEqual(len(optimized.graph), 2) + + def test_sequence_concat(self): + model = _create_model( + """ + + agraph (float[N] x1) => (float[M] z1) { + z1 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + }""" + ) + optimized = optimizer.optimize( + model, onnx_shape_inference=False, num_iterations=1, inline=False + ) + function = optimized.functions[("local", "fun1", "")] + self.assertEqual(len(function), 3) + self.assertEqual(function[2].op_type, "Concat") + + def test_sequence_at(self): + model = _create_model( + """ + + agraph (float[N] x) => (float[M] z) { + t0 = Add (x, x) + t1 = Mul (x, x) + s = SequenceConstruct (x, t0, t1) + one = Constant () + z = SequenceAt (s, one) + }""" + ) + optimized = optimizer.optimize( + model, onnx_shape_inference=False, num_iterations=1, inline=False + ) + expected = _create_model( + """ + + agraph (float[N] x) => (float[M] z) { + z = Add (x, x) + }""" + ) + # TODO(justinchuby): Implement assert_isomorphic_graph for IR objects + onnxscript.testing.assert_isomorphic_graph( + ir.to_proto(optimized.graph), ir.to_proto(expected.graph) + ) + + def test_single_user_function_is_modified_inplace_after_folding(self): + model = _create_model( + """ + + agraph (float[N] x1) => (float[M] z1) { + z1 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + }""" + ) + optimized = optimizer.optimize( + model, onnx_shape_inference=False, num_iterations=1, inline=False + ) + self.assertEqual(next(iter(optimized.functions.values())).name, "fun1") + + def test_fold_nested_if_function_succeeds(self): + model = _create_model( + """ + < + ir_version: 9, + opset_import: ["this" : 1, "" : 18] + > + func (float[1,512] x, float[1,512] y) => ( out) { + out = this.foldable_func (x, y) + } + < + domain: "this", + opset_import: ["" : 18] + > + foldable_func (x, y) => (z_6) + { + cond = Constant () + z_6 = If (cond) ( z_2) { + cond_0 = Not (cond) + z_2 = If (cond_0) ( z) { + z = Add (x, x) + }, else_branch: graph = elseGraph_5 () => ( z_1) { + z_1 = Identity (x) + }> + }, else_branch: graph = elseGraph_4 () => ( z_5) { + z_5 = If (cond) ( z_3) { + z_3 = Add (y, y) + }, else_branch: graph = elseGraph_10 () => ( z_4) { + z_4 = Add (x, y) + }> + }> + }""" + ) + optimized = optimizer.optimize(model, onnx_shape_inference=False, inline=True) + + self.assertEqual(len(optimized.functions), 0) + self.assertEqual(len(optimized.graph), 1) + self.assertNotIn("If", {n.op_type for n in optimized.graph}) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py new file mode 100644 index 0000000000..307144462f --- /dev/null +++ b/onnxscript/optimizer/_optimizer.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import logging +from typing import Callable + +import onnx_ir as ir +import onnx_ir.passes.common as common_passes + +from onnxscript import rewriter +from onnxscript.optimizer import _constant_folding + +logger = logging.getLogger(__name__) + + +def optimize_ir( + model: ir.Model, + num_iterations: int = 2, + *, + onnx_shape_inference: bool = True, + stop_if_no_change: bool = True, + input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, + output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, + inline: bool = True, +) -> None: + """Optimizes a model. + + Args: + model: The model to be optimized. + num_iterations: Number of times the optimization loop is repeated. + onnx_shape_inference: Applies node-level shape-inference as part of optimization + stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. + input_size_limit: Will not apply constant folding to ops with any input of size + greater than this. Does not apply to special ops like Shape() and Size(). + output_size_limit: Will not rewrite any foldable-op into a Constant op if the size + of the output tensor is greater than this. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding. + The function should return True/False value to indicate if this particular + node should be folded, or None to use the default folding rules. + inline: If True, inlines all functions in the model. + """ + passes = [ + ir.passes.PassManager( + [ + _constant_folding.FoldConstantsPass( + shape_inference=onnx_shape_inference, + input_size_limit=input_size_limit, + output_size_limit=output_size_limit, + should_fold=should_fold, + ), + rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES), + common_passes.RemoveUnusedNodesPass(), + common_passes.RemoveUnusedFunctionsPass(), + common_passes.RemoveUnusedOpsetsPass(), + ], + steps=num_iterations, + early_stop=stop_if_no_change, + ), + common_passes.RemoveUnusedNodesPass(), + common_passes.LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=0), + common_passes.LiftSubgraphInitializersToMainGraphPass(), + common_passes.DeduplicateInitializersPass(), + common_passes.CommonSubexpressionEliminationPass(), + ] + if inline: + # Inline all functions first before optimizing + passes = [common_passes.InlinePass(), *passes] + optimizer_pass = ir.passes.Sequential(*passes) + assert optimizer_pass.in_place + result = optimizer_pass(model) + assert result.model is model diff --git a/onnxscript/optimizer/_optimizer_test.py b/onnxscript/optimizer/_optimizer_test.py new file mode 100644 index 0000000000..0aed7f57ca --- /dev/null +++ b/onnxscript/optimizer/_optimizer_test.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx +import onnx_ir as ir + +import onnxscript.optimizer as optimizer + + +class OptimizerTest(unittest.TestCase): + def _model_proto(self) -> onnx.ModelProto: + return onnx.parser.parse_model( + """ + < + ir_version: 8, + opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], + producer_name: "pytorch", + producer_version: "2.2.0" + > + main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) + < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> + { + _val_1 = Constant () + _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) + _val_3 = Constant () + getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) + _val_5 = Constant () + getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) + return_val = Concat (getitem_1, getitem) + } + + + aten_split (self, split_size) => (return_val) + { + return_val = SplitToSequence (self, split_size) + } + + + aten_getitem (self, i) => (return_val) + { + return_val = SequenceAt (self, i) + } + + + Rank (input) => (return_val) + { + tmp = Shape (input) + return_val = Size (tmp) + } + + + IsScalar (input) => (return_val) + { + tmp = Shape (input) + tmp_0 = Size (tmp) + tmp_1 = Constant () + return_val = Equal (tmp_0, tmp_1) + } + """ + ) + + def test_static_split_to_sequence_with_uneven_split_proto(self): + model_proto = self._model_proto() + optimized = optimizer.optimize( + model_proto, num_iterations=1, onnx_shape_inference=False + ) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[0].output), 2) + self.assertEqual(optimized.graph.node[0].op_type, "Split") + + def test_static_split_to_sequence_with_uneven_split_ir(self): + model_proto = self._model_proto() + model_ir = ir.serde.deserialize_model(model_proto) + optimizer.optimize_ir(model_ir, num_iterations=1, onnx_shape_inference=False) + self.assertEqual(len(model_ir.graph), 2) + self.assertEqual(len(model_ir.graph.node(0).outputs), 2) + self.assertEqual(model_ir.graph.node(0).op_type, "Split") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py deleted file mode 100644 index 283a13fd13..0000000000 --- a/onnxscript/optimizer/constant_folding.py +++ /dev/null @@ -1,283 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any, Sequence - -import numpy as np -import onnx -import onnx.reference.ops - -import onnxscript._legacy_ir as ir -from onnxscript._legacy_ir import visitor -from onnxscript.optimizer import evaluator -from onnxscript.utils.utils import ( - is_control_flow_op, - is_onnx_domain, -) - -logger = logging.getLogger(__name__) - -_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 - -# Ops excluded from constant-propagation: -# * Random ops, which are not deterministic (checked below) -# * Control flow ops (checked by presence of graph-attribute) - -non_deterministic_ops = frozenset( - { - "RandomUniform", - "RandomNormal", - "RandomUniformLike", - "RandomNormalLike", - "Multinomial", - } -) - -onnx_domain = frozenset({"", "onnx.ai"}) - - -def is_non_deterministic_op(node: onnx.NodeProto) -> bool: - return node.op_type in non_deterministic_ops and is_onnx_domain(node.domain) - - -def is_constant_op(node: onnx.NodeProto) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(node.domain) - - -class ConstantFolder(visitor.FunctionCallsiteProtoTransformer): - def __init__( - self, - registry: evaluator.PartialEvaluatorRegistry, - external_data_folder: str, - *, - do_shape_inference: bool, - ) -> None: - self.registry = registry - # TODO: make evaluator a parameter - self.evaluate = evaluator.reference_evaluator.evaluate - self._do_shape_inference = do_shape_inference - self._init() - super().__init__(external_data_folder, do_shape_inference=do_shape_inference) - - def _init(self) -> None: - self.counts = {} - self.sizes = {} - - def add_count(self, op: str, size: int = 1): - self.counts[op] = self.counts.get(op, 0) + 1 - self.sizes[op] = self.sizes.get(op, 0) + size - - def foldable_value(self, name: str, value): - """Checks if a runtime-constant can and should be folded into the graph. - - We fold constants only if they are tensors (not lists of tensors, for example) - and have size below desired limit. - """ - if value is ir.NotConstant: - return None - - if not isinstance(value, np.ndarray): - # ONNX does not have a way to represent non-tensor constants, eg. a sequence. - # So, a constant-value of type sequence is not folded, but it can be used - # to optimize subsequent operations when possible. - logger.warning( - "Skip storing constant folded value %s due to unsupported type %s.", - name, - type(value), - ) - return None - - if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: - logger.warning( - "Skip storing constant folded nvalue %s due to large size %s.", - name, - value.nbytes, - ) - return None - - return onnx.numpy_helper.from_array(value, name) - - def new_constant(self, name, value): - if isinstance(value, (int, float, np.ScalarType)): - value = np.array(value) - - info = self.lookup_or_create(name) - info.value = value - - tensor = self.foldable_value(name, value) - if tensor is None: - return None - - logger.debug( - "New constant for value %s dtype: %s shape: %s", - name, - value.dtype, - value.shape, - ) - info.type = onnx.helper.make_tensor_type_proto( - onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape - ) - node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor) - return [node] - - def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]: - if self.scopes.current_scope().current_function_scope(): - # Need to resolve ref_attr_name if inside a function. - attr_dict = {} - for attribute in attributes: - concrete_attribute = ( - self.lookup_ref_attribute(attribute.ref_attr_name) - if attribute.ref_attr_name - else attribute - ) - if concrete_attribute is None: - continue - attr_dict[attribute.name] = onnx.helper.get_attribute_value(concrete_attribute) - return attr_dict - return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} - - def replace_copy(self, node: onnx.NodeProto) -> None: - for i in range(len(node.input)): - input = self.get_input(node, i) - if input is not None and input.is_copy(): - old_value = self.lookup_or_create(input.name) - assert isinstance(input.symbolic_value, str) - new_value = self.lookup_or_create(input.symbolic_value) - # Merge meta info. It is important to do if the new value - # is created by evaluator, and thus carries zero meta info. - # Since this is a copy, the meta info should be the same. - new_value.identity_merge_from(old_value) - node.input[i] = input.symbolic_value - - def process_function_outputs(self, function: onnx.FunctionProto) -> bool: - # Resolve copy for function subgraph output. - # Avoid copy of function subgraph input, because it is illegal for a direct edge - # from function input to function output. - prohibited_value_set = set(function.input) - updated = False - for i, output_name in enumerate(function.output): - output = self.lookup(output_name) - if ( - output is not None - and output.is_copy() - and output.symbolic_value not in prohibited_value_set - ): - old_value = self.lookup_or_create(output.name) - assert isinstance(output.symbolic_value, str) - new_value = self.lookup_or_create(output.symbolic_value) - new_value.identity_merge_from(old_value) - function.output[i] = output.symbolic_value - updated = True - return updated - - def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - self.replace_copy(node) - - super().process_node(node) - - inputs = [self.lookup(x) for x in node.input] - attrs = self.convert_attributes(node.attribute) - - domain = node.domain - op = node.op_type - version = self.lookup_version(domain) - - # if any(x is Undefined for x in inputs): - # return None - # Above check ensures that none of the optimizations below need to handle - # undefined inputs - - op_optimizers = self.registry.lookup_evaluators(domain, op, version) - for optimizer in op_optimizers: - assert optimizer - output = optimizer(self, node) - if output is None: - continue - if isinstance(output, list): - return output - else: - # Currently handles single output only - self.add_count(node.op_type, output.size) - return self.new_constant(node.output[0], output) - - if is_control_flow_op(node) or is_non_deterministic_op(node): - return None - - input_values = [x.value if x is not None else None for x in inputs] - if any(x is ir.NotConstant for x in input_values): - return None - - outputs = self.evaluate(domain, op, version, *input_values, **attrs) - # TODO: what if evaluated value is None? - if outputs is None: - return None - if len(node.output) == 1 and not isinstance(outputs, (tuple, list)): - replacement = self.new_constant(node.output[0], outputs) - if is_constant_op(node): - return None - self.add_count(op, outputs.size) - return replacement - else: - logger.warning("Skipping constant folding for op %s with multiple outputs.", op) - return None - - def process_function_node( - self, node: onnx.NodeProto - ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: - self.replace_copy(node) - - _, new_function = super().process_function_node(node) - - # Replace function node with Constant if all outputs are constants - ir_values = [self.lookup(output_name) for output_name in node.output] - tensors = [ - self.foldable_value(output_name, ir_value.value if ir_value is not None else None) - for output_name, ir_value in zip(node.output, ir_values) - ] - if all(tensor is not None for tensor in tensors): - replacements = [] - for output_name, tensor in zip(node.output, tensors): - newnode = onnx.helper.make_node( - "Constant", inputs=[], outputs=[output_name], value=tensor - ) - replacements.append(newnode) - logger.debug( - "Function node replacements: node %s %s (%s/%s)", - node.name, - [replacement.output for replacement in replacements], - len(replacements), - len(node.output), - ) - return replacements, new_function - return None, new_function - - def visit_model(self, model: onnx.ModelProto) -> None: - self._init() - - super().visit_model(model) - - -def fold_constants( - model: onnx.ModelProto, - external_data_folder: str = "", - *, - onnx_shape_inference: bool = False, -) -> bool: - """ - Applies constant folding optimization to the model. - Returns true iff the model was modified. - """ - folder = ConstantFolder( - evaluator.registry, - external_data_folder, - do_shape_inference=onnx_shape_inference, - ) - folder.visit_model(model) - for op in folder.counts: - logger.info( - "Constant-folded '%s' %s times, with %s size.", - op, - folder.counts[op], - folder.sizes[op], - ) - return folder.modified diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/constant_folding_test.py deleted file mode 100644 index 64a27e33de..0000000000 --- a/onnxscript/optimizer/constant_folding_test.py +++ /dev/null @@ -1,444 +0,0 @@ -import unittest - -import onnx -import pytest - -from onnxscript import optimizer - - -class FoldConstantsTest(unittest.TestCase): - def test_fold_add(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - two = Constant () - four = Add(two, two) - z = Mul(x, four) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") - - def test_fold_cast_like(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - two = Constant () - two_float = CastLike(two, x) - four = Add(two_float, two_float) - z = Mul(x, four) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") - - def test_fold_shape(self): - model = onnx.parser.parse_model( - """ - - agraph (float[16, 16] x) => (float[16, 16] z) { - shape = Shape(x) - rank = Size(shape) - two_float = CastLike(rank, x) - four = Add(two_float, two_float) - z = Mul(x, four) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") - - def test_fold_shape_slice(self): - model = onnx.parser.parse_model( - """ - - agraph (float[M, N, 16, 16] x) => (float[M, N, 16, 16] z) { - shape = Shape (x) - two = Size(shape) - two_float = CastLike(two, x) - four = Add(two_float, two_float) - z = Mul(x, four) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") - - def test_fold_if_cond(self): - model = onnx.parser.parse_model( - """ - - agraph (float[16, 16] x) => (float[16, 16] z) { - shape = Shape(x) - rank = Size(shape) - zero = Constant () - zero_cast = CastLike (zero, rank) - is_scalar = Equal(zero_cast, rank) - z = If (is_scalar) < - then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, - else_branch = else_graph () => (else_z) { else_z = Mul (x, x) } - > - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 1) - self.assertEqual(optimized.graph.node[0].output[0], "z") - self.assertEqual(optimized.graph.node[0].op_type, "Mul") - - def test_fold_inside_if_branch(self): - model = onnx.parser.parse_model( - """ - - agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { - two = Constant () - z = If (cond) < - then_branch = then_graph () => (then_z) { - three = Constant () - temp = Add (two, three) - then_z = Mul (temp, x) - }, - else_branch = else_graph () => (else_z) { - four = Constant () - temp = Add (two, four) - else_z = Mul (temp, x) - } - > - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 1) - then_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "then_branch") - self.assertEqual(len(then_graph.node), 2) - else_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "else_branch") - self.assertEqual(len(else_graph.node), 2) - - def test_fold_if_propagate(self): - model = onnx.parser.parse_model( - """ - - agraph (float[16, 16] x) => (float[16, 16] z) { - shape = Shape(x) - rank = Size(shape) - zero = Constant () - two = Constant () - zero_cast = CastLike (zero, rank) - is_scalar = Equal(zero_cast, rank) - m = If (is_scalar) < - then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, - else_branch = else_graph () => (else_z) { else_z = Mul (two, two) } - > - m_square = Mul (m, m) - z = Mul (x, m_square) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - print(onnx.printer.to_text(optimized)) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "m_square") - self.assertEqual(optimized.graph.node[0].op_type, "Constant") - - def test_fold_redundant_cast(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - two = Constant () - x_cast = CastLike(x, two) - z = Mul(x_cast, two) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 2) - - def test_fold_redundant_cast2(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - two = Constant () - z = CastLike(x, two) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 1) - self.assertEqual(optimized.graph.node[0].op_type, "Identity") - self.assertEqual(optimized.graph.node[0].output[0], "z") - self.assertEqual(optimized.graph.node[0].input[0], "x") - - @pytest.mark.skip(reason="Feature removed to catch errors early") - def test_fold_undefined_vars(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - four = Add(two, two) - y = Shape(t1) - w = CastLike(x, t2) - w2 = CastLike(t3, t4) - w3 = Size(t5) - z = Sum (four, y, w, w2, w3) - } - """ - ) - # No optimizations expected. Just make sure it doesn't crash. - optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) - self.assertEqual(len(optimized.graph.node), 6) - - def test_shape_inference(self): - model = onnx.parser.parse_model( - """ - - agraph (int64[64] x) => (int64[N] z) { - one = Constant () - cond = Equal(one, one) - temp = If (cond) < - then_branch = then_graph () => (then_z) { - shape1 = Constant () - then_z = Reshape(x, shape1) - }, - else_branch = else_graph () => (else_z) { - shape2 = Constant () - else_z = Reshape(x, shape2) - }> - shape = Shape(temp) # shape = [8, 8] or [64], but [8, 8] after constant propagation - rank = Size(shape) # rank = 2 or 1, but 2 after constant propagation - C = Add (rank, rank) - z = Mul(x, C) - } - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - print(onnx.printer.to_text(optimized)) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "C") - - def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( - self, - ): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["" : 18] -> -func (float[1,512] x) => ( return_val) { - int64_128 = Constant () - splits = SplitToSequence (x, int64_128) - int64_0 = Constant () - split_0 = SequenceAt (splits, int64_0) - int64_1 = Constant () - split_1 = SequenceAt (splits, int64_1) - int64_2 = Constant () - split_2 = SequenceAt (splits, int64_2) - int64_3 = Constant () - split_3 = SequenceAt (splits, int64_3) - return_val = Concat (split_0, split_1, split_2, split_3) -} - """ - ) - - # TODO: There is an unrelated limitation that `symbolic_value` is not - # utilized when the value is only referenced by graph output. - # E.g., the following test model will not have this optimization - # applied. - """ -< - ir_version: 8, - opset_import: ["" : 18] -> -func (float[1,512] x) => ( split_0, split_1, split_2, split_3) { - int64_128 = Constant () - splits = SplitToSequence (x, int64_128) - int64_0 = Constant () - split_0 = SequenceAt (splits, int64_0) - int64_1 = Constant () - split_1 = SequenceAt (splits, int64_1) - int64_2 = Constant () - split_2 = SequenceAt (splits, int64_2) - int64_3 = Constant () - split_3 = SequenceAt (splits, int64_3) -} - """ - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(len(optimized.graph.node[-2].output), 4) - self.assertEqual(optimized.graph.node[-2].op_type, "Split") - - def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split( - self, - ): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["" : 18] -> -func (float[1,512] x) => ( return_val) { - const = Constant () - splits = SplitToSequence (x, const) - int64_0 = Constant () - split_0 = SequenceAt (splits, int64_0) - int64_1 = Constant () - split_1 = SequenceAt (splits, int64_1) - int64_2 = Constant () - split_2 = SequenceAt (splits, int64_2) - return_val = Concat (split_0, split_1, split_2) -} - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 3) - self.assertEqual(len(optimized.graph.node[-2].output), 3) - self.assertEqual(optimized.graph.node[-2].op_type, "Split") - - def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze( - self, - ): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["" : 18] -> -func (float[1,3] x) => ( return_val) { - const = Constant () - splits = SplitToSequence (x, const) - int64_0 = Constant () - split_0 = SequenceAt (splits, int64_0) - int64_1 = Constant () - split_1 = SequenceAt (splits, int64_1) - int64_2 = Constant () - split_2 = SequenceAt (splits, int64_2) - return_val = Concat (split_0, split_1, split_2) -} - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 7) - self.assertEqual(len(optimized.graph.node[1].output), 3) - self.assertEqual(optimized.graph.node[1].op_type, "Split") - self.assertEqual(len([n for n in optimized.graph.node if n.op_type == "Squeeze"]), 3) - - def test_static_split_to_sequence_with_uneven_split(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], - producer_name: "pytorch", - producer_version: "2.2.0" -> -main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) - < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> -{ - _val_1 = Constant () - _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) - _val_3 = Constant () - getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) - _val_5 = Constant () - getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) - return_val = Concat (getitem_1, getitem) -} -< - domain: "pkg.onnxscript.torch_lib", - opset_import: ["" : 18] -> -aten_split (self, split_size) => (return_val) -{ - return_val = SplitToSequence (self, split_size) -} -< - domain: "pkg.onnxscript.torch_lib", - opset_import: ["" : 18] -> -aten_getitem (self, i) => (return_val) -{ - return_val = SequenceAt (self, i) -} -< - domain: "pkg.onnxscript.torch_lib.common", - opset_import: ["" : 18] -> -Rank (input) => (return_val) -{ - tmp = Shape (input) - return_val = Size (tmp) -} -< - domain: "pkg.onnxscript.torch_lib.common", - opset_import: ["" : 18] -> -IsScalar (input) => (return_val) -{ - tmp = Shape (input) - tmp_0 = Size (tmp) - tmp_1 = Constant () - return_val = Equal (tmp_0, tmp_1) -} - """ - ) - optimized = optimizer.optimize(model, onnx_shape_inference=False) - - print(onnx.printer.to_text(optimized)) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(len(optimized.graph.node[0].output), 2) - self.assertEqual(optimized.graph.node[0].op_type, "Split") - - def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( - self, - ): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["" : 18] -> -func (float[1,3] x) => ( return_val) { - const = Constant () - splits = SplitToSequence (x, const) - return_val = ConcatFromSequence (splits) -} - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 3) - self.assertEqual(optimized.graph.node[2].op_type, "Concat") - onnx.checker.check_model(optimized) - - def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( - self, - ): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["" : 18] -> -func (float[1,3] x) => ( return_val) { - const = Constant () - splits = SplitToSequence (x, const) - return_val = ConcatFromSequence (splits) -} - """ - ) - optimized = optimizer.optimize(model, num_iterations=1) - self.assertEqual(len(optimized.graph.node), 7) - self.assertEqual(optimized.graph.node[6].op_type, "Concat") - onnx.checker.check_model(optimized) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/optimizer/evaluator.py b/onnxscript/optimizer/evaluator.py deleted file mode 100644 index 30ea2823d5..0000000000 --- a/onnxscript/optimizer/evaluator.py +++ /dev/null @@ -1,438 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# ------------------------------------------------------------------------- - -from __future__ import annotations - -import dataclasses -import logging -import math -from typing import Any, Callable, Protocol, Sequence, Union - -import numpy as np -import onnx -import onnx.reference.ops - -import onnxscript._legacy_ir as ir -from onnxscript.utils.utils import ( - get_node_attr_value, -) - -logger = logging.getLogger(__name__) - -# "Standard" evaluators are used to perform constant-folding. -# The API below works only for non-control-flow ops (ops without any graph-attributes). -# This currently used ONNX's reference implementation. But we could also -# use ORT's implementation if we want to. - - -class ReferenceEvaluator: - def get_evaluator(self, domain: str, op: str, version: int) -> callable | None: - try: - op_impl_class = onnx.reference.ops.load_op(domain, op, version) - return op_impl_class.eval # noqa: TRY300 - except Exception: - return None - - def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: - logger.debug("Evaluating %s::%s", domain, op) - evaluator = self.get_evaluator(domain, op, version) - if evaluator is None: - return None - return evaluator(*args, **kwargs) - - -reference_evaluator = ReferenceEvaluator() - -# The "partial evaluators" below are non-standard evaluators. They are used to perform -# partial evaluation and/or static program analysis (abstract interpretation). - - -class IRContext(Protocol): - """A class that represents the context for partial evaluation. - - This is a placeholder, subject to simplification when a proper IR is defined. - """ - - def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... - - def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... - - def input_const_value(self, node: onnx.NodeProto, index: int) -> ir.ConcreteValue: ... - - def input_shape( - self, node: onnx.NodeProto, index: int - ) -> onnx.TensorShapeProto | None: ... - - def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: ... - - def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: ... - - def lookup_version(self, domain: str) -> int: ... - - def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict: ... - - def new_constant(self, name: str, value: Any) -> Sequence[onnx.NodeProto] | None: ... - - -# A partial-evaluator function takes an IRContext and a node, and returns a list of -# replacement nodes or None (if no replacement is needed). We return None instead -# of [input node] so the caller is aware that the node is not replaced. If the node -# is replaced, the caller will recursively visit the replacement nodes to process them. - -PartialEvaluatorFunction = Union[ - Callable[[IRContext, onnx.NodeProto], Sequence[onnx.NodeProto]], None -] - - -@dataclasses.dataclass -class PartialEvaluator: - """A class that represents a partial-evaluator for a particular op. - - It is applicable for a specific version range (min_version, max_version) of the op. - The min_version and max_version can be None, indicating that there is no version - constraint in that direction. - """ - - min_version: int | None - max_version: int | None - function: PartialEvaluatorFunction - - def valid_for(self, version: int) -> bool: - """Returns True if this evaluator is applicable for the given version.""" - return (self.min_version is None or version >= self.min_version) and ( - self.max_version is None or version <= self.max_version - ) - - -class PartialEvaluatorRegistry: - """A class that maintains a registry of evaluators for ops.""" - - def __init__(self): - self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} - - def lookup_evaluators(self, domain: str, opname: str, version: int): - evaluator_list = self.op_evaluators.get((domain, opname), []) - return [ - evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) - ] - - def register(self, opname: str, domain: str = "", version=None): - if (domain, opname) not in self.op_evaluators: - evaluator_list = [] - self.op_evaluators[(domain, opname)] = evaluator_list - else: - evaluator_list = self.op_evaluators[(domain, opname)] - if version is None: - min_version = None - max_version = None - elif isinstance(version, int): - min_version = version - max_version = version - elif isinstance(version, tuple): - min_version, max_version = version - - def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: - evaluator_list.append(PartialEvaluator(min_version, max_version, function)) - return function - - return decorator - - -registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() - -register = registry.register - - -def get_bool_value(val) -> bool | None: - if isinstance(val, bool): - return val - if isinstance(val, np.bool_): - return bool(val) - if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: - return val.item(0) - return None - - -def get_size_info(type: onnx.TypeProto) -> np.ndarray | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): - size = 1 - for d in type.tensor_type.shape.dim: - size *= d.dim_value - return np.array(size, dtype=np.int64) - return None - - -def get_dim_info(type: onnx.TypeProto, dim: int) -> int | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - rank = len(type.tensor_type.shape.dim) - dim = dim if dim >= 0 else dim + rank - if dim < 0 or dim >= rank: - return None - if type.tensor_type.shape.dim[dim].HasField("dim_value"): - return type.tensor_type.shape.dim[dim].dim_value - return None - - -@register("Cast") -def cast(context: IRContext, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - if context.input_shape(node, 0) is not None: - output_value = context.get_output(node, 0) - output_value.type = onnx.TypeProto() - output_value.type.CopyFrom(context.input_type(node, 0)) - output_value.type.tensor_type.elem_type = node.attribute[0].i - return None - - -@register("CastLike") -def cast_like(context: IRContext, node: onnx.NodeProto): - source_element_type = context.input_element_type(node, 0) - target_element_type = context.input_element_type(node, 1) - - if target_element_type is None: - return None - if source_element_type == target_element_type: - node.op_type = "Identity" - del node.input[1] - return [node] - - node.op_type = "Cast" - del node.input[1] - del node.attribute[:] - node.attribute.append(onnx.helper.make_attribute("to", target_element_type)) - return [node] - - -@register("Shape") -def shape(context: IRContext, node: onnx.NodeProto): - shape = context.input_shape(node, 0) - if shape is None: - return None - start = get_node_attr_value(node, "start", 0) - end = get_node_attr_value(node, "end", None) - shape_slice = shape.dim[start:end] - if all(d.HasField("dim_value") for d in shape_slice): - return np.array([d.dim_value for d in shape_slice], dtype=np.int64) - return None - - -@register("Size") -def size(context: IRContext, node: onnx.NodeProto): - type = context.input_type(node, 0) - size = get_size_info(type) if type is not None else None - return size - - -@register("If") -def if_op(context: IRContext, node: onnx.NodeProto): - cond = context.input_const_value(node, 0) - if cond is ir.NotConstant: - # Visitor will recursively visit subgraphs to constant-fold them. - return None - cond = get_bool_value(cond) - if cond is not None: - # cond is a constant-value: inline the branch - branch = "then_branch" if cond else "else_branch" - graph = onnx.helper.get_node_attr_value(node, branch) - - formal_outs = list(graph.output) - actual_outs = node.output - renamings = { - formal.name: actual - for formal, actual in zip(formal_outs, actual_outs) - if actual != "" - } - # TODO: Extend renaming to intermediate values. - - def rename(name): - return renamings.get(name, name) - - for sub_node in graph.node: - # TODO: handle renaming inside subgraphs in nodes - sub_node.input[:] = [rename(name) for name in sub_node.input] - sub_node.output[:] = [rename(name) for name in sub_node.output] - # Avoid name collision. - sub_node.name = f"{node.name}_{sub_node.name}" - - # TODO: we should handle initializers as well! - return list(graph.node) - return None - - -@register("Identity") -def identity(context: IRContext, node: onnx.NodeProto): - input = context.get_input(node, 0) - output = context.get_output(node, 0) - if input is not None and output is not None: - output.symbolic_value = input.name - - -@register("SequenceConstruct") -def sequence_construct( - context: IRContext, node: onnx.NodeProto -) -> Sequence[onnx.NodeProto] | None: - output = context.get_output(node, 0) - if output is not None: - output.symbolic_value = list(node.input) - return None - - -@register("ConcatFromSequence") -def concat_from_sequence( - context: IRContext, node: onnx.NodeProto -) -> Sequence[onnx.NodeProto] | None: - input = context.get_input(node, 0) - attrs = context.convert_attributes(node.attribute) - new_axis = attrs.get("new_axis", 0) - if input is not None and isinstance(input.symbolic_value, list): - if new_axis == 0: - node.op_type = "Concat" - node.input[:] = input.symbolic_value - logger.debug("ConcatFromSequence => Concat: %s", node.input) - for i in range(len(node.attribute)): - if node.attribute[i].name == "new_axis": - del node.attribute[i] - return [node] - return [node] - if new_axis == 1: - # Unsqueeze the inputs with concat axis if new_axis is 1 - axis = attrs.get("axis", None) - assert axis is not None - output = context.get_output(node, 0) - axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] - unsqueeze_nodes = [] - for node_input in input.symbolic_value: - unsqueeze_node = onnx.helper.make_node( - "Unsqueeze", - [node_input, axis_node.output[0]], - [f"{node_input}_unsqueeze"], - ) - unsqueeze_nodes.append(unsqueeze_node) - unsqueeze_outputs = [n.output[0] for n in unsqueeze_nodes] - unsqueeze_nodes = [axis_node, *unsqueeze_nodes] - - # Send unsqueezed outputs to Concat - node.input[:] = unsqueeze_outputs - node.op_type = "Concat" - logger.debug( - "ConcatFromSequence => UnSqueeze %s + Concat %s", - unsqueeze_outputs, - node.input, - ) - for i in range(len(node.attribute)): - if node.attribute[i].name == "new_axis": - del node.attribute[i] - return [*unsqueeze_nodes, node] - return None - - -@register("SplitToSequence") -def split_to_sequence( - context: IRContext, node: onnx.NodeProto -) -> Sequence[onnx.NodeProto] | None: - """Rewriting pattern. - - From - - splits = onnx::SplitToSequence(input, split, axis=axis) - - to - - split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) - splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) - - or - - split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) - splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) - - where number of output tensors in `splits` is statically known. - onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. - This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. - """ - input = context.get_input(node, 0) - split = context.get_input(node, 1) - attrs = context.convert_attributes(node.attribute) - output = context.get_output(node, 0) - - if input is None or split is None or output is None: - return None - - axis = attrs.get("axis", 0) - if input.type is None: - return None - split_dimension_size = get_dim_info(input.type, axis) - if split_dimension_size is None: - return None - - split_value = split.value - if split_value is None or split_value is ir.NotConstant: - return None - assert isinstance(split_value, np.ndarray) - - if split_value.ndim == 0: - # split into chunks all of size 'split' if possible. - num_outputs = math.ceil(split_dimension_size / split_value.item()) - split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_node = onnx.helper.make_node( - "Split", - [input.name], - split_outputs, - axis=axis, - num_outputs=num_outputs, - ) - else: - # split into 'size(split)' chunks - num_outputs = split_value.size - split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_node = onnx.helper.make_node( - "Split", - [input.name, split.name], - split_outputs, - axis=axis, - ) - - keepdims = attrs.get("keepdims", 1) - squeeze_nodes = [] - if keepdims == 0: - # squeeze the split dimension if keepdims is 0 - axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] - for i in range(num_outputs): - squeeze_node = onnx.helper.make_node( - "Squeeze", - [split_outputs[i], axis_node.output[0]], - [f"{split_outputs[i]}_squeeze"], - ) - squeeze_nodes.append(squeeze_node) - split_outputs = [n.output[0] for n in squeeze_nodes] - squeeze_nodes = [axis_node, *squeeze_nodes] - - node.op_type = "SequenceConstruct" - node.input[:] = split_outputs - del node.attribute[:] - logger.debug( - "SplitToSequence => Split %s + SequenceConstruct %s", - split_node.input, - node.input, - ) - return [split_node, *squeeze_nodes, node] - - -@register("SequenceAt") -def sequence_at(context: IRContext, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - input = context.get_input(node, 0) - position = context.get_input(node, 1) - output = context.get_output(node, 0) - if input is not None and position is not None: - input_vals = input.symbolic_value - position_val = position.value - if isinstance(input_vals, list) and position_val is not None: - output.symbolic_value = input_vals[position_val] - logger.debug("SequenceAt %s => %s", input, output.symbolic_value) - new_node = onnx.helper.make_node( - "Identity", [output.symbolic_value], [output.name] - ) - return [new_node] - return None diff --git a/onnxscript/optimizer/fold_constants_v0.py b/onnxscript/optimizer/fold_constants_v0.py deleted file mode 100644 index 556f824b8b..0000000000 --- a/onnxscript/optimizer/fold_constants_v0.py +++ /dev/null @@ -1,248 +0,0 @@ -from __future__ import annotations - -from typing import Any, Sequence - -import numpy as np -import onnx -import onnx.reference.ops - -# Excluded ops include -# * Random ops, which are not deterministic -# * Control flow ops - -excluded_ops = frozenset( - { - "RandomUniform", - "RandomNormal", - "RandomUniformLike", - "RandomNormalLike", - "Multinomial", - "If", - "Loop", - "Scan", - "SequenceMap", - } -) - -onnx_domain = frozenset({"", "onnx.ai"}) - - -def get_evaluator(domain: str, op: str, version: int) -> callable | None: - if op in excluded_ops and domain in onnx_domain: - return None - try: - op_impl_class = onnx.reference.ops.load_op(domain, op, version) - except Exception: - return None - else: - return op_impl_class.eval - - -def convert_attributes(attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]: - return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} - - -def is_control_flow_op(node: onnx.NodeProto) -> bool: - return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute) - - -def is_constant_op(node: onnx.NodeProto) -> bool: - return node.op_type == "Constant" and node.domain == "" - - -def get_bool_value(val) -> bool | None: - if isinstance(val, bool): - return val - if isinstance(val, np.bool_): - return bool(val) - if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: - return val.item(0) - return None - - -def get_shape_info(type: onnx.TypeProto) -> tuple[int, ...] | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): - return np.array([d.dim_value for d in type.tensor_type.shape.dim], dtype=np.int64) - return None - - -def get_element_type(type: onnx.TypeProto) -> int | None: - if type.HasField("tensor_type"): - return type.tensor_type.elem_type - return None - - -class State: - def __init__(self, default_value) -> None: - self.scopes = [{}] - self.default_value = default_value - - def lookup(self, name: str) -> Any: - for scope in reversed(self.scopes): - if name in scope: - return scope[name] - return self.default_value - - def bind(self, name: str, value: Any) -> None: - self.scopes[-1][name] = value - - def enter_scope(self) -> None: - self.scopes.append({}) - - def exit_scope(self) -> None: - self.scopes.pop() - - -def is_onnx_op(node: onnx.NodeProto, op: str) -> bool: - return (node.op_type == op) and (node.domain in onnx_domain) - - -def matches(node: onnx.NodeProto, op: str, *arg_predicates) -> bool: - if node.op_type != op or node.domain != "": - return False - if len(node.input) < len(arg_predicates): - return False - return all(pred(input) for pred, input in zip(arg_predicates, node.input)) - - -def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto: - type = onnx.TypeProto() - type.tensor_type.elem_type = initializer.data_type - dims = type.tensor_type.shape.dim - for dim in initializer.dims: - dims.add().dim_value = dim - return type - - -def fold_constants(model: onnx.ModelProto): - not_constant = object() - var_info = State(default_value=not_constant) - type_info = State(default_value=None) - counts = {} - sizes = {} - - def add_count(op: str, size: int = 1): - counts[op] = counts.get(op, 0) + 1 - sizes[op] = sizes.get(op, 0) + size - - def new_constant(name, value): - var_info.bind(name, value) - tensor = onnx.numpy_helper.from_array(value, name=name) - node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor) - return node - - def lookup_version(domain: str, op: str) -> int: - for opset in model.opset_import: - if opset.domain == domain: - return opset.version - return 1 # TODO - - def transform_node(node: onnx.NodeProto): - if is_onnx_op(node, "Transpose"): - return [node] - if is_onnx_op(node, "CastLike"): - value = var_info.lookup(node.input[0]) if len(node.input) > 0 else not_constant - if value is not_constant: - return [node] - type = type_info.lookup(node.input[1]) if len(node.input) > 1 else None - element_type = get_element_type(type) if type is not None else None - if element_type is None: - return [node] - evaluator = get_evaluator("", "Cast", lookup_version("", "Cast")) - if evaluator is None: - return [node] - cast_value = evaluator(value, to=element_type) - add_count("CastLike", cast_value.size) - return [new_constant(node.output[0], cast_value)] - if is_onnx_op(node, "Shape"): - type = type_info.lookup(node.input[0]) if len(node.input) > 0 else None - shape = get_shape_info(type) if type is not None else None - if shape is not None: - add_count("Shape", shape.size) - return [new_constant(node.output[0], shape)] - - if is_onnx_op(node, "If"): - cond = var_info.lookup(node.input[0]) if len(node.input) > 0 else None - cond = get_bool_value(cond) - if cond is not None: - # cond is a constant-value: inline the branch - branch = "then_branch" if cond else "else_branch" - graph = onnx.helper.get_node_attr_value(node, branch) - formal_outs = list(graph.output) - actual_outs = node.output - renamings = { - formal.name: actual - for formal, actual in zip(formal_outs, actual_outs) - if actual != "" - } - - def rename(name): - return renamings.get(name, name) - - for node in graph.node: - node.input[:] = [rename(name) for name in node.input] - node.output[:] = [rename(name) for name in node.output] - transform_graph(graph) - add_count("If") - return list(graph.node) - - if is_control_flow_op(node): - for attr in node.attribute: - if attr.HasField("g"): - transform_graph(attr.g) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - transform_graph(graph) - return [node] - - domain = node.domain - op = node.op_type - version = lookup_version(domain, op) - inputs = [] - for x in node.input: - if x == "": - inputs.append(None) - else: - v = var_info.lookup(x) - if v is not_constant: - return [node] - inputs.append(v) - evaluator = get_evaluator(domain, op, version) - if evaluator is None: - return [node] - attrs = convert_attributes(node.attribute) - outputs = evaluator(*inputs, **attrs) - if len(node.output) == 1 and not isinstance(outputs, tuple): - replacement = new_constant(node.output[0], outputs) - if is_constant_op(node): - return [node] - add_count(op, outputs.size) - return [replacement] - else: - add_count(op) - return [new_constant(output, outputs[i]) for i, output in enumerate(node.output)] - - def transform_graph(graph: onnx.GraphProto): - var_info.enter_scope() - type_info.enter_scope() - for initializer in graph.initializer: - array = onnx.numpy_helper.to_array(initializer) - var_info.bind(initializer.name, array) - type_info.bind(initializer.name, get_initializer_type(initializer)) - for input in graph.input: - var_info.bind(input.name, not_constant) - type_info.bind(input.name, input.type) - for valueinfo in graph.value_info: - type_info.bind(valueinfo.name, valueinfo.type) - - replacement = [transform_node(node) for node in graph.node] - flattened = [node for nodes in replacement for node in nodes] - del graph.node[:] - graph.node.extend(flattened) - var_info.exit_scope() - type_info.exit_scope() - - transform_graph(model.graph) - for op in counts: - print(f"Constant-folded '{op}' {counts[op]} times, with {sizes[op]} size.") diff --git a/onnxscript/optimizer/function_folding_test.py b/onnxscript/optimizer/function_folding_test.py deleted file mode 100644 index 296048a442..0000000000 --- a/onnxscript/optimizer/function_folding_test.py +++ /dev/null @@ -1,192 +0,0 @@ -import unittest - -import onnx - -import onnxscript.testing -from onnxscript import optimizer - - -class FunctionFoldingTest(unittest.TestCase): - def test_identity(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x1, bool cond1) => (float[N] z1) { - z1 = local.fun1(x1, cond1) - } - - fun1 (x, cond) => (z) { - t = Identity(x) - t2 = Identity(t) - t3 = If (cond) < - then_branch = then_graph() => (t4) { - t5 = Identity(t2) - t4 = Identity(t5) - }, - else_branch = else__graph() => (t6) { - t7 = Identity(t) - t6 = Identity(t7) - } - > - t4 = Add(t3, t3) - z = Identity(t4) - } - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, - ) - self.assertEqual(len(optimized.functions), 0) - self.assertEqual(len(optimized.graph.node), 2) - - def test_sequence_concat(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x1) => (float[M] z1) { - z1 = local.fun1(x1) - } - - fun1 (x) => (z) { - t0 = Add (x, x) - t2 = Add (x, x) - t3 = SequenceConstruct (x, t0, t2, x) - z = ConcatFromSequence (t3) - } - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, - ) - function_node = optimized.functions[0].node - self.assertEqual(len(function_node), 3) - self.assertEqual(function_node[2].op_type, "Concat") - - def test_sequence_at(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[M] z) { - t0 = Add (x, x) - t1 = Mul (x, x) - s = SequenceConstruct (x, t0, t1) - one = Constant () - z = SequenceAt (s, one) - } - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, - ) - expected = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[M] z) { - t0 = Add (x, x) - z = Identity (t0) - } - """ - ) - onnxscript.testing.assert_isomorphic_graph(optimized.graph, expected.graph) - - def test_single_user_function_is_modified_inplace_after_folding(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x1) => (float[M] z1) { - z1 = local.fun1(x1) - } - - fun1 (x) => (z) { - t0 = Add (x, x) - t2 = Add (x, x) - t3 = SequenceConstruct (x, t0, t2, x) - z = ConcatFromSequence (t3) - } - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, - ) - self.assertEqual(optimized.functions[0].name, "fun1") - - def test_multi_users_function_is_not_modified_inplace_after_folding(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x1) => (float[M] z1, float[M] z2) { - z1 = local.fun1(x1) - z2 = local.fun1(x1) - } - - fun1 (x) => (z) { - t0 = Add (x, x) - t2 = Add (x, x) - t3 = SequenceConstruct (x, t0, t2, x) - z = ConcatFromSequence (t3) - } - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, - ) - self.assertEqual(len(optimized.functions), 2) - self.assertNotEqual(optimized.functions[0].name, "fun1") - self.assertNotEqual(optimized.functions[1].name, "fun1") - - def test_fold_nested_if_function_succeeds(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 9, - opset_import: ["this" : 1, "" : 21] -> -func (float[1,512] x, float[1,512] y) => ( out) { - out = this.foldable_func (x, y) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable_func (x, y) => (z_6) -{ - cond = Constant () - z_6 = If (cond) ( z_2) { - cond_0 = Not (cond) - z_2 = If (cond_0) ( z) { - z = Add (x, x) - }, else_branch: graph = elseGraph_5 () => ( z_1) { - z_1 = Identity (x) - }> - }, else_branch: graph = elseGraph_4 () => ( z_5) { - z_5 = If (cond) ( z_3) { - z_3 = Add (y, y) - }, else_branch: graph = elseGraph_10 () => ( z_4) { - z_4 = Add (x, y) - }> - }> -} - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - ) - - self.assertEqual(len(optimized.functions), 0) - self.assertEqual(len(optimized.graph.node), 1) - self.assertNotIn("If", {n.op_type for n in optimized.graph.node}) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/optimizer/remove_unused.py b/onnxscript/optimizer/remove_unused.py deleted file mode 100644 index 57357f3dbe..0000000000 --- a/onnxscript/optimizer/remove_unused.py +++ /dev/null @@ -1,127 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Sequence - -import onnx -from google.protobuf.internal.containers import ( # type: ignore - RepeatedCompositeFieldContainer, -) - -logger = logging.getLogger(__name__) - - -def remove_unused_optional_outputs( - n: onnx.NodeProto, used: set, opset_import: Sequence[onnx.OperatorSetIdProto] -) -> None: - try: - if n.domain not in {"", "onnx.ai"}: - return - onnx_opset_version = 1 - for opset in opset_import: - if opset.domain == n.domain: - onnx_opset_version = opset.version - op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain) - except Exception: - return - # TODO: If current node is a BatchNormalization node, - # based on training_mode atrribute, number of optional outputs and - # how they are handled varies, handle both training_modes - if n.op_type == "BatchNormalization": - return - optional_info = [] - for o in op_schema.outputs: - # Current ops do not have optional outputs if they have variable number of outputs - if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic: - return - optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional) - # If no optional outputs in spec, skip delete operations - if len([o == 1 for o in optional_info]) == 0: - return - - for i, out in enumerate(n.output): - if out not in used and optional_info[i] is True: - n.output[i] = "" - # Only delete trailing unused optional outputs - for o in n.output[::-1]: # type: ignore[assignment] - if o == "": - n.output.pop() - else: - return - - -def compute_used_in_node(n: onnx.NodeProto) -> set[str]: - used = {n for n in n.input if n != ""} - for attr in n.attribute: - if attr.HasField("g"): - used |= compute_used_in_graph(attr.g) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - used |= compute_used_in_graph(graph) - return used - - -def compute_used_in_graph(g: onnx.GraphProto) -> set[str]: - used = set() - for n in g.node: - used |= compute_used_in_node(n) - return used - - -def process_nodes( - nodes: RepeatedCompositeFieldContainer[onnx.NodeProto], - used: set, - opset_import: Sequence[onnx.OperatorSetIdProto], -) -> int: - count = 0 - i = len(nodes) - 1 - while i >= 0: - node = nodes[i] - remove_unused_optional_outputs(node, used, opset_import) - used_outputs = [x for x in node.output if x in used] - if not used_outputs: - del nodes[i] - count += 1 - i -= 1 - continue - for attr in node.attribute: - if attr.HasField("g"): - process_graph(attr.g, opset_import) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - process_graph(graph, opset_import) - used |= compute_used_in_node(node) - i -= 1 - return count - - -def process_graph( - graph: onnx.GraphProto, opset_import: Sequence[onnx.OperatorSetIdProto] -) -> int: - used = {output.name for output in graph.output} - - count = process_nodes(graph.node, used, opset_import) - - for i in range(len(graph.initializer) - 1, -1, -1): - if graph.initializer[i].name not in used: - del graph.initializer[i] - count += 1 - - return count - - -def process_function( - function: onnx.FunctionProto, opset_import: Sequence[onnx.OperatorSetIdProto] -) -> int: - used = set(function.output) - - return process_nodes(function.node, used, opset_import) - - -def remove_unused_nodes(model: onnx.ModelProto) -> None: - """Removes unused nodes from the model.""" - count = process_graph(model.graph, model.opset_import) - for function in model.functions: - count += process_function(function, model.opset_import) - - logger.info("Removed %s unused nodes", count) diff --git a/onnxscript/optimizer/remove_unused_function.py b/onnxscript/optimizer/remove_unused_function.py deleted file mode 100644 index 573dfaa8b1..0000000000 --- a/onnxscript/optimizer/remove_unused_function.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -import logging - -import onnx -from google.protobuf.internal.containers import ( # type: ignore - RepeatedCompositeFieldContainer, -) - -logger = logging.getLogger(__name__) - - -class UnusedFunctionRemover: - def compute_used_in_node(self, n: onnx.NodeProto) -> set[tuple[str, str]]: - used = {(n.domain, n.op_type)} - for attr in n.attribute: - if attr.HasField("g"): - used |= self.process_graph(attr.g) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - used |= self.process_graph(graph) - if (n.domain, n.op_type) in self._functions: - function = self._functions[(n.domain, n.op_type)] - used |= self.process_function(function) - return used - - def process_nodes( - self, nodes: RepeatedCompositeFieldContainer[onnx.NodeProto] - ) -> set[tuple[str, str]]: - used = set() - for node in nodes: - used |= self.compute_used_in_node(node) - return used - - def process_graph(self, graph: onnx.GraphProto) -> set[tuple[str, str]]: - return self.process_nodes(graph.node) - - def process_function(self, function: onnx.FunctionProto) -> set[tuple[str, str]]: - return self.process_nodes(function.node) - - def process_model(self, model: onnx.ModelProto) -> None: - self._functions = {(f.domain, f.name): f for f in model.functions} - used = self.process_graph(model.graph) - count = 0 - logger.debug("Used function protos: %s", used) - for i in range(len(model.functions) - 1, -1, -1): - if (model.functions[i].domain, model.functions[i].name) not in used: - del model.functions[i] - count += 1 - logger.info("Removed %s unused function protos", count) - logger.debug("Function protos left: %s", [f.name for f in model.functions]) - - -def remove_unused_functions(model: onnx.ModelProto) -> None: - """Removes unused function protos from the model.""" - UnusedFunctionRemover().process_model(model) diff --git a/onnxscript/optimizer/remove_unused_test.py b/onnxscript/optimizer/remove_unused_test.py deleted file mode 100644 index 350808defb..0000000000 --- a/onnxscript/optimizer/remove_unused_test.py +++ /dev/null @@ -1,173 +0,0 @@ -import unittest - -import onnx - -from onnxscript import optimizer - - -class RemoveUnusedTest(unittest.TestCase): - def test_remove_unused_nodes(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - two = Constant () - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "Mul") - - def test_remove_unused_initializers(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - self.assertEqual(len(model.graph.initializer), 1) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "Mul") - self.assertEqual(len(model.graph.initializer), 0) - - def test_partially_used_nodes(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[M] z) { - w1, w2, w3 = Split (x) - z = Mul(w3, w3) - } - """ - ) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 2) - self.assertEqual(model.graph.node[0].op_type, "Split") - - def test_remove_unused_optional_outputs_maxpool(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) { - z, indices = MaxPool (x) - } - """ - ) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(len(model.graph.node[0].output), 2) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(len(model.graph.node[0].output), 1) - - def test_remove_unused_optional_outputs_dropout_in_function(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) - { - z = pkg.custom.afunction (x) - } - - afunction (x) => (z) - { - z, indices = MaxPool (x) - } - """ - ) - self.assertEqual(len(model.functions), 1) - self.assertEqual(len(model.functions[0].node), 1) - self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") - self.assertEqual(len(model.functions[0].node[0].output), 2) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.functions), 1) - self.assertEqual(len(model.functions[0].node), 1) - self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") - self.assertEqual(len(model.functions[0].node[0].output), 1) - - def test_remove_used_optional_outputs_maxpool(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] y, float[1, 1, 5, 5] z) { - y, z = MaxPool (x) - } - """ - ) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(len(model.graph.node[0].output), 2) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(len(model.graph.node[0].output), 2) - - def test_remove_multiple_unused_optional_outputs_layernorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z) { - scale = Constant () - B = Constant () - z, mean, InvStdDev = LayerNormalization(x, scale, B) - } - """ - ) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 3) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 1) - - def test_remove_trailing_unused_optional_outputs_layernorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] mean) { - scale = Constant () - B = Constant () - z, mean, InvStdDev = LayerNormalization(x, scale, B) - } - """ - ) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 3) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 2) - - def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] InvStdDev) { - scale = Constant () - B = Constant () - z, mean, InvStdDev = LayerNormalization(x, scale, B) - } - """ - ) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 3) - optimizer.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 3) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/optimizer/simple_function_folding.py b/onnxscript/optimizer/simple_function_folding.py deleted file mode 100644 index 8b6f6662b0..0000000000 --- a/onnxscript/optimizer/simple_function_folding.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Inlines the function if it only contains very few number of nodes.""" - -from __future__ import annotations - -import logging -from typing import Sequence - -import onnx - -import onnxscript._legacy_ir as ir -from onnxscript._legacy_ir import visitor -from onnxscript.optimizer import remove_unused - -logger = logging.getLogger(__name__) - - -class FunctionInliner(visitor.FunctionCallsiteProtoTransformer): - counts: dict[ir.FunctionId, int] - - def __init__(self, node_count: int) -> None: - super().__init__() - self._node_count = node_count - - def _gather_function_metadata(self, model: onnx.ModelProto) -> None: - super()._gather_function_metadata(model) - self._function_renamer._postfix = "inlined" - - def visit_model(self, model: onnx.ModelProto) -> None: - self.counts = {} - - super().visit_model(model) - - def should_inline_function(self, function: onnx.FunctionProto) -> bool: - return len(function.node) <= self._node_count - - def process_function_node( - self, node: onnx.NodeProto - ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: - # Recursively process sub nodes first. - function_id = (node.domain, node.op_type, getattr(node, "overload", "")) - function = self._functions[function_id] - replacement, new_function = super().process_function_node(node) - function = new_function if new_function else function - - if self.should_inline_function(function): - self.enter_function_scope(function) - sub_scope = self.exit_function_scope(function) - new_nodes = [] - - formal_outs = function.output - actual_outs = node.output - formal_ins = function.input - actual_ins = node.input - # TODO: Potential collision when actual is "". - # formal.name may collide with existing value names. - input_renamings = dict(zip(formal_ins, actual_ins)) - if len(actual_ins) < len(formal_ins): - input_renamings.update(dict.fromkeys(formal_ins[len(actual_ins) :], "")) - output_renamings = { - formal: actual - for formal, actual in zip(formal_outs, actual_outs) - if actual != "" - } - renamings = {**input_renamings, **output_renamings} - - logger.debug("renamings function %s: %s", function.name, renamings) - - def rename(name: str) -> str: - if name == "": - return name - new_name = renamings.get(name) - if new_name is None: - new_name = f"{node.name}_{name}" - logger.debug("renaming %s to %s", name, new_name) - if (ir_value := sub_scope.lookup(name)) is not None: - if ir_value.tensor_shape_proto() is not None and ir_value.type is not None: - ir_value.name = new_name - self.bind(new_name, ir_value) - return new_name - - ref_attrs = {attr.name: attr for attr in node.attribute} - # logger.debug("inlining simple function %s. Ref attrs: %s", function.name, ref_attrs) - - def fill_in_ref(attr: onnx.AttributeProto) -> onnx.AttributeProto: - if attr.ref_attr_name: - new_attr = onnx.AttributeProto() - new_attr.CopyFrom(ref_attrs[attr.ref_attr_name]) - new_attr.name = attr.name - return new_attr - return attr - - def update_graph_attribute( - attr: onnx.AttributeProto, - ) -> onnx.AttributeProto: - if attr.g: - new_attr = onnx.AttributeProto() - new_attr.CopyFrom(attr) - for node in new_attr.g.node: - node.input[:] = [rename(name) for name in node.input] - node.output[:] = [rename(name) for name in node.output] - new_attrs = [] - for attr in node.attribute: - new_attrs.append(update_attribute(attr)) - del node.attribute[:] - node.attribute.extend(new_attrs) - for vi_proto in new_attr.g.input: - vi_proto.name = rename(vi_proto.name) - for vi_proto in new_attr.g.output: - vi_proto.name = rename(vi_proto.name) - return new_attr - return attr - - def update_attribute(attr: onnx.AttributeProto) -> onnx.AttributeProto: - new_attr = fill_in_ref(attr) - new_attr = update_graph_attribute(new_attr) - return new_attr - - for sub_node in function.node: - # logger.debug("inlining simple function. old node: %s", sub_node) - new_node = onnx.NodeProto() - new_node.CopyFrom(sub_node) - new_node.input[:] = [rename(name) for name in new_node.input] - new_node.output[:] = [rename(name) for name in new_node.output] - del new_node.attribute[:] - for attr in sub_node.attribute: - new_node.attribute.append(update_attribute(attr)) - # Avoid name collision. - new_node.name = f"{node.name}_{new_node.name}" - # logger.debug("inlining simple function. new node: %s", new_node) - new_nodes.append(new_node) - - self.counts.setdefault(function_id, 0) - self.counts[function_id] += 1 - - return new_nodes, None - - return replacement, new_function - - -class SelectedFunctionInliner(FunctionInliner): - def __init__(self, functions_to_inline: Sequence[onnx.FunctionProto]): - super().__init__(node_count=0) # node_count unused. - self._functions_to_inline = functions_to_inline - - def should_inline_function(self, function: onnx.FunctionProto) -> bool: - return function in self._functions_to_inline - - -class FindFunctionWithUnusedOutputsVisitor(visitor.ProtoVisitor): - def __init__(self) -> None: - super().__init__() - self._function_with_unused_outputs: dict[ir.FunctionId, onnx.FunctionProto] = {} - self._functions: dict[ir.FunctionId, onnx.FunctionProto] = {} - self._used_nodes: list[onnx.NodeProto] = [] - - def _find_nodes_with_any_unused_output( - self, nodes: Sequence[onnx.NodeProto], used_values: set[str] - ) -> list[onnx.NodeProto]: - target_nodes = [] - for i in range(len(nodes) - 1, -1, -1): - node = nodes[i] - if any(x not in used_values for x in node.output): - # Any unused output means the node is a target node. - target_nodes.append(node) - if all(x not in used_values for x in node.output): - # All unused output means the node is not used at all. - # Hence do not update used_values with the node's inputs. - continue - used_values |= remove_unused.compute_used_in_node(node) - return target_nodes - - def visit_model(self, model: onnx.ModelProto) -> None: - used_values = {output.name for output in model.graph.output} - target_nodes = self._find_nodes_with_any_unused_output(model.graph.node, used_values) - - for function in model.functions: - self._functions[ - (function.domain, function.name, getattr(function, "overload", "")) - ] = function - used_values = set(function.output) - target_nodes.extend( - self._find_nodes_with_any_unused_output(function.node, used_values) - ) - - for node in target_nodes: - if visitor.is_local_function_node(node, self._functions): - function_id = (node.domain, node.op_type, getattr(node, "overload", "")) - self._function_with_unused_outputs[function_id] = self._functions[function_id] - - logger.info( - "Found %s function nodes that have unused outputs.", - len(self._function_with_unused_outputs), - ) - for key in self._function_with_unused_outputs: - logger.info("Function node with unused outputs: %s::%s", key[0], key[1]) - - @property - def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto]: - return self._function_with_unused_outputs - - -def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool: - """Inlines simple functions based on a node count threshold""" - inliner = FunctionInliner(node_count) - inliner.visit_model(model) - logger.info( - "inlined %s simple functions based on node count threshold %s.", - len(inliner.counts), - node_count, - ) - for op in inliner.counts: - logger.info( - "Inlined simple function '%s::%s' %s times.", - op[0], - op[1], - inliner.counts[op], - ) - return inliner.modified - - -def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool: - """Inlines function nodes that have unused outputs.""" - # TODO: Use onnx.inliner after 1.16. - # This visitor based inliner is used to ensure the function inner value info remains consistent. - visitor = FindFunctionWithUnusedOutputsVisitor() - visitor.visit_model(model) - # FIXME: Fix the type of the argument passed into SelectedFunctionInliner - inliner = SelectedFunctionInliner(visitor.function_with_unused_outputs.values()) # type: ignore[arg-type] - inliner.visit_model(model) - logger.info( - "inlined %s function nodes that have unused outputs.", - len(inliner.counts), - ) - for op in inliner.counts: - logger.info( - "Inlined function '%s::%s' %s times.", - op[0], - op[1], - inliner.counts[op], - ) - return inliner.modified diff --git a/onnxscript/optimizer/simple_function_folding_test.py b/onnxscript/optimizer/simple_function_folding_test.py deleted file mode 100644 index df7feaec2b..0000000000 --- a/onnxscript/optimizer/simple_function_folding_test.py +++ /dev/null @@ -1,218 +0,0 @@ -from __future__ import annotations - -import unittest - -import onnx - -from onnxscript.optimizer import remove_unused_function, simple_function_folding - - -class SingleNodeFunctionFoldingTest(unittest.TestCase): - def test_fold_single_node_function(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y) => ( return_val) { - tmp = this.foldable (x) - return_val = Add (tmp, y) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x) => (return_val) -{ - return_val = Identity (x) -} - """ - ) - - simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) - - self.assertEqual(len(model.functions), 0) - - def test_fold_single_node_function_ref_attr(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y, z) => ( return_val) { - tmp = this.foldable (x, y) - return_val = Add (tmp, z) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x, y) => (return_val) -{ - return_val = Concat (x, y) -} - """ - ) - - simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) - - self.assertEqual(len(model.functions), 0) - self.assertFalse(model.graph.node[0].attribute[0].ref_attr_name) - self.assertEqual(model.graph.node[0].attribute[0].name, "axis") - - def test_fold_single_node_function_nested(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y, z) => ( return_val) { - tmp = this.non_foldable (x, y) - return_val = Add (tmp, z) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x, y) => (return_val) -{ - return_val = Concat (x, y) -} -< - domain: "this", - opset_import: ["this" : 1,"" : 18] -> -non_foldable (x, y) => (return_val) -{ - tmp = this.foldable (x, y) - tmp_0 = this.foldable (x, y) - return_val = Add (tmp, tmp_0) -} - """ - ) - - simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) - - self.assertEqual(len(model.functions), 1) - self.assertEqual(model.functions[0].node[0].op_type, "Concat") - self.assertEqual(model.functions[0].node[1].op_type, "Concat") - - def test_fold_single_node_function_create_new_nodes_with_correct_attributes(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 9, - opset_import: ["this" : 1, "" : 21] -> -func (float[1,512] x) => ( a, b, c) { - a = this.prim_cast (x) - b = this.prim_cast (x) - c = this.prim_cast (x) -} -< - domain: "this", - opset_import: ["" : 18] -> -prim_cast (x) => (return_val) -{ - return_val = Cast (x) -} - """ - ) - simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) - self.assertEqual(len(model.functions), 0) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[0].attribute[0].i, 10) - self.assertEqual(model.graph.node[1].attribute[0].i, 6) - self.assertEqual(model.graph.node[2].attribute[0].i, 7) - - def test_fold_nested_if_function_succeeds(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 9, - opset_import: ["this" : 1, "" : 21] -> -func (float[1,512] x, float[1,512] y) => ( out) { - out = this.foldable_func (x, y) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable_func (x, y) => (z_6) -{ - cond = Constant () - z_6 = If (cond) ( z_2) { - cond_0 = Not (cond) - z_2 = If (cond_0) ( z) { - z = Add (x, x) - }, else_branch: graph = elseGraph_5 () => ( z_1) { - z_1 = Identity (x) - }> - }, else_branch: graph = elseGraph_4 () => ( z_5) { - z_5 = If (cond) ( z_3) { - z_3 = Add (y, y) - }, else_branch: graph = elseGraph_10 () => ( z_4) { - z_4 = Add (x, y) - }> - }> -} - """ - ) - - simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) - - self.assertEqual(len(model.functions), 0) - self.assertEqual(len(model.graph.node), 2) - self.assertEqual(model.graph.node[1].op_type, "If") - - def test_fold_function_with_unused_output(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y, z) => ( return_val) { - tmp = this.non_foldable (x, y) - return_val = Add (tmp, z) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x, y) => (return_val, unused, unused1) -{ - return_val = Concat (x, y) - unused = Identity (x) - unused1 = Identity (y) -} -< - domain: "this", - opset_import: ["this" : 1,"" : 18] -> -non_foldable (x, y) => (return_val) -{ - tmp, unused, unused1 = this.foldable (x, y) - tmp_0, unused2, unused3 = this.foldable (x, y) - return_val = Add (tmp, tmp_0) -} - """ - ) - - simple_function_folding.inline_functions_with_unused_outputs(model) - remove_unused_function.remove_unused_functions(model) - self.assertEqual(len(model.functions), 1) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 7dc7846506..fc000dc176 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -1,43 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from typing import Sequence, TypeVar, Union __all__ = [ - # Modules - "function_rule", "pattern", - # Functions "rewrite", + "RewritePass", + "MatchResult", + "MatchContext", + "RewriteRule", + "RewriteRuleClassBase", + "RewriteRuleSet", + "RewriterContext", + "MatchingTracer", + "MatchStatus", ] import onnx +import onnx_ir.passes.common as common_passes from onnxscript import ir -from onnxscript.optimizer import remove_unused, remove_unused_function -from onnxscript.rewriter import function_rule, pattern +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus +from onnxscript.rewriter._rewrite_rule import ( + RewriterContext, + RewriteRule, + RewriteRuleClassBase, + RewriteRuleSet, +) +from onnxscript.rewriter.rules.common import ( + _basic_rules, + _broadcast_to_matmul, + _cast_constant_of_shape, + _collapse_slices, + _fuse_batchnorm, + _fuse_pad_into_conv, + _fuse_relus_clips, + _min_max_to_clip, + _no_op, + _redundant_scatter_nd, +) -RewriteRuleSet = pattern.RewriteRuleSet -PatternRewriteRule = pattern.RewriteRule -FunctionRewriteRule = function_rule.FunctionRewriteRule +_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) +_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( + *_no_op.rules, # TODO: merge this rule into constant folding? + *_broadcast_to_matmul.rules, + *_cast_constant_of_shape.rules, + *_collapse_slices.rules, + *_min_max_to_clip.rules, + *_fuse_relus_clips.rules, + *_basic_rules.basic_optimization_rules(), + *_redundant_scatter_nd.rules, + *_fuse_pad_into_conv.rules, + *_fuse_batchnorm.rules, +) -def rewrite( - model: onnx.ModelProto, - function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (), - pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (), -) -> onnx.ModelProto: - model_ir = ir.serde.deserialize_model(model) - if function_rewrite_rules: - for rule_cls in function_rewrite_rules: - count, model_ir = rule_cls().apply_to_model(model_ir) - print(f"Applied {count} of onnxruntime specific function rewrite rules.") - if pattern_rewrite_rules: - if not isinstance(pattern_rewrite_rules, RewriteRuleSet): +class RewritePass(ir.passes.InPlacePass): + def __init__( + self, + rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet, + /, + ) -> None: + super().__init__() + if isinstance(rules, Sequence): + if not rules: + raise ValueError("rules must not be empty") # Create a pattern rule-set using provided rules - pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules) - count = pattern_rewrite_rules.apply_to_model(model_ir) - print(f"Applied {count} of general pattern rewrite rules.") - model = ir.serde.serialize_model(model_ir) - remove_unused.remove_unused_nodes(model) - remove_unused_function.remove_unused_functions(model) - return model + rules = pattern.RewriteRuleSet(rules) + assert isinstance(rules, pattern.RewriteRuleSet) + self.rules: pattern.RewriteRuleSet = rules + + def call(self, model: ir.Model) -> ir.passes.PassResult: + count = self.rules.apply_to_model(model) + if count: + print(f"Applied {count} of general pattern rewrite rules.") + return ir.passes.PassResult(model, bool(count)) + + +def rewrite( + model: _ModelProtoOrIr, + pattern_rewrite_rules: Union[Sequence[pattern.RewriteRule], pattern.RewriteRuleSet] + | None = None, +) -> _ModelProtoOrIr: + """Rewrite the model using the provided pattern rewrite rules. + + Unused nodes, functions, and opsets will be removed after the rewrite. + + Args: + model: The model to be rewritten. Can be an ONNX ModelProto or an ir.Model. + pattern_rewrite_rules: A sequence of pattern rewrite rules or a RewriteRuleSet. + If not provided, default rules will be applied. If empty, no rules will be applied + and the original model will be returned. + + Returns: + The rewritten model as the same type as the input model. + """ + if pattern_rewrite_rules is None: + pattern_rewrite_rules = _DEFAULT_REWRITE_RULES + elif not pattern_rewrite_rules: + return model + + if isinstance(model, onnx.ModelProto): + model_ir = ir.serde.deserialize_model(model) + proto = True + else: + model_ir = model + proto = False + + rewrite_pass = ir.passes.PassManager( + ( + RewritePass(pattern_rewrite_rules), + common_passes.RemoveUnusedNodesPass(), + common_passes.RemoveUnusedFunctionsPass(), + common_passes.RemoveUnusedOpsetsPass(), + ) + ) + model_ir = rewrite_pass(model_ir).model + if proto: + return ir.serde.serialize_model(model_ir) + return model_ir # type: ignore[return-value] diff --git a/onnxscript/rewriter/_basics.py b/onnxscript/rewriter/_basics.py new file mode 100644 index 0000000000..9b66ff49e6 --- /dev/null +++ b/onnxscript/rewriter/_basics.py @@ -0,0 +1,476 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Basic types for the pattern matching and rewriter API.""" + +from __future__ import annotations + +import dataclasses +import enum +from collections import defaultdict +from typing import TYPE_CHECKING, Any, MutableSequence, Sequence, Union + +from onnxscript import ir + +if TYPE_CHECKING: + import onnxscript.rewriter._pattern_ir as _pattern_ir + import onnxscript.rewriter._rewrite_rule as _rewrite_rule + + +class MatchFailureInfo: + """Encapsulates information about a pattern match failure.""" + + def __init__( + self, + reason: str = "", + *failure_source: ir.Node | ir.Value, + ): + self.reason = reason + self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source + assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), ( + f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}" + ) + + def __str__(self): + return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})" + + +class MatchFailureError(MatchFailureInfo, Exception): + """Exception raised when a pattern match fails. + + This makes it easier to handle match failures in a compositional way, + for example, during the condition-checking phase of a pattern match. + It allows us to define utility functions without having to check for + and propagate match failures explicitly. + """ + + def __init__( + self, + reason: str = "", + *failure_source: ir.Node | ir.Value, + ): + MatchFailureInfo.__init__(self, reason, *failure_source) + Exception.__init__(self, reason) + + +class MatchResult: + """The state object used by the pattern-matching algorithm. + + A match can either succeed or fail. + If it succeeds, it returns a list of nodes that matched the pattern + and a set of bindings for the variables in the pattern. + + Example: + :: + def pattern(x, shape1, shape2): + t1 = op.Reshape(x, shape1) + t2 = op.Reshape(t1, shape2) + return t2 + The above pattern matches a sequence of two Reshape ops. + The matched_nodes will contain the two Reshape ops, and the bindings will + contain the values that are bound to the variables `x`, `shape1`, and `shape2`. + """ + + def __init__(self) -> None: + # We use a stack of partial matches to handle OR patterns that require backtracking. + self._partial_matches: list[PartialMatchResult] = [PartialMatchResult()] + + def __repr__(self) -> str: + """Returns a string representation of the match result.""" + if not self._partial_matches: + return "MatchResult()" + return ( + f"MatchResult(success={bool(self)}, reason={self.reason!r}, nodes={self.nodes!r})" + ) + + @property + def _current_match(self) -> PartialMatchResult: + """Returns the current match result.""" + return self._partial_matches[-1] + + def enter_new_match(self) -> None: + """Starts a new sub-match to try out one of multiple alternatives.""" + match = PartialMatchResult() + self._partial_matches.append(match) + + def abandon_current_match(self) -> PartialMatchResult: + """Abandons the current alternative due to failure.""" + if len(self._partial_matches) < 2: + raise ValueError("No match to abandon.") + return self._partial_matches.pop() + + def merge_current_match(self) -> None: + """Merges a successful sub-match for an alternative with the parent one.""" + if len(self._partial_matches) < 2: + raise ValueError("No match to merge.") + current_match = self._partial_matches.pop() + previous_match = self._partial_matches[-1] + if not current_match: + raise ValueError("Current match is not successful.") + # Merge the two matches. + previous_match.merge(current_match) + + def __bool__(self) -> bool: + """Returns True if the current match is successful.""" + return bool(self._current_match) + + def fail( + self, + reason: str = "", + failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, + ) -> MatchResult: + self._current_match.fail(reason, failure_source) + return self + + @property + def reason(self) -> str: + """Returns the reason for the failure.""" + return self._current_match.reason + + @property + def nodes(self) -> Sequence[ir.Node]: + """Returns the list of nodes that matched the pattern.""" + return self._current_match.nodes + + def bind_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node): + """Binds a pattern node to a matched node.""" + self.add_node(node) + self._current_match.node_bindings[pattern_node] = node + + def add_node(self, node: ir.Node) -> None: + """Adds a node to the list of matched nodes.""" + self._current_match.add_node(node) + + def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool: + var_name = pattern_value.name + # TODO(rama): Simplify the following. We currently bind values to + # pattern variables in two different ways: via their name, or via the + # pattern-value itself. + if var_name is None: + for match in self._partial_matches: + if pattern_value in match.value_bindings: + # TODO(rama): Use appropriate equality-check here. + if match.value_bindings[pattern_value] == value: + return True + self._current_match.fail( + f"Binding failure: {pattern_value} bound to two different values.", + [match.value_bindings[pattern_value], value], + ) + return False + self._current_match.value_bindings[pattern_value] = value + return True + return self.bind(var_name, value) + + def bind(self, var: str, value: Any) -> bool: + for match in self._partial_matches: + if var in match.bindings: + # TODO(rama): Use appropriate equality-check here. + if match.bindings[var] == value: + return True + self._current_match.fail( + f"Binding failure: {var} bound to two different values.", + [match.bindings[var], value], + ) + return False + self._current_match.bindings[var] = value + return True + + @property + def bindings(self) -> dict[str, Any]: + """Returns the bindings for the pattern variables.""" + if len(self._partial_matches) > 1: + raise ValueError("Bindings can be accessed only at the top-level match.") + return self._current_match.bindings + + @property + def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]: + """Returns the bindings for the value variables.""" + if len(self._partial_matches) > 1: + raise ValueError("Value bindings can be accessed only at the top-level match.") + return self._current_match.value_bindings + + @property + def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]: + """Returns the bindings for the node variables.""" + if len(self._partial_matches) > 1: + raise ValueError("Node bindings can be accessed only at the top-level match.") + return self._current_match.node_bindings + + @property + def outputs(self) -> MutableSequence[ir.Value]: + """Returns the list of output values that matched the pattern.""" + if len(self._partial_matches) > 1: + raise ValueError("Outputs can be accessed only at the top-level match.") + return self._current_match.outputs + + @property + def failure_nodes_and_values(self) -> list[Union[ir.Node, ir.Value]]: + """Returns the nodes and values that caused the failure.""" + return self._current_match._failure_nodes_and_values + + def lookup_node(self, pattern_node: _pattern_ir.NodePattern) -> ir.Node | None: + """Looks up the node that matched the given pattern node.""" + for match in self._partial_matches: + if pattern_node in match.node_bindings: + return match.node_bindings[pattern_node] + return None + + def num_matched_nodes(self) -> int: + """Returns the number of nodes matched so far.""" + return sum(len(match.node_bindings) for match in self._partial_matches) + + +class PartialMatchResult: + """The state object used by the pattern-matching algorithm for a sub-match.""" + + def __init__(self) -> None: + self._success: bool = True + # For a successful match, _matched_nodes is a list of values that matched the pattern. + # These include the internal nodes of the pattern that were matched, but not + # the leaves (sub-trees) that match against the variables in the pattern. + # These represent the values that will be replaced by the replacement pattern. + self._matched_nodes: MutableSequence[ir.Node] = [] + # For a successful match, bindings is a dictionary of mapping pattern-variable-names + # to values. + self._bindings: dict[str, Any] = {} + self._value_bindings: dict[_pattern_ir.ValuePattern, ir.Value] = {} + self._node_bindings: dict[_pattern_ir.NodePattern, ir.Node] = {} + + self._outputs: list[ir.Value] = [] + # For a failed match, _reason is a string that describes the reason for the failure. + self._reason: str = "" + # Track the node(s) or value(s) that caused the failure. + self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = [] + + def __bool__(self): + return self._success + + def fail( + self, + reason: str = "", + failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, + ) -> None: + self._success = False + self._reason = reason + if failure_source is not None: + if isinstance(failure_source, list): + self._failure_nodes_and_values.extend(failure_source) + else: + self._failure_nodes_and_values.append(failure_source) + + @property + def reason(self) -> str: + return self._reason + + @property + def nodes(self) -> Sequence[ir.Node]: + return tuple(self._matched_nodes) + + def add_node(self, node: ir.Node) -> None: + """Adds a node to the list of matched nodes.""" + self._matched_nodes.append(node) + + @property + def bindings(self) -> dict[str, Any]: + return self._bindings + + @property + def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]: + return self._value_bindings + + @property + def outputs(self) -> MutableSequence[ir.Value]: + return self._outputs + + @property + def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]: + return self._node_bindings + + def merge(self, other: PartialMatchResult) -> None: + """Merges a successful sub-match for an alternative with the parent one.""" + if self._success and other._success: + # Merge the two successful matches. Matching algorithm responsible for ensuring + # that the two matches are compatible. No need to check for conflicts here. + self._bindings.update(other._bindings) + self._matched_nodes.extend(other.nodes) + # Note: outputs should be set only at end of the (top-level) match. There + # should be no outputs in the sub-match. + assert not other._outputs + else: + # This should not happen currently. + raise NotImplementedError("Merging failed matches is not yet supported.") + + +class MatchStatus(enum.IntEnum): + """The status of a pattern-matching operation.""" + + NO_MATCH = 0 # No successful match found for entire pattern graph + CONDITION_FAILED = 1 # Subsequent validation check failed + REPLACEMENT_FAILED = 2 # Replacement subgraph could not be created + SUCCESS = 3 # A successful match was found + + +@dataclasses.dataclass +class MatchInfo: + """The status of a pattern-matching operation. An extension of MatchResult.""" + + match_result: MatchResult + root_node: ir.Node + container: ir.Graph | ir.Function + status: MatchStatus + + def score(self) -> int: + """Return a score for the match.""" + return len(self.match_result.nodes) + int(self.status.value) * 100 + + def print(self): + separator = "-" * 80 + print(separator) + print(f"Status: {self.status.name}") + if self.status != MatchStatus.SUCCESS: + reason = self.match_result.reason + if reason: + if self.status == MatchStatus.CONDITION_FAILED: + print(f"Graph matching failed due to failing check condition : {reason}") + else: + print(f"Graph matching failed: {reason}") + else: + print("Graph matching failed.") + failure_nodes_and_values = self.match_result.failure_nodes_and_values + print("Failure at or around nodes/values:") + if failure_nodes_and_values: + for failure_cause in failure_nodes_and_values: + failure_cause.display() + print("Matched nodes:") + import onnxscript.rewriter._ir_utils as ir_utils + + ir_utils.display_nodes(self.match_result.nodes) + print(separator) + + +class MatchContext: + """A read-only context containing information about a pattern match. + + This class captures information about the context describing a match to a given pattern, + providing access to the model, graph/function, root node, output values, and all + nodes of the matching subgraph. + """ + + def __init__( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + root: ir.Node, + match_result: MatchResult, + ) -> None: + """Initialize the pattern match context. + + Args: + model: The model being matched. + graph_or_function: The graph or function being matched. + root: The root node of the matching subgraph. + match_result: The match result containing matched nodes and outputs. + """ + self._model = model + self._graph_or_function = graph_or_function + self._root = root + self._match_result = match_result + + @property + def model(self) -> ir.Model: + """The model being matched.""" + return self._model + + @property + def graph_or_function(self) -> ir.Graph | ir.Function: + """The graph or function being matched.""" + return self._graph_or_function + + @property + def root(self) -> ir.Node: + """The root node of the matching subgraph.""" + return self._root + + @property + def output_values(self) -> Sequence[ir.Value]: + """The output values of the matching subgraph.""" + return self._match_result.outputs + + @property + def nodes(self) -> Sequence[ir.Node]: + """All the nodes of the matching subgraph.""" + return self._match_result.nodes + + def display(self, *, in_graph_order: bool = True) -> None: + """Display the nodes in the pattern match context. + + Args: + in_graph_order: If True, display nodes in the order they appear in the + graph/function. If False, display nodes in the order they appear + in the match result. + """ + nodes = self.nodes + if not nodes: + return + + if in_graph_order: + # Display nodes in same order as in graph/function + for node in self._graph_or_function: + if node in nodes: + node.display() + else: + # Display nodes in match order + for node in nodes: + node.display() + + +class MatchingTracer: + """A debugging helper class to trace the matching of a pattern against a graph. + + This is used to track the best matches found for each rule, and to report the + results at the end of the matching. + """ + + def __init__(self) -> None: + self._best_matches_map: dict[_rewrite_rule.RewriteRule, list[MatchInfo]] = defaultdict( + list + ) + + @property + def best_matches_map(self) -> dict[_rewrite_rule.RewriteRule, list[MatchInfo]]: + return self._best_matches_map + + def log( + self, + rule: _rewrite_rule.RewriteRule, + container: ir.Graph | ir.Function, + node: ir.Node, + match_result: MatchResult, + status: MatchStatus, + ) -> None: + this_match = MatchInfo(match_result, node, container, status) + this_score = this_match.score() + if this_score == 0: + return + best_matches = self._best_matches_map[rule] + if best_matches: + if this_score < best_matches[0].score(): + return + if this_score > best_matches[0].score(): + best_matches.clear() + best_matches.append(this_match) + + def report(self) -> None: + best_score = 0 + for rule, matches in self._best_matches_map.items(): + if not matches: + continue + if matches[0].score() > best_score: + best_score = matches[0].score() + best_match = matches[0] + best_rule = rule + + if best_score > 0: + print(f"Rule: {best_rule}") + best_match.print() + else: + print("No matches found.") diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py new file mode 100644 index 0000000000..f6a7204ac8 --- /dev/null +++ b/onnxscript/rewriter/_fusion_utils.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Callable, Sequence, Union + +import onnx_ir as ir +import onnx_ir.passes.common as common_passes + +from onnxscript.rewriter._basics import MatchFailureError, MatchingTracer +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet + +Dim = Union[int, ir.SymbolicDim] + + +def check_shape_bool(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: + if val.shape is None: + return False + if val.shape.rank() != len(shape): + return False + for actual, expected in zip(val.shape, shape): + if expected not in bindings: + bindings[expected] = actual # type: ignore[assignment] + elif actual != bindings[expected]: + return False + return True + + +def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]): + if val.shape is None: + raise MatchFailureError(f"The shape of {val} is unknown.", val) + if val.shape.rank() != len(shape): + raise MatchFailureError( + f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}.", + val, + ) + for i, (actual, expected) in enumerate(zip(val.shape, shape)): + if expected not in bindings: + bindings[expected] = actual # type: ignore[assignment] + elif actual != bindings[expected]: + raise MatchFailureError( + f"Dimension {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).", + val, + ) + + +def apply_fusion_rules(rules: RewriteRule | RewriteRuleSet) -> Callable: + """ + Apply the given fusion rules to the model and return the number of fusions applied. + + model: The input ONNX model represented as an `ir.Model`. + debug: If debug is True, enable pattern matching tracer for debugging. + apply_shape_inference: If True, apply shape inference after fusions. + """ + + def apply_to( + model: ir.Model, debug: bool = False, apply_shape_inference: bool = False, **kwargs + ) -> int: + count = rules.apply_to_model(model, **kwargs) + if apply_shape_inference: + common_passes.ShapeInferencePass()(model) + if count == 0 and debug: + tracer = MatchingTracer() + rules.apply_to_model(model, tracer=tracer, **kwargs) + tracer.report() + return count + + return apply_to diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index b8dd5f45ff..953d5f33d5 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -1,40 +1,178 @@ -"""This is a temporary utility to assist new IR while it's still under development.""" - +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations +import math +from typing import Callable, Sequence + import numpy as np -from onnxscript import ir - -GRAPH_OUTPUT_META_KEY = "pkg.onnxscript.rewriter.generic_pattern.graph_output" - - -def propagate_const_value(ir_value: ir.Value) -> ir.Value: - node = ir_value.producer() - if ir_value.const_value is None and node is not None and node.op_type == "Constant": - attr_names = [ - "value_float", - "value_int", - "value_string", - "value", - "value_floats", - "value_ints", - "value_strings", - ] - for attr_name in attr_names: - attr_value = node.attributes.get(attr_name) - if attr_value is not None: - # TODO: RefAttr should be also supported? - if isinstance(attr_value, ir.Attr): - ir_value.const_value = attr_value.value # type: ignore[union-attr] - break - return ir_value - - -def get_numpy_from_ir_value(value: ir.Value) -> np.ndarray | None: - constant_value = value.const_value - if constant_value is not None: - if isinstance(constant_value, ir.serde.TensorProtoTensor): - return constant_value.numpy() - return np.array(constant_value) - return constant_value +from onnxscript import ir, optimizer + + +def display_nodes(nodes: Sequence[ir.Node]) -> None: + """Display a list of nodes in the order they appear in the graph.""" + if nodes: + graph = nodes[0].graph + if graph: + # Display nodes in same order as in graph: + # Currently doesn't handle (control-flow) subgraphs + for node in graph: + if node in nodes: + node.display() + else: + for node in nodes: + node.display() + + +def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5) -> None: + """Display the (backward or forward) subgraph from a given value or node upto a certain depth.""" + slice = [] + + def visit(node: ir.Node, depth): + if node in slice: + return + slice.append(node) + if depth < depth_limit: + if backward: + for inp in node.inputs: + if inp is not None and inp.producer() is not None: + visit(inp.producer(), depth + 1) # type: ignore[arg-type] + else: + for out in node.outputs: + for consumer, _ in out.uses(): + visit(consumer, depth + 1) + + if isinstance(x, ir.Node): + visit(x, 0) + elif isinstance(x, ir.Value) and x.producer() is not None: + visit(x.producer(), 0) # type: ignore[arg-type] + display_nodes(slice) + + +def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: + node = value.producer() + if node is not None: + optimizer.basic_constant_propagation([node]) + return value.const_value + + +def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: + """Convenience wrapper to get (optional) numpy value from an optional IR Value. + + This is intended for use in optimizations/rewriting. Note that this does not + yet handle the distinction between inputs with default values (values that are + both graph inputs and graph initializers), which should not be treated as a + constant, and true constant values. The caller should make the distinction, as + a value does not contain enough information to determine this. (TODO) + """ + if val is None: + return None + const_value = get_const_value(val) + if const_value is not None: + try: + return const_value.numpy() + except FileNotFoundError: + # External data is not available. + return None + return None + + +def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None): + """Returns element of a single element tensor constant value, and None otherwise. + + If an int rank is specified, it checks that the value has the given rank. + If the rank is a sequence of ints, it checks that the value has one of the given ranks. + + Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and + `rank=(0,1)` checks for either a scalar or a 1D tensor. + """ + np_val = get_numpy_value(val) + if np_val is not None and np_val.size == 1: + value = np_val.item() + if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)): + return value + if isinstance(rank, Sequence) and (np_val.ndim in rank): + return value + return None + + +def is_singleton_value( + val: ir.Value | None, + expected: float | int | Callable, + *, + rtol: float | None = None, + rank: int | Sequence[int] | None = None, +) -> bool: + """Returns True if the value is a single element tensor with given value, and False otherwise.""" + scalar = get_singleton_value(val, rank=rank) + if scalar is None: + return False + if callable(expected): + return expected(scalar) + if isinstance(expected, int): + return expected == scalar + # rtol must be specified for float comparison + assert rtol is not None + return math.isclose(scalar, expected, rel_tol=rtol) + + +def is_1d_value(val: ir.Value | None, expected: list[int]) -> bool: + """Returns True if the value is a 1d int64 tensor with given value, and False otherwise.""" + if val is None: + return False + if not isinstance(val.type, ir.TypeProtocol): + return False + np_val = get_numpy_value(val) + if np_val is None: + return False + if (np_val.size != len(expected)) or (val.type.dtype != ir.DataType.INT64): + return False + values = np_val.tolist() + return values == expected + + +def has_rank(value: ir.Value | None, rank: int) -> bool: + """Returns True if the value is statically known to have the given rank, and False otherwise.""" + if value is None: + return False + shape = value.shape + return (shape is not None) and (shape.rank() == rank) + + +def get_dim(value: ir.Value | None, dim: int) -> ir.SymbolicDim | int | None: + """Returns the value of the given dimension, or None if it is not statically known.""" + if value is None: + return None + shape = value.shape + if shape is None: + return None + if dim < 0: + dim += shape.rank() + if dim < 0 or dim >= shape.rank(): + return None + return shape[dim] + + +def same_shape(shape1: ir.Shape | None, shape2: ir.Shape | None) -> bool: + """Check if two shapes are semantically the same.""" + if shape1 is None or shape2 is None: + return False + + # If any dim is unknown, the shapes are not the same + if shape1.has_unknown_dim() or shape2.has_unknown_dim(): + return False + + return shape1 == shape2 + + +def same_dim(dim1: ir.SymbolicDim | int, dim2: ir.SymbolicDim | int) -> bool: + """Check if two dimensions are semantically the same.""" + if type(dim1) is not type(dim2): + return False + if isinstance(dim1, int) and isinstance(dim2, int): + return dim1 == dim2 + assert isinstance(dim1, ir.SymbolicDim) and isinstance(dim2, ir.SymbolicDim) + if dim1.value is None or dim2.value is None: + return False + return dim1.value == dim2.value diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py new file mode 100644 index 0000000000..e347b98375 --- /dev/null +++ b/onnxscript/rewriter/_matcher.py @@ -0,0 +1,411 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Implementation of the pattern matching algorithm.""" + +from __future__ import annotations + +import abc +import itertools +import math +from typing import ( + Iterable, + Sequence, +) + +import onnxscript.rewriter._basics as _basics +import onnxscript.rewriter._pattern_ir as _pattern_ir +from onnxscript import ir + + +def _valid_to_replace( + matched_nodes: Sequence[ir.Node], output_values: Sequence[ir.Value] +) -> bool: + """Check that values computed by the matched_nodes, except for output_values, are used only by the matched_nodes.""" + # * Must check that all values matched by pattern are used only by pattern, + # except for the value that is replaced. + # * Must ensure that replacement subgraph does not use any of the deleted + # (intermediate) values. (Not necessary for now. Guaranteed.) + for n in matched_nodes: + for v in n.outputs: + if v in output_values: + continue + if v.is_graph_output(): + # value is an output-value of the graph/function. + return False + for consumer, _ in v.uses(): + if consumer not in matched_nodes: + return False + return True + + +class PatternMatcher(abc.ABC): + def __init__(self, pattern: _pattern_ir.GraphPattern) -> None: + self.pattern = pattern + + @abc.abstractmethod + def match( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + *, + verbose: int = 0, + remove_nodes: bool = True, + tracer: _basics.MatchingTracer | None = None, + ) -> _basics.MatchResult: + """Match the pattern against the subgraph ending at the given node.""" + + def __str__(self) -> str: + return str(self.pattern) + + +class SimplePatternMatcher(PatternMatcher): + def __init__(self, pattern: _pattern_ir.GraphPattern) -> None: + super().__init__(pattern) + self._current_node: ir.Node | None = None + + def fail(self, reason: str, node: ir.Node | None = None) -> bool: + if self._verbose: + num_matched_nodes = self._match.num_matched_nodes() + if num_matched_nodes > 0: # Print only if at least one node successfully matched. + print(f"Match failed after {num_matched_nodes} nodes: {reason}") + self._match.fail(reason, node or self._current_node) + return False + + def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Value) -> bool: + """Match a Constant pattern against a value. + + If the constant value is produced by a Constant node, we do not include + the constant node as part of the matched graph. Thus, it will not be deleted, + if subgraph replacement happens. But subsequent DCE will remove the constant + node if it is not used elsewhere. + """ + constant_value = value.const_value + if constant_value is None: + return self.fail( + f"Value {value.name} is not a constant, expecting {pattern_constant.value}.", + ) + + try: + constant_value_numpy = constant_value.numpy() + except FileNotFoundError: + return self.fail(f"Constant value of {value.name} not available.") + + pattern_constant_value = pattern_constant._value + + if isinstance(pattern_constant_value, list): + expected_shape = (len(pattern_constant_value),) + if constant_value_numpy.shape != expected_shape: + return self.fail(f"Value has mismatching shape, expecting {expected_shape}.") + if not all( + math.isclose( + constant_value_numpy.item(i), + pattern_constant_value[i], + rel_tol=pattern_constant._rel_tol, + abs_tol=pattern_constant._abs_tol, + ) + for i in range(len(pattern_constant_value)) + ): + return self.fail( + f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}." + ) + return True + + # TODO (rama): allow users to specify shape requirement, if desired. + if constant_value_numpy.size != 1: + return self.fail( + f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.", + ) + + if not math.isclose( + constant_value_numpy.item(), + pattern_constant_value, + rel_tol=pattern_constant._rel_tol, + abs_tol=pattern_constant._abs_tol, + ): + return self.fail( + f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.", + ) + + return True + + def _match_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node) -> bool: + """Matches a pattern subgraph against subgraph rooted at node.""" + self._current_node = node + # Graph-matching: we do not allow the same pattern node to be matched against + # different graph nodes. + matched_node = self._match.lookup_node(pattern_node) + if matched_node is not None: + if matched_node is not node: + return self.fail("Same pattern node is matched against different graph nodes.") + return True + match = self._match + if not pattern_node.matches(node, match): + return self.fail(match.reason) + + if self._verbose: + print(f"Matched: {node.op_type}") + + match.bind_node(pattern_node, node) + + # TODO: Revisit this to handle optional trailing inputs better. + + if len(node.inputs) > len(pattern_node.inputs): + if not pattern_node.allow_other_inputs: + return self.fail( + f"Number of inputs ({len(node.inputs)}) is greater than expected ({len(pattern_node.inputs)})" + ) + checked_inputs = zip(node.inputs, pattern_node.inputs) + else: + # In ONNX, trailing Nones can be omitted in the inputs of a node. So, we extend actual + # node inputs with None values to match the pattern node inputs length when zipping. + checked_inputs = itertools.zip_longest( + node.inputs, pattern_node.inputs, fillvalue=None + ) + + for arg_value, arg_pattern in checked_inputs: + # arg_pattern could be a Var, if it's the original arg. + if arg_pattern is None: + if arg_value is None: + continue + else: + return self.fail("(Optional) input is expected to be None but is not.") + if not self._match_value(arg_pattern, arg_value): + return False + + for i, output_value_pattern in enumerate(pattern_node.outputs): + # When trying to bind more outputs (from the pattern) than there are + # actual outputs of the candidate node, reject the node before even + # trying to index into the list of node outputs. + if i >= len(node.outputs): + return False + + if not self._match.bind_value(output_value_pattern, node.outputs[i]): + return False + + return True + + def _match_value( + self, pattern_value: _pattern_ir.ValuePattern, value: ir.Value | None + ) -> bool: + """Match an IR value against a ValuePattern instance.""" + if value is not None and value.graph is not self._graph: + if not isinstance( + pattern_value, (_pattern_ir.Var, _pattern_ir.Constant, _pattern_ir.AnyValue) + ): + # If the pattern value is a Var, Constant, or AnyValue, we allow it to match + # values from other graphs. Otherwise, we fail the match. + return self.fail( + f"Value {value.name} is not in the graph {self._graph.name}. " + f"Pattern matches crossing graph boundaries are not supported." + ) + if isinstance(pattern_value, _pattern_ir.AnyValue): + return True + + if not self._match.bind_value(pattern_value, value): + return False + + if isinstance(pattern_value, _pattern_ir.NodeOutputPattern): + if value is None: + return self.fail("Mismatch: Computed node pattern does not match None.") + return self._match_node_output(pattern_value, value) + if isinstance(pattern_value, _pattern_ir.Constant): + if value is None: + return self.fail("Mismatch: Constant pattern does not match None.") + return self._match_constant(pattern_value, value) + if isinstance(pattern_value, _pattern_ir.BacktrackingOr): + for i, pattern_choice in enumerate(pattern_value._values): + self._match.enter_new_match() + if self._match_value(pattern_choice, value): + if pattern_value.tag_var is not None: + self._match.bind(pattern_value.tag_var, pattern_value._tag_values[i]) + self._match.merge_current_match() + return True + self._match.abandon_current_match() + return self.fail("None of the alternatives matched.") + if isinstance(pattern_value, _pattern_ir.OpIdDispatchOr): + if value is None: + return self.fail("Mismatch: OrValue pattern does not match None.") + alternative = pattern_value.get_pattern(value) + if alternative is None: + return self.fail("Mismatch: OrValue pattern does not match value.") + i, pattern_choice = alternative + result = self._match_value(pattern_choice, value) + if result: + if pattern_value.tag_var is not None: + self._match.bind(pattern_value.tag_var, i) + return result + # Default case: a plain pattern variable (ValuePattern) + if value is None and not pattern_value.can_match_none: + return self.fail( + f"Mismatch: pattern variable {pattern_value} does not match None." + ) + return True + + def _match_node_output( + self, pattern_value: _pattern_ir.NodeOutputPattern, value: ir.Value + ) -> bool: + """Match an IR value against a NodeOutputPattern instance.""" + node = value.producer() + if node is None: + return self.fail( + "Mismatch: Computed node pattern does not match uncomputed IR value." + ) + if value.index() != pattern_value.output_index: + return self.fail( + f"Node output index mismatch: expected {pattern_value._output_index}, got {value.index()}." + ) + return self._match_node(pattern_value.producer(), node) + + def _init_match(self, verbose: int) -> None: + """Initialize the match state. Invoked before starting a new match.""" + self._verbose = verbose + self._match: _basics.MatchResult = _basics.MatchResult() + self._current_node = None + + def _get_output_values(self) -> list[ir.Value] | None: + """Get values bound to the output variables of the pattern.""" + output_values: list[ir.Value] = [] + unbound_values: list[str] = [] + for j, value_pattern in enumerate(self.pattern.outputs): + if value_pattern.name is not None: + if value_pattern.name in self._match.bindings: + output_values.append(self._match.bindings[value_pattern.name]) + else: + unbound_values.append(value_pattern.name) + else: + if value_pattern in self._match.value_bindings: + output_values.append(self._match.value_bindings[value_pattern]) + else: + unbound_values.append(f"output_{j}") + if unbound_values: + self._match.fail(f"Error: Output values not found: {unbound_values}") + return None + return output_values + + def _match_single_output_node( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + check_removable: bool, + ) -> _basics.MatchResult: + del model + del graph_or_function + + pattern = self.pattern + match = self._match + + if not pattern.has_single_output_node: + return match.fail( + "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." + ) + + if not self._match_node(pattern.output_node, node): + return match + + output_values = self._get_output_values() + if output_values is None: + # TODO(rama): Is this a valid (useful) case? + return match + if check_removable and not _valid_to_replace(match.nodes, output_values): + # TODO(rama): Match status should be updated to reflect failure reason. + return match.fail("Matched nodes have other uses preventing replacement.") + + match.outputs.extend(output_values) + return match + + def _multi_match( + self, candidate: Iterable[ir.Node], check_removable: bool + ) -> _basics.MatchResult: + """Find a match for a pattern with multiple output nodes. + + For a pattern with K output nodes, the input candidate should specify K nodes + in the graph that will be matched against the pattern output nodes. + + Args: + candidate: An iterable of nodes that will be matched against the pattern output nodes. + check_removable: If True, check that the matched nodes can be removed (that is, that + they are not used elsewhere in the graph). + """ + match = self._match + for pattern_node, node in zip(self.pattern.output_nodes, candidate): + if not self._match_node(pattern_node, node): + return match + output_values = self._get_output_values() + if output_values is None: + return match + + if check_removable and not _valid_to_replace(match.nodes, output_values): + return match.fail("Matched nodes have other uses preventing replacement.") + + match.outputs.extend(output_values) + return match + + def match( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + *, + verbose: int = 0, + remove_nodes: bool = True, + tracer: _basics.MatchingTracer | None = None, + ) -> _basics.MatchResult: + """Match the pattern against the subgraph ending at the given node. + + For patterns with multiple output nodes, the given node is matched + against the first output node in the pattern. For the remaining + output nodes in the pattern, we use a brute-force algorithm that + enumerates all possible combinations of nodes from the graph (with + a filter based on op-type). + + TODO: Consider omitting parameters model and graph_or_function. With + the new IR, the graph can be obtained from the node, and the model is + not used. But this is a shared abstract method of the Matcher interface, + so other matcher implementation also needs to be updated. More importantly, + matching in the presence of subgraphs (control-flow) can introduce some + complications which require careful consideration. + """ + self._tracer = tracer + if isinstance(graph_or_function, ir.Graph): + self._graph: ir.Graph = graph_or_function + else: + self._graph = graph_or_function.graph + if self.pattern.has_single_output_node: + self._init_match(verbose) + return self._match_single_output_node( + model, graph_or_function, node, check_removable=remove_nodes + ) + else: + # Note: This is a potentially expensive algorithm for matching patterns with + # multiple output nodes. For patterns with N output nodes, we try all possible + # combinations of N nodes from the graph, and check if they match the pattern. + # The first node is fixed to the node argument in this method call. We do + # some simple filtering by restricting the candidates for each remaining + # output nodes to graph nodes with the same op_type as the corresponding pattern + # node. For now, this is intended to be a simple, but robust, implementation + # that can be used for debugging and testing. The GenericPatternMatcher is a + # more sophisticated implementation, but incomplete. + pattern_output_nodes = self.pattern.output_nodes + op_to_nodes: dict[tuple[str, str, str], list[ir.Node]] = {} + for n in graph_or_function: + op_to_nodes.setdefault(n.op_identifier(), []).append(n) + all_nodes = iter(graph_or_function) + + def get_nodes(pattern_node): + id = pattern_node.op_identifier() + if id is None: + return all_nodes + return op_to_nodes.get(id, []) + + candidates = [iter([node])] + [get_nodes(pn) for pn in pattern_output_nodes[1:]] + match = None + for combination in itertools.product(*candidates): + self._init_match(verbose) + match = self._multi_match(combination, check_removable=remove_nodes) + if match: + return match + if match is None: + return _basics.MatchResult().fail("No match found.") + return match diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py new file mode 100644 index 0000000000..9b81e33581 --- /dev/null +++ b/onnxscript/rewriter/_pattern_ir.py @@ -0,0 +1,966 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""The Pattern IR: used to describe (source) patterns of rewrite rules.""" + +from __future__ import annotations + +import abc +import contextlib +import inspect +import itertools +from collections.abc import Mapping +from typing import ( + Any, + Callable, + Iterable, + Iterator, + Protocol, + Sequence, + TypeVar, + Union, +) + +import onnxscript.rewriter._basics as _basics +from onnxscript import ir + +T = TypeVar("T") + + +class Pattern(Protocol[T]): # type: ignore[misc] + """This is essentially a Predicate[T], that is, a Callable[[T], bool] bound to the name "matches".""" + + def matches(self, item: T) -> bool: ... + + +class StringPattern(abc.ABC, Pattern[str]): + """Abstract base class for string patterns.""" + + @abc.abstractmethod + def matches(self, item: str) -> bool: + pass + + @abc.abstractmethod + def __str__(self) -> str: + pass + + +class StringConstantPattern(StringPattern): + """Matches strings with given value.""" + + def __init__(self, value: str): + self._value = value + + def matches(self, item: str) -> bool: + return item == self._value + + def __str__(self) -> str: + return self._value + + def value(self) -> str: + return self._value + + +class PrefixPattern(StringPattern): + """Matches strings with a given prefix.""" + + def __init__(self, value: str) -> None: + self._value = value + + def matches(self, value: str) -> bool: + return value.startswith(self._value) + + def __str__(self) -> str: + return f"{self._value}*" + + +class AttrPattern(Pattern[ir.Attr]): + """Base class for an attribute pattern. Matches any attribute value by default.""" + + def __init__(self, name: str | None, *, can_match_none: bool = False): + self._name = name + self._can_match_none = can_match_none + + @property + def name(self) -> str | None: + return self._name + + @property + def can_match_none(self) -> bool: + """Indicates whether this pattern can match a None attribute.""" + return self._can_match_none + + def matches(self, attr: ir.Attr) -> bool: + return True + + def __str__(self) -> str: + return self._name if self._name is not None else "anonymous:" + str(id(self)) + + +class AttrVar(AttrPattern): + """Represents a pattern variable used to match against attribute values.""" + + def __init__(self, name: str | None, *, can_match_none: bool = False): + super().__init__(name, can_match_none=can_match_none) + + +# TODO: Support tensors. Align with usage elsewhere. +SupportedAttrTypes = Union[ + int, + float, + str, + Sequence[int], + Sequence[float], + Sequence[str], +] + + +class AttrConstantPattern(AttrPattern): + """Matches attributes with given value. + + Uses standard equality for matching. For list-valued attributes, the order of elements matters. + If order is immaterial, we need to define a separate pattern for that. + """ + + def __init__(self, value: SupportedAttrTypes): + super().__init__(None) + self._value = value + + def matches(self, attr: ir.Attr) -> bool: + if attr.type in { + ir.AttributeType.INTS, + ir.AttributeType.FLOATS, + ir.AttributeType.STRINGS, + }: + # Since the type of attr.value is Sequence, we need to convert to the same type for comparison. + return tuple(attr.value) == tuple(self._value) + return attr.value == self._value + + def __str__(self) -> str: + return str(self._value) + + +def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> AttrPattern: + """Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern.""" + if isinstance(value, AttrPattern): + return value + if isinstance(value, Var): + # This is a hack. Currently, when we create pattern-variables, we create them as Var, + # and change them to AttrPattern if/when used in an attribute context. We could use type + # annotations to distinguish between ValuePattern and AttrPattern, but forces users to + # use these type annotations. + # TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.) + if value.check_method is not None: + raise ValueError( + "Pattern variables used in attributes must not have check_method set." + ) + return AttrVar(value.name, can_match_none=value.can_match_none) + if isinstance(value, (int, float, str)): + return AttrConstantPattern(value) + if isinstance(value, Sequence): + if all(isinstance(i, (int, float)) for i in value): + return AttrConstantPattern(value) + if all(isinstance(i, str) for i in value): + return AttrConstantPattern(value) + raise ValueError("Only lists of int/float/str can be used as an AttrPattern") + raise TypeError(f"Cannot convert {type(value)} to AttrPattern") + + +class OpsetPatternBuilder: + """Represents an opset pattern and a pattern builder. + + (i) It is used to create a NodePattern (via OpPatternBuilder). + Example usage: + :: + + z = op.Matmul(x, y) + + Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance + of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. + + (ii) It contains a domain pattern matched against the actual opset domain used in the + input model. + """ + + def __init__(self, domain: StringPattern | str, record: bool = False) -> None: + if isinstance(domain, str): + domain = StringConstantPattern(domain) + self._domain_pattern = domain + if record: + self._nodes: list[NodePattern] | None = [] + else: + self._nodes = None + + def domain_pattern(self) -> StringPattern: + return self._domain_pattern + + def __getattr__(self, op_name: str) -> OpPatternBuilder: + return OpPatternBuilder(self, op_name) + + def submodule(self, name: str) -> OpPatternBuilder: + """This method is used to match against submodule ops with prefix.""" + return OpPatternBuilder(self, PrefixPattern(name)) + + def __str__(self) -> str: + return str(self._domain_pattern) + + def add_node(self, node: NodePattern) -> None: + if self._nodes is not None: + self._nodes.append(node) + + def nodes(self) -> Sequence[NodePattern]: + if self._nodes is None: + raise ValueError("Nodes were not recorded.") + return self._nodes + + +onnxop = OpsetPatternBuilder("") + +torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch")) + + +class OpPatternBuilder: + """A utility class to build a NodePattern. + + It is used primarily to create a NodePattern. + Example usage: + :: + + z = op.Matmul(x, y) + + Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance + of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. + + """ + + def __init__( + self, + pattern_builder: OpsetPatternBuilder, + op_name: str | Pattern[str], + ) -> None: + self.pattern_builder = pattern_builder + self.op_name = op_name + + def __call__( + self, + *args, + _domain: str | None = None, + _version: int | None = None, + _outputs: int | list[str | None] = 1, + _allow_other_attributes: bool | None = None, + _allow_other_inputs: bool | None = None, + _check: Callable | None = None, + **kwargs, + ): + if _version is not None: + raise ValueError( + "The pattern builder does not support '_version' keyword argument. " + "Version restrictions should be handled by rewrite rules." + ) + if _domain is None: + opset_pattern = self.pattern_builder.domain_pattern() + elif isinstance(_domain, str): + opset_pattern = StringConstantPattern(_domain) + else: + # TODO(rama): allow OpsetPatternBuilder as _domain. + raise TypeError("_domain must be a string.") + + if isinstance(_outputs, int): + _outputs = [None for _ in range(_outputs)] + elif not isinstance(_outputs, Sequence) or not all( + isinstance(x, (str, type(None))) for x in _outputs + ): + raise ValueError("_outputs must be an int or a list[str|None].") + inputs = [_to_value_pattern(x) for x in args] + attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} + node_pattern = NodePattern( + opset_pattern, + self.op_name, + inputs, + attributes, + _outputs, + allow_other_attributes=_allow_other_attributes, + allow_other_inputs=_allow_other_inputs, + check=_check, + ) + self.pattern_builder.add_node(node_pattern) + output_values = node_pattern.outputs + # Unpack outputs if there is only one output, the common case. + if len(output_values) == 1: + return output_values[0] + else: + return output_values + + +def _to_value_pattern( + x: ValuePattern | int | float | Callable | None, +) -> ValuePattern | None: + """Promotes an input-value used to construct a NodePattern to a ValuePattern. + + Example usage: + :: + x = op.MatMul(a, b) + z = op.Add(x, 0) + + In this example, `a, `b`, and `x` are ValuePatterns used to construct a NodePattern. + `0` is a constant (int) value, and is automatically promoted to a ValuePattern. + + Note that this is a shorthand for creating a Constant pattern. The user can more + explicitly write this as: + :: + z = op.Add(x, op.Constant(0)) + + If a callable is provided, it will be converted to a ValuePattern with the callable as the check attribute. + """ + if x is None or isinstance(x, ValuePattern): + return x + if isinstance(x, (int, float)): + return Constant(x) + if isinstance(x, Sequence): + if all(isinstance(i, (int, float)) for i in x): + return Constant(x) + raise ValueError("Only lists of int/float can be used as a ValuePattern") + if callable(x): + return ValuePattern(None, check=x) + + raise TypeError(f"Cannot convert {type(x)} to ValuePattern") + + +_pattern_builder: OpsetPatternBuilder = onnxop + + +@contextlib.contextmanager +def pattern_builder(builder: OpsetPatternBuilder): + global _pattern_builder + prev_builder = _pattern_builder + _pattern_builder = builder + yield + _pattern_builder = prev_builder + + +class ValuePattern: + """Base class for all patterns that match against IR values. + + This is used primarily to provide operator overloadings for arithmetic + operations, so that we can write patterns like `x + 1` and `1 + x`. + """ + + def __init__( + self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False + ) -> None: + self._name = name + self._check = check + self._can_match_none = can_match_none + # Note: uses will be computed only when the full graph-pattern is constructed. + self._uses: list[tuple[NodePattern, int]] = [] + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern: + del node_map + return ValuePattern(self._name, check=self._check) + + @property + def name(self) -> str | None: + return self._name + + @property + def check_method(self) -> Callable | None: + return self._check + + @property + def can_match_none(self) -> bool: + """Indicates whether this variable can match a None input.""" + return self._can_match_none + + def producer(self) -> NodePattern | None: + return None + + def uses(self) -> Sequence[tuple[NodePattern, int]]: + return self._uses + + def append_use(self, node: NodePattern, index: int): + self._uses.append((node, index)) + + def __repr__(self) -> str: + return f"ValuePattern({self._name!r})" + + def __add__(self, other): + return _pattern_builder.Add(self, other) + + def __radd__(self, other): + return _pattern_builder.Add(other, self) + + def __sub__(self, other): + return _pattern_builder.Sub(self, other) + + def __rsub__(self, other): + return _pattern_builder.Sub(other, self) + + def __mul__(self, other): + return _pattern_builder.Mul(self, other) + + def __rmul__(self, other): + return _pattern_builder.Mul(other, self) + + def __truediv__(self, other): + return _pattern_builder.Div(self, other) + + def __rtruediv__(self, other): + return _pattern_builder.Div(other, self) + + def __pow__(self, other): + return _pattern_builder.Pow(self, other) + + def __str__(self) -> str: + return self._name if self._name is not None else "anonymous:" + str(id(self)) + + +class NodePattern: + """Represents a pattern that matches against a Node. + + This differs from a NodeOutputPattern in that it matches against a node (which + may produce 1 or more outputs), whereas a NodeOutputPattern matches against + a specific output of a node. + + Args: + domain: pattern to match against the domain of the node. + op: pattern or string constant to match against the op_type of the node. + inputs: sequence of ValuePatterns (or constants) to match against the inputs of the node. + attributes: dictionary of attribute patterns to match against the attributes of the node. + outputs: specifies pattern-variable-name for outputs (or None) + allow_other_attributes: specifies whether other attributes (not mentioned in `attributes`) + are allowed in the node. + """ + + def __init__( + self, + domain: StringPattern, + op: str | Pattern[str], + inputs: Sequence[int | float | ValuePattern | None], + attributes: dict[str, AttrPattern], + outputs: Sequence[str | None], + *, + allow_other_attributes: bool | None, + allow_other_inputs: bool | None, + check: Callable | None = None, + ): + if allow_other_attributes is None: + # Default behavior: allow other unmatched attributes in the node. + allow_other_attributes = True + if allow_other_inputs is None: + # TODO(rama): Should we default to True? For now, we preserve the current behavior. + allow_other_inputs = False + self.domain = domain + self.op = StringConstantPattern(op) if isinstance(op, str) else op + self.inputs = [_to_value_pattern(x) for x in inputs] + self.attributes = attributes + self.allow_other_attributes = allow_other_attributes + self.allow_other_inputs = allow_other_inputs + self._check = check + # In the common case, domain and op are constants, which can be used to optimize matching. + if isinstance(op, str) and isinstance(domain, StringConstantPattern): + # TODO(rama): support overloaded operators. + overload = "" + self._op_identifier: ir.OperatorIdentifier | None = ( + domain.value(), + op, + overload, + ) + else: + self._op_identifier = None + self.outputs = [NodeOutputPattern(self, i, name) for i, name in enumerate(outputs)] + + # Update uses for inputs. + for index, value in enumerate(self.inputs): + if value is not None: + value.append_use(self, index) + + def __str__(self) -> str: + inputs = ", ".join(str(v) for v in self.inputs) + outputs = ", ".join(str(v) for v in self.outputs) + attributes = ", ".join(f"{k}={v}" for k, v in self.attributes.items()) + op = str(self.op) + domain = str(self.domain) + qualified_op = f"{domain}.{op}" if domain else op + inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs + return f"{outputs} = {qualified_op} ({inputs_and_attributes})" + + def op_identifier(self) -> ir.OperatorIdentifier | None: + return self._op_identifier + + @property + def op_type(self) -> str: + return str(self.op) + + @property + def check_method(self) -> Callable | None: + return self._check + + def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchResult: + """Matches the pattern represented by self against a node. + + This is purely a local node-level match, and does not consider the subgraph rooted at the node. + We check the domain, op_type, and attributes of the node, but not the inputs. + """ + # TODO(rama): Ensure we handle "" and "onnx.ai" correctly. + if not self.op.matches(node.op_type): + return match.fail( + f"OpType mismatch: expected {self.op}, got {node.op_type}.", node + ) + if not self.domain.matches(node.domain): + return match.fail( + f"Domain mismatch: expected {self.domain}, got {node.domain}.", node + ) + + for name, attr_pattern in self.attributes.items(): + attr_value = node.attributes.get(name) + if attr_value is None: + if not attr_pattern.can_match_none: + return match.fail(f"Attribute {name} not found in node.", node) + elif not attr_pattern.matches(attr_value): + return match.fail( + f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.", + node, + ) + if attr_pattern.name is not None: + if not match.bind(attr_pattern.name, attr_value): + return match + + if not self.allow_other_attributes: + for name in node.attributes: + # TODO: Support matching default nodes for attributes. + if name not in self.attributes: + return match.fail(f"Attribute {name} not expected in node.", node) + + return match + + def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern: + inputs = [(v.clone(node_map) if v is not None else None) for v in self.inputs] + if swap: + assert len(inputs) == 2, ( + "Internal error: commutative swap applies only to binary ops." + ) + inputs = [inputs[1], inputs[0]] + outputs = [value.name for value in self.outputs] + copied = NodePattern( + self.domain, + self.op, + inputs, + self.attributes, + outputs, + allow_other_attributes=self.allow_other_attributes, + allow_other_inputs=self.allow_other_inputs, + check=self._check, + ) + node_map[self] = copied + return copied + + +class NodeOutputPattern(ValuePattern): + """Represents a pattern that matches against a specific output of a Node. + + This is the primary pattern used to match against computed values, that + is values computed using a specific op. + """ + + def __init__( + self, producer: NodePattern, output_index: int, name: str | None = None + ) -> None: + super().__init__(name) + self._producer = producer + self._output_index = output_index + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> NodeOutputPattern: + return node_map[self._producer].outputs[self._output_index] + # return NodeOutputPattern(node_map[self._producer], self._output_index, self._name) + + @property + def output_index(self) -> int: + return self._output_index + + def producer(self) -> NodePattern: + return self._producer + + +class Var(ValuePattern): + """Represents a pattern-variable.""" + + def __init__( + self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False + ) -> None: + super().__init__(name, check=check, can_match_none=can_match_none) + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> Var: + """Clones the pattern-variable, preserving its name and check method.""" + return Var(self.name, check=self.check_method, can_match_none=self.can_match_none) + + +class AnyValue(ValuePattern): + """Represents a pattern that matches against any value.""" + + def __init__(self) -> None: + super().__init__(None) + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> AnyValue: + # A single instance of AnyValue suffices. + return self + + +ANY_VALUE = AnyValue() + + +class Constant(ValuePattern): + """Represents a pattern that matches against a scalar constant value.""" + + def __init__( + self, + value: int | float | Sequence[int] | Sequence[float], + rel_tol: float = 1e-5, + abs_tol: float = 1e-8, + ) -> None: + super().__init__(None) + self._value = list(value) if isinstance(value, Sequence) else value + self._rel_tol = rel_tol + self._abs_tol = abs_tol + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant: + del node_map + return Constant(self._value, self._rel_tol, self._abs_tol) + + @property + def value(self) -> int | float | list[int] | list[float]: + return self._value + + def __str__(self) -> str: + return str(self._value) + + +class OpIdDispatchOr(ValuePattern): + """Represents a (restricted) form of value pattern disjunction that enables deterministic matching.""" + + def __init__( + self, + op_to_pattern: Mapping[ir.OperatorIdentifier, tuple[Any, ValuePattern]], + name: str | None = None, + tag_var: str | None = None, + ) -> None: + """ + Initialize an OpIdDispatchOr pattern. + + Args: + op_to_pattern: A dictionary mapping operator identifiers to tuples of tag values and patterns. + The keys are operator identifiers, and the values are tuples containing a tag value + and a pattern to match against. + name: An optional variable name for the pattern. Defaults to None. If present, + this name will be bound to the value matched by the pattern. + tag_var: An optional variable name for the tag. Defaults to None. If present, + it will be bound to a value indicating which alternative was matched. + """ + super().__init__(name) + self._op_to_pattern = op_to_pattern + self._tag_var = tag_var + + @property + def tag_var(self) -> str | None: + """Returns the tag variable associated with the OrValue pattern.""" + return self._tag_var + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> OpIdDispatchOr: + return OpIdDispatchOr( + {k: (v[0], v[1].clone(node_map)) for k, v in self._op_to_pattern.items()}, + self.name, + self._tag_var, + ) + + def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern] | None: + """Returns the pattern that should be tried for the given value.""" + producer = value.producer() + if producer is not None: + id = producer.op_identifier() + if id is not None and id in self._op_to_pattern: + return self._op_to_pattern[id] + return None + + +class BacktrackingOr(ValuePattern): + """Represents an unrestricted form of OR pattern implemented using backtracking.""" + + def __init__( + self, + values: Sequence[ValuePattern], + name: str | None = None, + tag_var: str | None = None, + tag_values: Sequence[Any] | None = None, + ) -> None: + """ + Initialize a BacktrackingOr pattern. + + Args: + values: A sequence of value patterns to match against. + name: An optional variable name for the pattern. Defaults to None. If present, + this name will be bound to the value matched by the pattern. + tag_var: An optional variable name for the tag. Defaults to None. If present, + it will be bound to a value (from tag_values) indicating which alternative was matched. + tag_values: An optional sequence of values to bind to the tag_var. Defaults to None. + If present, the length of tag_values must match the number of alternatives in values. + In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th + alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used. + """ + super().__init__(name) + if tag_values is not None: + if tag_var is None: + raise ValueError("tag_var must be specified if tag_values is provided.") + if len(tag_values) != len(values): + raise ValueError( + "tag_values must have the same length as the number of alternatives." + ) + else: + tag_values = tuple(range(len(values))) + self._tag_var = tag_var + self._tag_values = tag_values + self._values = values + + @property + def tag_var(self) -> str | None: + """Returns the tag variable associated with the OrValue pattern.""" + return self._tag_var + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> BacktrackingOr: + return BacktrackingOr( + [v.clone(node_map) for v in self._values], + self.name, + self._tag_var, + self._tag_values, + ) + + +def OrValue( + values: Sequence[ValuePattern], + name: str | None = None, + tag_var: str | None = None, + tag_values: Sequence[Any] | None = None, +) -> ValuePattern: + """ + Creates an OR pattern. + + Args: + values: A sequence of value patterns to match against. + name: An optional variable name for the pattern. Defaults to None. If present, + this name will be bound to the value matched by the pattern. + tag_var: An optional variable name for the tag. Defaults to None. If present, + it will be bound to a value (from tag_values) indicating which alternative was matched. + tag_values: An optional sequence of values to bind to the tag_var. Defaults to None. + If present, the length of tag_values must match the number of alternatives in values. + In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th + alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used. + """ + if tag_values is not None: + if tag_var is None: + raise ValueError("tag_var must be specified if tag_values is provided.") + if len(tag_values) != len(values): + raise ValueError( + "tag_values must have the same length as the number of alternatives." + ) + else: + tag_values = tuple(range(len(values))) + + def make_op_id_or_pattern() -> OpIdDispatchOr | None: + mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {} + for i, alternative in enumerate(values): + if not isinstance(alternative, NodeOutputPattern): + return None + producer = alternative.producer() + id = producer.op_identifier() + if id is None or id in mapping: + return None + mapping[id] = (tag_values[i], alternative) + return OpIdDispatchOr(mapping, name, tag_var) + + optimized_pattern = make_op_id_or_pattern() + return optimized_pattern or BacktrackingOr( + values, name, tag_var, tag_values if tag_var else None + ) + + +def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]: + """Returns all nodes used in a pattern, given the outputs of the pattern.""" + node_patterns: list[NodePattern] = [] + + def visit(value_patterns: Sequence[ValuePattern | None]) -> None: + for value_pattern in value_patterns: + if isinstance(value_pattern, NodeOutputPattern): + node_pattern = value_pattern.producer() + if node_pattern not in node_patterns: + node_patterns.append(node_pattern) + visit(node_pattern.inputs) + + visit(outputs) + node_patterns.reverse() + return node_patterns + + +def _add_backward_slice( + node: NodePattern, + backward_slice: set[NodePattern], + backward_slice_values: set[ValuePattern], +) -> None: + """Adds all nodes in the backward slice of given node to the set `backward_slice`. + + The backward slice of a node is the set of all nodes that are reachable from the node + in a backward traversal from the given node. + """ + if node in backward_slice: + return + backward_slice.add(node) + for value_pattern in node.inputs: + if isinstance(value_pattern, NodeOutputPattern): + _add_backward_slice( + value_pattern.producer(), backward_slice, backward_slice_values + ) + elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)): + backward_slice_values.add(value_pattern) + + +class GraphPattern: + """Represents a pattern that can be matched against a subgraph.""" + + def __init__( + self, + inputs: Sequence[ValuePattern], + outputs: Sequence[ValuePattern], + nodes: Sequence[NodePattern], + ) -> None: + self._inputs = inputs + self._outputs = outputs + if len(outputs) == 0: + raise ValueError("GraphPattern must have at least one output") + self._nodes = nodes # _nodes_in_pattern(outputs) + + # Determine the output nodes of the pattern. These are a minimal set of nodes + # whose backward-slices cover the entire pattern. + output_nodes: set[NodePattern] = set() + covered: set[NodePattern] = set() + choice_values_returned: set[ValuePattern] = set() + covered_choice_values: set[ValuePattern] = set() + for value_pattern in outputs: + if not isinstance(value_pattern, ValuePattern): + raise TypeError( + f"Invalid type {type(value_pattern)} for graph pattern output." + ) + if isinstance(value_pattern, NodeOutputPattern): + candidate = value_pattern.producer() + if candidate not in covered: + output_nodes.add(candidate) + _add_backward_slice(candidate, covered, covered_choice_values) + elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)): + choice_values_returned.add(value_pattern) + + # check if all choice_values_returned are contained in covered_choice_values: + # We don't yet support the use of a choice-value as a "root" of the search. + # This is a limitation of the current implementation, and will be fixed in the future. + if not (choice_values_returned <= covered_choice_values): + raise NotImplementedError("Returning uncovered choice-values is not supported.") + + self.output_nodes: list[NodePattern] = list(output_nodes) + + @property + def output_node(self) -> NodePattern: + if len(self.output_nodes) != 1: + raise ValueError("GraphPattern does not have unique output node.") + return self.output_nodes[0] + + def node(self, index: int) -> NodePattern: + return self._nodes[index] + + def num_nodes(self) -> int: + return len(self._nodes) + + def __len__(self) -> int: + return self.num_nodes() + + @property + def inputs(self) -> Sequence[ValuePattern]: + return self._inputs + + @property + def outputs(self) -> Sequence[ValuePattern]: + return self._outputs + + def __iter__(self) -> Iterator[NodePattern]: + return iter(self._nodes) + + def __reversed__(self) -> Iterator[NodePattern]: + return reversed(self._nodes) + + @property + def has_single_output_node(self) -> bool: + return len(self.output_nodes) == 1 + + @property + def num_outputs(self) -> int: + return len(self._outputs) + + def commute(self) -> Sequence[GraphPattern]: + def commute_node(node: NodePattern) -> Iterable[bool]: + if node.op_identifier() == ("", "Add", "") or node.op_identifier() == ( + "", + "Mul", + "", + ): + # Try with and without swapping inputs. + return [False, True] + # No swapping of inputs + return [False] + + iteration_space = [commute_node(node) for node in self._nodes] + + def copy_graph(swap_list: Iterable[bool]) -> GraphPattern: + if not any(swap_list): + # No need to swap inputs of any node + return self + # Create a copy of the graph, with swapped inputs for the nodes that need it. + node_map: dict[NodePattern, NodePattern] = {} + new_inputs = [v.clone(node_map) for v in self._inputs] + new_nodes = [ + node.clone(node_map, swap) for node, swap in zip(self._nodes, swap_list) + ] + new_outputs = [v.clone(node_map) for v in self._outputs] + return GraphPattern(new_inputs, new_outputs, new_nodes) + + return [copy_graph(swap_list) for swap_list in itertools.product(*iteration_space)] + + def __str__(self) -> str: + inputs = ", ".join(str(v) for v in self._inputs) + outputs = ", ".join(str(v) for v in self._outputs) + nodes = "\n ".join(str(n) for n in self._nodes) + return f"pattern ({inputs}) {{\n {nodes}\n return {outputs}\n}}" + + +def _to_graph_pattern(pattern_constructor: Callable) -> GraphPattern: + """Convert a pattern-construction function to a GraphPattern. + + A pattern-construction function will return values as below: + :: + def pattern(op, x: Var, shape1: Var, shape2: Var): + ... + return outputs + + We create a pattern graph by creating pattern-variables for each parameter of the function, + and calling the function. The returned values are normalized to a list of ValuePatterns, + which represent the outputs of the pattern graph. + + Args: + pattern_constructor: Callable + + Returns: + GraphPattern: A representation of the pattern that can be matched against a subgraph. + """ + _pattern_vars = inspect.signature(pattern_constructor).parameters + pattern_inputs = [Var(v) for v in _pattern_vars][1:] # Skip the first parameter + builder = OpsetPatternBuilder("", record=True) + with pattern_builder(builder): + pattern_outputs = pattern_constructor(builder, *pattern_inputs) + # TODO(rama): classify inputs as value/attribute vars + # Returned value could be a single ValuePattern or a list of ValuePatterns. + # Normalize representation to a list of ValuePatterns. + if isinstance(pattern_outputs, ValuePattern): + pattern_outputs = [pattern_outputs] + return GraphPattern(pattern_inputs, pattern_outputs, builder.nodes()) diff --git a/onnxscript/rewriter/_pattern_ir_test.py b/onnxscript/rewriter/_pattern_ir_test.py new file mode 100644 index 0000000000..e5f826b191 --- /dev/null +++ b/onnxscript/rewriter/_pattern_ir_test.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +from onnxscript.rewriter import _pattern_ir + + +class PatternIRTest(unittest.TestCase): + """Test _pattern_ir module functionality.""" + + def test_value_pattern_with_check(self): + """Test ValuePattern with check attribute.""" + + def value_checker(context, value): + return True + + # Test creating ValuePattern with check + value_pattern = _pattern_ir.ValuePattern("test_value", check=value_checker) + self.assertIs(value_pattern._check, value_checker) + self.assertEqual(value_pattern.name, "test_value") + + def test_node_pattern_with_check(self): + """Test NodePattern with check attribute.""" + + def node_checker(context, node): + return True + + # Test creating NodePattern with check + domain_pattern = _pattern_ir.StringConstantPattern("") + inputs = [] + attributes = {} + outputs = ["output"] + + node_pattern = _pattern_ir.NodePattern( + domain_pattern, + "Add", + inputs, + attributes, + outputs, + allow_other_attributes=True, + allow_other_inputs=True, + check=node_checker, + ) + self.assertIs(node_pattern._check, node_checker) + + def test_to_value_pattern_with_callable(self): + """Test _to_value_pattern function with callable input.""" + + def my_checker(context, value): + return True + + result = _pattern_ir._to_value_pattern(my_checker) + self.assertIsInstance(result, _pattern_ir.ValuePattern) + self.assertIs(result._check, my_checker) + self.assertIsNone(result.name) + + def test_op_pattern_builder_with_check(self): + """Test OpPatternBuilder with _check parameter.""" + + def node_checker(context, node): + return True + + # Create OpPatternBuilder and call with _check parameter + opset_builder = _pattern_ir.OpsetPatternBuilder("") + result = opset_builder.Add(None, None, _check=node_checker) + + # The result should be a NodeOutputPattern, and its producer should have the check + self.assertTrue(hasattr(result, "producer")) + producer = result.producer() + self.assertIsNotNone(producer) + self.assertTrue(hasattr(producer, "_check")) + self.assertIs(producer._check, node_checker) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py new file mode 100644 index 0000000000..8964230fe0 --- /dev/null +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -0,0 +1,788 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Rewrite rules for ONNX models.""" + +from __future__ import annotations + +import abc +import dataclasses +import itertools +from typing import ( + Callable, + Sequence, + TypeVar, +) + +import onnxscript.optimizer +import onnxscript.rewriter._basics as _basics +import onnxscript.rewriter._ir_utils as _ir_utils +import onnxscript.rewriter._matcher as _matcher +import onnxscript.rewriter._pattern_ir as _pattern_ir +from onnxscript import ir +from onnxscript.ir import _tape, convenience + +T = TypeVar("T") + +RewriterContext = _tape.Builder + + +@dataclasses.dataclass +class ReplacementSubgraph: + """A subgraph that will replace the matched pattern.""" + + match: _basics.MatchResult + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + new_initializers: Sequence[ir.Value] + used_opsets: _tape.UsedOpsets + + +def always_true(*args, **kwargs) -> bool: + """A condition function that always returns True. + + This is used when no condition function is provided for a rewrite rule. + """ + return True + + +class Pattern: + """A pattern that can be matched against nodes in an ONNX graph. + + This class encapsulates pattern matching functionality, providing the ability to + match patterns against nodes without requiring replacement functionality. + """ + + def __init__( + self, + target_pattern: _pattern_ir.GraphPattern | Callable, + condition_function: Callable | None = None, + matcher: _matcher.PatternMatcher + | Callable[[_pattern_ir.GraphPattern], _matcher.PatternMatcher] + | None = None, + verbose: int = 0, + name: str | None = None, + ) -> None: + """Create a pattern matcher. + + Args: + target_pattern: The _pattern_ir.GraphPattern that will be matched against the IR. + If a callable is provided, it will be converted to a _pattern_ir.GraphPattern. + condition_function: The condition function that will be used to check if + the pattern match found should be rewritten. + matcher: The pattern matcher that will be used to match the pattern. + If not provided, a default matcher will be used. + verbose: The verbosity level of the rule. + name: An optional name for the pattern that will show up in verbose logging. + """ + if not isinstance(target_pattern, _pattern_ir.GraphPattern): + target_pattern = _pattern_ir._to_graph_pattern(target_pattern) + self._target_pattern = target_pattern + + self._condition_function = condition_function or always_true + if isinstance(matcher, _matcher.PatternMatcher): + self._matcher = matcher + elif matcher is None: + self._matcher = _matcher.SimplePatternMatcher(self._target_pattern) + else: + self._matcher = matcher(self._target_pattern) + self._verbose = verbose + self.name = name + + def __str__(self) -> str: + return self.name if self.name else "Anonymous Pattern" + + def match( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + *, + verbose: int | None = None, + check_nodes_are_removable: bool = True, + tracer: _basics.MatchingTracer | None = None, + ) -> _basics.MatchResult | None: + """Check if the node matches the pattern and return the match result. + + Args: + model: The model containing the graph or function. + graph_or_function: The graph or function to match against. + node: The node to try to match the pattern against. + verbose: The verbosity level of messages. + check_nodes_are_removable: If True, validate that matched nodes can be safely removed. + tracer: The tracer for debugging. + + Returns: + MatchResult if the pattern matches successfully and passes the condition function, + None otherwise. + """ + if verbose and verbose > 2: + print(f"[match] {self}") + verbose = verbose if verbose is not None else self._verbose + match = self._matcher.match( + model, + graph_or_function, + node, + verbose=verbose, + remove_nodes=check_nodes_are_removable, + ) + if match: + context = _basics.MatchContext(model, graph_or_function, node, match) + for var in self._target_pattern.inputs: + if var.name is not None: + if var.name not in match.bindings: + match.bind(var.name, None) + + # Perform value/node level checks before condition function + def fail(check_result, default_message, failure_object=None): + """Local utility to handle check failures consistently.""" + if isinstance(check_result, _basics.MatchResult): + match.fail( + check_result.reason, + check_result.failure_nodes_and_values, + ) + else: + match.fail(default_message, failure_object) + if tracer: + tracer.log( + self, # type: ignore[arg-type] + graph_or_function, + node, + match, + _basics.MatchStatus.CONDITION_FAILED, + ) + return None + + def wrap_try(f): + """Encapsulates try-except pattern for check functions.""" + + def wrapped(*args, **kwargs): + try: + return f(*args, **kwargs) + except _basics.MatchFailureError as e: + result = _basics.MatchResult() + result.fail(e.reason, list(e.failure_sources)) + return result + + return wrapped + + # Check node-level checkers + for pattern_node, ir_node in match.node_bindings.items(): + if pattern_node.check_method is not None: + check_result = wrap_try(pattern_node.check_method)(context, ir_node) + if not check_result: + return fail( + check_result, + f"Node-level check failed for pattern node {pattern_node}", + ir_node, + ) + + # Check value-level checkers + for pattern_value, ir_value in match.value_bindings.items(): + if pattern_value.check_method is not None: + check_result = wrap_try(pattern_value.check_method)(context, ir_value) + if not check_result: + return fail( + check_result, + f"Value-level check failed for pattern value {pattern_value}", + ir_value, + ) + + check_match_result = wrap_try(self._condition_function)(context, **match.bindings) + if not check_match_result: + # If check function was provided, but it failed, return the reason for failure to the tracer. + return fail(check_match_result, "Condition function check failed") + if tracer: + tracer.log(self, graph_or_function, node, match, _basics.MatchStatus.SUCCESS) # type: ignore[arg-type] + return match + if tracer: + tracer.log(self, graph_or_function, node, match, _basics.MatchStatus.NO_MATCH) # type: ignore[arg-type] + return match + + +class ReplacementPatternFunction: + """The replacement pattern that will replace the targeted pattern. + + Attributes: + function (Callable): The replacement function that will be used to replace the matched pattern. + """ + + def __init__(self, function) -> None: + self._function = function + + def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None: + context = RewriterContext() + new_outputs = self._function(context, **match.bindings) + if new_outputs is None: + return None # Failed to create replacement subgraph + if not isinstance(new_outputs, Sequence): + new_outputs = [new_outputs] + return ReplacementSubgraph( + match, new_outputs, context.nodes, context.initializers, context.used_opsets + ) + + +def _update_opset_imports( + graph_or_function: ir.Graph | ir.Function, delta: ReplacementSubgraph +): + imports = graph_or_function.opset_imports + for domain, version in delta.used_opsets: + if domain not in imports: + # use 1 as default version if not explicitly specified + imports[domain] = version if version is not None else 1 + elif version is not None and version != imports[domain]: + raise ValueError( + f"Multiple versions of opset {domain} used. " + f"Expected version {imports[domain]}, but got {version}." + ) + + +class RewriteRule(Pattern): + def __init__( + self, + target_pattern: _pattern_ir.GraphPattern | Callable, + replacement_pattern: ReplacementPatternFunction | Callable, + condition_function: Callable | None = None, + matcher: _matcher.PatternMatcher + | Callable[[_pattern_ir.GraphPattern], _matcher.PatternMatcher] + | None = None, + verbose: int = 0, + name: str | None = None, + remove_nodes: bool = True, + graph_pre_visitor: Callable[[], None] | None = None, + graph_post_visitor: Callable[[], None] | None = None, + as_function: bool = False, + ) -> None: + """Create a rewrite rule. + + Args: + target_pattern: The _pattern_ir.GraphPattern that will be matched against the IR. + If a callable is provided, it will be converted to a _pattern_ir.GraphPattern. + replacement_pattern: The ReplacementPatternFunction that will be used to + replace the matched pattern. If a callable is provided, it will be + converted to a ReplacementPatternFunction. + condition_function: The condition function that will be used to check if + the pattern match found should be rewritten. + matcher: The pattern matcher that will be used to match the pattern. + If not provided, a default matcher will be used. + verbose: The verbosity level of the rule. + name: An optional name for the pattern that will show up in verbose logging. + remove_nodes: If True, the matched nodes will be removed from the graph. + graph_pre_visitor: A function that will be called before applying the + rewriting to the top-level graph or a function. + graph_post_visitor: A function that will be called after the rewriting + is complete for a graph or function. + as_function: If True, the matched nodes will be extracted into a model + local function. This is only supported when remove_nodes=True and + when the replacement subgraph has a single node, representing the + function call. + """ + if as_function and not remove_nodes: + raise ValueError("as_function=True is only supported when remove_nodes=True.") + + # Initialize the base pattern matching functionality + super().__init__(target_pattern, condition_function, matcher, verbose, name) + + if not isinstance(replacement_pattern, ReplacementPatternFunction): + replacement_pattern = ReplacementPatternFunction(replacement_pattern) + self._replacement_pattern = replacement_pattern + self.remove_nodes = remove_nodes + self.graph_pre_visitor = graph_pre_visitor + self.graph_post_visitor = graph_post_visitor + self.as_function = as_function + + def __str__(self) -> str: + return self.name if self.name else "Anonymous Rule" + + def try_rewrite( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + *, + verbose: int | None = None, + tracer: _basics.MatchingTracer | None = None, + ) -> ReplacementSubgraph | None: + """If the node matches the pattern, then replace the node with the replacement pattern.""" + # Use the inherited match method from Pattern + match = self.match( + model, + graph_or_function, + node, + verbose=verbose, + check_nodes_are_removable=self.remove_nodes, + tracer=tracer, + ) + if not match: + return None + + replacement_subgraph = self._replacement_pattern.get_replacement(match) + if replacement_subgraph is None: + if tracer: + tracer.log( + self, + graph_or_function, + node, + match, + _basics.MatchStatus.REPLACEMENT_FAILED, + ) + return None + if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: + raise ValueError( + f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " + f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}." + ) + # TODO(rama): Remove the opset imports from deleted nodes? + _update_opset_imports(graph_or_function, replacement_subgraph) + _update_opset_imports(model.graph, replacement_subgraph) + return replacement_subgraph + + def apply_to_model( + self, + model: ir.Model, + *, + commute: bool = False, + verbose: int | None = None, + tracer: _basics.MatchingTracer | None = None, + ): + # A convenience method to apply the rule to a model. We use a RewriteRuleSet to + # handle commutative rules. + return RewriteRuleSet([self], commute=commute).apply_to_model( + model, verbose=verbose, tracer=tracer + ) + + def commute(self) -> Sequence[RewriteRule]: + def replace_pattern(new_pattern): + """Return a shallow copy of self with node_pattern replaced by new_pattern.""" + # TODO(rama): Maybe we should use a better alternative to construct new matcher. + matcher_class = type(self._matcher) + return RewriteRule( + new_pattern, + self._replacement_pattern, + self._condition_function, + matcher_class(new_pattern), + self._verbose, + self.name, + self.remove_nodes, + self.graph_pre_visitor, + self.graph_post_visitor, + self.as_function, + ) + + return [replace_pattern(p) for p in self._target_pattern.commute()] + + +class PatternBase(abc.ABC): + """Base class for implementing pattern matching as a class. + + This class encapsulates the pattern definition and condition checking + without the replacement functionality. + + Example:: + + class TransposePattern(PatternBase): + def pattern(cls, op, x, perm): + return op.Transpose(x, perm=perm) + + def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: + if perm.is_ref(): + return False + if perm.type == ir.AttributeType.INTS: + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): + return True + return False + """ + + def __init__(self, name: str | None = None, **kwargs) -> None: + self.name = name or self.__class__.__name__ + # Initialize to None and create on demand to avoid construction order issues + self._compiled_pattern: Pattern | None = None + self._pattern_kwargs = kwargs + + @abc.abstractmethod + def pattern(self, op, *args, **kwargs): + raise NotImplementedError("Method 'pattern' must be implemented by derived class.") + + def check(self, op, *args, **kwargs) -> _basics.MatchResult: + """Default check function that returns a _basics.MatchResult object with success always set to True.""" + return _basics.MatchResult() + + def match( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + *, + verbose: int | None = None, + check_nodes_are_removable: bool = True, + tracer: _basics.MatchingTracer | None = None, + ) -> _basics.MatchResult | None: + """Check if the node matches the pattern and return the match result. + + Args: + model: The model containing the graph or function. + graph_or_function: The graph or function to match against. + node: The node to try to match the pattern against. + verbose: The verbosity level of messages. + check_nodes_are_removable: If True, validate that matched nodes can be safely removed. + tracer: The tracer for debugging. + + Returns: + MatchResult if the pattern matches successfully and passes the condition function, + None otherwise. + """ + # Create the compiled pattern on demand if not already created + if self._compiled_pattern is None: + self._compiled_pattern = Pattern( + self.pattern, self.check, name=self.name, **self._pattern_kwargs + ) + return self._compiled_pattern.match( + model, + graph_or_function, + node, + verbose=verbose, + check_nodes_are_removable=check_nodes_are_removable, + tracer=tracer, + ) + + +class RewriteRuleClassBase(PatternBase): + """Base class for implementing rewrite rules as a class. + + Example:: + + class TransposeIdentity(RewriteRuleClassBase): + def pattern(cls, op, x, perm): + return op.Transpose(x, perm=perm) + + def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: + if perm.is_ref(): + return False + if perm.type == ir.AttributeType.INTS: + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): + return True + return False + + def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): + return op.Identity(x) + + # Then use + # TransposeIdentity.rule() + # to create a RewriteRule object. + + """ + + @classmethod + def rule(cls, *args, **kwargs): + instance = cls(*args, **kwargs) + return RewriteRule( + instance.pattern, + instance.rewrite, + instance.check, + name=instance.name, + remove_nodes=instance.remove_nodes, + graph_pre_visitor=instance.setup, + graph_post_visitor=instance.cleanup, + as_function=instance.as_function, + ) + + def __init__( + self, name: str | None = None, remove_nodes: bool = True, as_function: bool = False + ) -> None: + super().__init__(name) + self.remove_nodes = remove_nodes + self.as_function = as_function + + @abc.abstractmethod + def rewrite(self, op, *args, **kwargs): + raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") + + def setup(self): + """Optional setup function that can be overridden by derived classes. + + Used to do per model/function initialization. + """ + return + + def cleanup(self): + """Optional cleanup function that can be overridden by derived classes. + + Used to do per model/function cleanup. + """ + return + + +def _copy_for_function( + inputs: Sequence[ir.Value | None], nodes: Sequence[ir.Node], outputs: Sequence[ir.Value] +): + """Utility function to extract a subgraph out as a function.""" + value_map: dict[ir.Value, ir.Value] = {} + function_inputs: list[ir.Value] = [] + constant_nodes: list[ir.Node] = [] + for input in inputs: + # Create a function input (formal-parameter value) to represent this value: + new_value = ( + ir.Value( + name=input.name, + shape=input.shape, + type=input.type, + doc_string=input.doc_string, + ) + if input + else ir.Value() # dummy parameter for a None input + ) + if input is not None: + value_map[input] = new_value + function_inputs.append(new_value) + + def copy_value(value: ir.Value | None) -> ir.Value | None: + if value is None: + return None + if value not in value_map: + const_value = value.const_value + if const_value is not None: + # create a Constant node to represent the value + value_attr = ir.AttrTensor("value", const_value) + const_node = ir.Node("", "Constant", [], [value_attr]) + constant_nodes.append(const_node) + value_map[value] = result = const_node.outputs[0] + return result + raise ValueError(f"Value {value} not found in value_map.") + return value_map[value] + + def copy_attr_value(attr: ir.Attr) -> ir.Attr: + if attr.is_ref(): + # No need to support this currently, as rewriting inside a function is + # not used, as it has several challenges. + raise NotImplementedError("RefAttr not supported.") + if attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}: + # No need to support this currently, as rewriting control-flow constructs + # is not used and has several challenges. + raise NotImplementedError("Graph attributes not supported.") + # Primitive attributes are immutable by design and can be shared. + return attr + + def copy_node(node: ir.Node) -> ir.Node: + new_inputs = [copy_value(v) for v in node.inputs] + new_attributes = [copy_attr_value(v) for v in node.attributes.values()] + new_node = ir.Node( + node.domain, + node.op_type, + new_inputs, + new_attributes, + overload=node.overload, + num_outputs=len(node.outputs), + graph=None, + name=node.name, + doc_string=node.doc_string, # type: ignore + metadata_props=node.metadata_props.copy(), + ) + new_outputs = new_node.outputs + for i, output in enumerate(node.outputs): + value_map[output] = new_outputs[i] + if output.name is not None: + new_outputs[i].name = output.name + return new_node + + function_nodes = [copy_node(node) for node in nodes] + function_outputs = [copy_value(v) for v in outputs] + return (function_inputs, constant_nodes + function_nodes, function_outputs) + + +def _get_new_overload(model: ir.Model, domain: str, name: str) -> str: + """Get a new overload for the given domain and name. + + Args: + model: The model to which the new overload will be added. + domain: The domain of the new overload. + name: The opname of the new overload. + + Returns: + The new overload name. + """ + existing_functions = model.functions + # Just a simple implementation for now + overload = 1 + while True: + overload_name = str(overload) + if (domain, name, overload_name) not in existing_functions: + return overload_name + overload += 1 + + +class RewriteRuleSet: + def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: + if not rules: + raise ValueError("rules must contain at least one rule") + if commute: + rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules])) + self.rules = rules + # We call remove_unused_nodes at end of rewriting if there is any rule that does + # NOT remove nodes (immediately when it is applied) + self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.rules})" + + def _apply_to_graph_or_function( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + *, + verbose: int | None, + tracer: _basics.MatchingTracer | None = None, + ) -> int: + """ + Apply the rewrite rules to the given graph or function. + + Args: + model: The model to which the rewrite rules are applied. + graph_or_function: The graph or function to which the rewrite rules are applied. + verbose: The verbosity level. Defaults to None. + tracer: The tracer for debugging. Defaults to None. + + Returns: + The number of rewrite rules applied. + """ + count = 0 + + for rule in self.rules: + if rule.graph_pre_visitor: + rule.graph_pre_visitor() + + for node in graph_or_function: + for rule in self.rules: + delta = rule.try_rewrite( + model, graph_or_function, node, verbose=verbose, tracer=tracer + ) + if delta is None or tracer is not None: + continue + assert isinstance(delta, ReplacementSubgraph) + if delta.new_initializers: + if isinstance(graph_or_function, ir.Function): + # TODO(rama): Can't add initializers to functions. But currently this is not + # an issue, as we apply inlining before applying rewrite rules. + if verbose: + print( + f"Rewrites adding initializers not supported for functions: {rule}" + ) + continue + initializers = graph_or_function.initializers + for initializer in delta.new_initializers: + if initializer.name in initializers: + if verbose: + print(f"Initializer {initializer.name} already exists.") + continue + for initializer in delta.new_initializers: + initializers[initializer.name] = initializer # type: ignore[index] + # TODO: This does not yet handle the problem of determining the correct insertion point + # for inserted nodes in the case of patterns with multiple output-nodes. The following + # is sufficient for patterns with a single output-node "node", which can serve as the + # insertion-point. + onnxscript.optimizer.basic_constant_propagation(delta.new_nodes) + if rule.as_function: + # Create a function out of a copy of the matched nodes + if len(delta.new_nodes) != 1: + raise ValueError( + "as_function=True is only supported for patterns with a single replacement node." + ) + call_node = delta.new_nodes[0] + domain = call_node.domain + name = call_node.op_type + overload = _get_new_overload(model, domain, name) + call_node.overload = overload + + # Create topologically sorted list of nodes to be replaced. + unsorted_nodes = set(delta.match.nodes) + original_nodes = [n for n in graph_or_function if n in unsorted_nodes] + # Create new inputs/nodes/outputs for the function + inputs, nodes, outputs = _copy_for_function( + call_node.inputs, original_nodes, delta.match.outputs + ) + + used_domains: set[str] = {node.domain for node in original_nodes} + parent_opset_imports = graph_or_function.opset_imports + used_opset_imports = { + k: v for k, v in parent_opset_imports.items() if k in used_domains + } + + graph = ir.Graph( + inputs, outputs, nodes=nodes, opset_imports=used_opset_imports + ) + f = ir.Function(domain, name, overload, graph=graph, attributes=()) + model.functions[f.identifier()] = f + + if verbose: + name = f"{rule.name}: " if rule.name else "" + print(f"----{name}Matched Nodes----") + _ir_utils.display_nodes(delta.match.nodes) + print("++++Replacement Nodes++++") + _ir_utils.display_nodes(delta.new_nodes) + print("++++End Replacement Nodes++++") + + convenience.replace_nodes_and_values( + graph_or_function, + node, + delta.match.nodes if rule.remove_nodes else [], + delta.new_nodes, + delta.match.outputs, + delta.new_outputs, + ) + + count += 1 + break + + # Apply rewrite rules to subgraphs of the node. + for attr in node.attributes.values(): + if attr.type == ir.AttributeType.GRAPH: + count += self._apply_to_graph_or_function( + model, attr.value, verbose=verbose, tracer=tracer + ) + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.value: + count += self._apply_to_graph_or_function( + model, graph, verbose=verbose, tracer=tracer + ) + + for rule in self.rules: + if rule.graph_post_visitor: + rule.graph_post_visitor() + + return count + + def apply_to_model( + self, + model: ir.Model, + *, + verbose: int | None = None, + tracer: _basics.MatchingTracer | None = None, + ) -> int: + """Apply the rewrite rules in the set to the model. + + Args: + model: The model to which the rewrite rules are applied. + verbose: The verbosity level of messages. Defaults to None. + tracer: if specified, no changes are made to the model, only + information about the best matches found is computed. + + Returns: + The number of applications of rewrite rules. + """ + assert isinstance(model, ir.Model) + onnxscript.optimizer.basic_constant_propagation(model.graph) + # Rewriting may introduce new functions. In the following loop, + # we restrict rewriting to original functions, not newly introduced ones. + original_functions = list(model.functions.values()) + count = self._apply_to_graph_or_function( + model, model.graph, verbose=verbose, tracer=tracer + ) + for function in original_functions: + onnxscript.optimizer.basic_constant_propagation(function) + count += self._apply_to_graph_or_function( + model, function, verbose=verbose, tracer=tracer + ) + if self.remove_unused_nodes: + onnxscript.optimizer.remove_unused_nodes(model) + return count + + def __iter__(self): + yield from self.rules diff --git a/onnxscript/rewriter/_tape.py b/onnxscript/rewriter/_tape.py deleted file mode 100644 index 5b35b0dbca..0000000000 --- a/onnxscript/rewriter/_tape.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Convenience methods for constructing the IR.""" - -# NOTE: This is a temporary solution for constructing the IR. It should be replaced -# with a more permanent solution in the future. - -from __future__ import annotations - -from typing import Iterable, Mapping, Sequence - -from onnxscript import ir -from onnxscript.ir import _convenience - - -class Tape(Iterable[ir.Node]): - """A tape for recording nodes that are created.""" - - def __init__(self) -> None: - self._nodes: list[ir.Node] = [] - - def __iter__(self) -> Sequence[ir.Node]: - return self._nodes - - @property - def nodes(self) -> Sequence[ir.Node]: - return tuple(self._nodes) - - def op( - self, - op_type: str, - inputs: Sequence[ir.Value | None], - attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, - domain: str = "", - ) -> ir.Value: - if attributes is None: - attrs: Sequence[ir.Attr | ir.RefAttr] = () - else: - attrs = _convenience.convert_attributes(attributes) - node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1) - self._nodes.append(node) - - return node.outputs[0] - - def op_multi_output( - self, - op_type: str, - inputs: Sequence[ir.Value | None], - attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, - *, - num_outputs: int, - domain: str = "", - ) -> Sequence[ir.Value]: - if attributes is None: - attrs: Sequence[ir.Attr | ir.RefAttr] = () - else: - attrs = _convenience.convert_attributes(attributes) - node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs) - self._nodes.append(node) - - return node.outputs diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/erfgelu.py deleted file mode 100644 index 59d689cee2..0000000000 --- a/onnxscript/rewriter/erfgelu.py +++ /dev/null @@ -1,30 +0,0 @@ -import math - -from onnxscript.rewriter import pattern - -op = pattern.onnxop - - -# Pattern to match against -def erf_gelu_pattern(x): - # erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - # half = pattern.Constant(0.5) - # sqrt2 = pattern.Constant(1.4142) - # x_div_sqrt2 = op.Div(x, sqrt2) - # erf = op.Erf(x_div_sqrt2) - # one = pattern.Constant(1.0) - # one_plus_erf = op.Add(erf, one) - # x_mul_one_plus_erf = op.Mul(x, one_plus_erf) - # return op.Mul(half, x_mul_one_plus_erf) - return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) - - -msft_op = pattern.msft_op - - -# Replacement -def gelu(op, x): - return op.Gelu(x, domain="com.microsoft") - - -rule = pattern.RewriteRule(erf_gelu_pattern, gelu) diff --git a/onnxscript/rewriter/function_rule.py b/onnxscript/rewriter/function_rule.py deleted file mode 100644 index b9272dffdb..0000000000 --- a/onnxscript/rewriter/function_rule.py +++ /dev/null @@ -1,230 +0,0 @@ -from __future__ import annotations - -import functools -import logging -from typing import Callable - -import onnx -from packaging import version - -import onnxscript -from onnxscript import ir -from onnxscript.rewriter import pattern - -logger = logging.getLogger(__name__) - - -class FunctionRewriteError(RuntimeError): ... - - -@functools.lru_cache -def parse_domain(function_domain: str) -> tuple[str, version.Version | None]: - splits = function_domain.split(".") - if splits[0] != "pkg": - raise FunctionRewriteError( - f"Invalid domain: {function_domain}. Must start with 'pkg'." - ) - splits = splits[1:] - for i, s in enumerate(splits): - if s.isdigit(): - return ".".join(splits[:i]), version.parse(".".join(splits[i:])) - return ".".join(splits), None - - -MIN_VERSION = version.parse("0") -MAX_VERSION = version.parse("9999") - - -class VersionController: - def __init__(self): - # A dispatch table for rewrite implementation based on the function package version. - self.dispatch_table: dict[tuple[version.Version, version.Version], Callable] = {} - - def register_version( - self, - min_version: version.Version | str | None = None, - max_version: version.Version | str | None = None, - ): - """Register a function implementation for a specific package version range [min_version, max_version). - - Args: - min_version: The minimum version of the package. Inclusive. - max_version: The maximum version of the package. Exclusive. - """ - # TODO: check for version overloap - - min_version = MIN_VERSION if min_version is None else min_version - max_version = MAX_VERSION if max_version is None else max_version - if isinstance(min_version, str): - min_version = version.parse(min_version) - if isinstance(max_version, str): - max_version = version.parse(max_version) - - def deco(func): - self.dispatch_table[(min_version, max_version)] = func - return func - - return deco - - def dispatch(self, version: version.Version | None) -> Callable | None: - if version is None: - if len(self.dispatch_table) == 1: - return next(iter(self.dispatch_table.values())) - raise ValueError( - "No function package version specified, however there are multiple " - f"fusion rules based on package version: {self.dispatch_table.keys()}." - ) - for (min_version, max_version), func in self.dispatch_table.items(): - greater_than_min = min_version is None or min_version <= version - less_than_max = max_version is None or version < max_version - if greater_than_min and less_than_max: - return func - return None - - -class FunctionRewriteRule(pattern.RewriteRule): - FUNCTION_KEYWORD: str | tuple[str] - """The keyword to match the function name. If a tuple, any keyword will match.""" - - PACKAGE_NAME: str - """The package name to match. - - For example, 'transformers' to match for domain name 'pkg.transformers.4.36.2'. - """ - - _opset_imports: dict[str, int] - onnx_opset: onnxscript.values.Opset - - def __init__(self, opset: onnxscript.values.Opset = onnxscript.opset18) -> None: # type: ignore[has-type] - self.onnx_opset = opset - - def _match_function(self, function: ir.Function, pkg_name: str) -> bool: - # TODO: Consolidate more checks from `compose_new_function` to here. - if pkg_name != self.PACKAGE_NAME: - logger.info( - "Rule %s did not match function %s::%s. Package name mismatch '%s' != '%s'.", - self.__class__.__name__, - function.domain, - function.name, - self.PACKAGE_NAME, - pkg_name, - ) - return False - if isinstance(self.FUNCTION_KEYWORD, str): - return function.name.find(self.FUNCTION_KEYWORD) != -1 - elif isinstance(self.FUNCTION_KEYWORD, tuple): - return any(function.name.find(keyword) != -1 for keyword in self.FUNCTION_KEYWORD) - else: - raise ValueError( # noqa: TRY004 - f"Function keyword must be str or tuple, got {self.FUNCTION_KEYWORD}" - ) - - def _find_node_contains_key_in_name( - self, function: onnx.FunctionProto, keyword: str - ) -> onnx.NodeProto | None: - for node in function.node: - if node.name.find(keyword) != -1: - return node - return None - - def _find_node_by_type( - self, function: ir.Function, domain: str, op_type: str - ) -> ir.Node | None: - # Repeat - for node in function: - if node.domain == domain and node.op_type == op_type: - return node - return None - - def compose_new_function( - self, old_function: ir.Function, pkg_version: version.Version | None - ) -> ir.Function: - """Compose a new function from the old function. - - Returns: - A tuple of the new function and the opset imports. - - Raises: - FunctionRewriteError: If the rewrite fails. - """ - # self._version_controller is created in the subclass - func = self._version_controller.dispatch(pkg_version) # type: ignore[attr-defined] - if func is not None: - new_function = func(self, old_function) - return new_function - raise FunctionRewriteError( - f"No rewrite implementation for package version {pkg_version}." - ) - - def try_rewrite_function( - self, function: ir.Function - ) -> tuple[ir.OperatorIdentifier, ir.Function] | None: - try: - pkg_name, pkg_version = parse_domain(function.domain) - except FunctionRewriteError as e: - logger.warning("Could not parse domain: %s", e) - return None - - if pkg_version is None and not pkg_name.startswith("onnxscript"): - logger.warning( - "Could not parse version for domain of function %s::%s. " - "Usually this implies the model source is not from a package, but from arbitrary python files instead. " - "For example, models not defined in huggingface/transformers but loaded via 'trust_remote_code=True'.", - function.domain, - function.name, - ) - - if not self._match_function(function, pkg_name): - return None - logger.info( - "Rule %s matched function %s::%s", - self.__class__.__name__, - function.domain, - function.name, - ) - try: - new_function = self.compose_new_function(function, pkg_version) - except FunctionRewriteError as e: - logger.warning("Could not rewrite function: %s", e) - return None - - new_function.name = function.name - new_function.domain = function.domain - - return function.identifier(), new_function - - def try_rewrite(self, model: ir.Model, value) -> bool: - raise NotImplementedError( - "Use `try_rewrite_function` instead for function based rewrites." - ) - - def apply_to_model( - self, model: ir.Model, *, commute: bool = False - ) -> tuple[int, ir.Model]: - del commute # unused - - old_function_to_new_function: dict[ir.OperatorIdentifier, ir.Function] = {} - for function in model.functions.values(): - rewrite_or_none = self.try_rewrite_function(function) - if rewrite_or_none is not None: - old_function_to_new_function[rewrite_or_none[0]] = rewrite_or_none[1] - model = self.update_to_new_function(model, old_function_to_new_function) - return len(old_function_to_new_function), model - - def update_to_new_function( - self, - model: ir.Model, - old_function_to_new_function: dict[ir.OperatorIdentifier, ir.Function], - ) -> ir.Model: - for old_function_id, new_function_ir in old_function_to_new_function.items(): - model.functions[old_function_id] = new_function_ir - for new_opset, opset_version in new_function_ir.opset_imports.items(): - if new_opset not in model.opset_imports: - model.opset_imports[new_opset] = opset_version - return model - - def count_matches(self, model, *, commute: bool = False) -> int: - raise NotImplementedError() - - def commute(self) -> list[pattern.RewriteRule]: - raise NotImplementedError() diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py deleted file mode 100644 index 2a92cda98d..0000000000 --- a/onnxscript/rewriter/generic_pattern.py +++ /dev/null @@ -1,754 +0,0 @@ -from __future__ import annotations - -import collections -import inspect -import os -import textwrap -from typing import Any, Callable, Iterator, Sequence - -import onnxscript.rewriter.pattern as orp -from onnxscript import ir - - -class PatternMatchResult: - """Stores information about a match if a match was successful. - - * pattern: the instance of :class:`GenericPattern` which found this result - * model_nodes: matched nodes coming from the model - * pattern_nodes: corresponding nodes coming from the pattern - * pattern_input_names: input names of the pattern - * pattern_ouptut_names: output names of the pattern - * kwargs: additional attributes the user may add through the method - :meth:`PatternMatchResult.add_kwargs` - - The class creates one attributes `matched_pattern_to_model_name`, - which maps every result name from the pattern to the corresponding - result name in the model. - """ - - def __init__( - self, - pattern: GenericPattern, - model_nodes: Sequence[ir.Node], - pattern_nodes: Sequence[ir.Node], - pattern_inputs: Sequence[ir.Value], - pattern_outputs: Sequence[ir.Value], - ): - assert len(model_nodes) == len(pattern_nodes) - self.pattern = pattern - self.model_nodes = model_nodes - self.pattern_nodes = pattern_nodes - self.pattern_inputs = pattern_inputs - self.pattern_outputs = pattern_outputs - self.kwargs: dict[str, Any] = {} - - matched_pattern_to_model_value: dict[str, ir.Value] = {} - for gn, pn in zip(model_nodes, pattern_nodes): - assert ( - gn.op_type == pn.op_type - ), f"Unexpected type mismatch {gn.op_type!r} != {pn.op_type!r}" - assert len(gn.inputs) == len( - pn.inputs - ), f"Unexpected number of inputs for type {gn.op_type}" - for a, b in zip(gn.inputs, pn.inputs): - if b is None: - # optional input or not an interesting input - continue - b_name = b.name - assert b_name is not None - if b_name in matched_pattern_to_model_value: - assert matched_pattern_to_model_value[b_name] == a, ( - f"Ambiguities, pattern input '{b_name}' means " - f"'{a!r}' or '{matched_pattern_to_model_value[b_name]}'" - ) - else: - matched_pattern_to_model_value[b_name] = a - - assert len(gn.outputs) == len( - pn.outputs - ), f"Unexpected number of outputs for type {gn.op_type}" - for a, b in zip(gn.outputs, pn.outputs): - b_name = b.name - assert b_name is not None - if b_name in matched_pattern_to_model_value: - assert matched_pattern_to_model_value[b_name] == a, ( - f"Ambiguities, pattern output {b_name!r} means " - f"{a!r} or {matched_pattern_to_model_value[b_name]}" - ) - else: - matched_pattern_to_model_value[b_name] = a - - self.matched_pattern_to_model_value = matched_pattern_to_model_value - - def add_kwargs(self, name: str, value: Any): - """Adds an attribute, it can be done when the match is being validated, - this attribute can be used when building the replacement nodes. - """ - self.kwargs[name] = value - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}([{self.pattern.__class__.__name__}], " - f"... {len(self.model_nodes)} nodes ..., {self.pattern_inputs}, " - f"{self.pattern_outputs})" - ) - - -def _to_match_result(pmr: PatternMatchResult) -> orp.MatchResult: - """Converts a PatternMatchResult into a MatchResult. - - TODO: This is a temporary hack until MatchResult and PatternMatchResult are unified. - """ - result = orp.MatchResult(success=True) - result.nodes.extend(pmr.model_nodes) - for var, val in pmr.matched_pattern_to_model_value.items(): - result.bind(var, val) - result.outputs.extend( - [pmr.matched_pattern_to_model_value[v.name] for v in pmr.pattern_outputs] - ) - return result - - -class GenericRewriteRule(orp.RewriteRule): - """ - Defines a rewriting rule. - - pattern: a pattern defines by :class:`GenericPattern`. - """ - - def __init__(self, pattern: GenericPattern): - self.pattern = pattern - self.verbose: int = 0 # TODO: remove this - - def matches(self, node: ir.Node, model: ir.Model) -> orp.MatchResult: - del model - del node - raise RuntimeError(f"This pattern {self} is meant to replace not to only match.") - - def try_rewrite( - self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node - ) -> orp.ReplacementSubgraph | None: - """See :meth:`RewriteRule.try_rewrite`.""" - - pattern_match_result = self.pattern.match(model.graph, node) - if pattern_match_result: - match_result = _to_match_result(pattern_match_result) - context = None # TODO: create a context - if not self.pattern.validate_mapping(context, **match_result.bindings): - pattern_match_result._hint( - "validate_mapping", "The pattern was rejected by the validation function." - ) - return None - - return self.pattern.apply(model, match_result, verbose=self.verbose) - return None - - def count_matches(self, model: ir.Model, *, commute: bool = False) -> int: - """See :meth:`RewriteRule.count_matches`.""" - raise NotImplementedError("Not supported yet.") - - def commute(self) -> list[orp.RewriteRule]: - """See :meth:`RewriteRule.commute`.""" - raise RuntimeError("Not supported (yet?). It could lead to many patterns.") - - def apply_to_model(self, model: ir.Model, *, commute: bool = False) -> int: - """See :meth:`RewriteRule.apply_to_model`.""" - return orp.RewriteRuleSet([self], commute=commute).apply_to_model(model) - - -class GenericPattern: - """ - Implements a pattern optimization for quick experimentation. - - Current limitation: - - * The current implementation does match on domain name (easy fix). - * It does not compares attributes either (easy fix as well). - """ - - def __init__(self, verbose: int = 0): - self.verbose = verbose - self._cache: dict = {} - - def enumerate_matches( - self, graph: ir.Graph | ir.GraphView, node: ir.Node | None = None - ) -> Iterator: - """Enumerates all the matches.""" - if node is None: - matched = [] - for node in graph: - res = self.match(graph, node) - if res: - matched.append(res) - yield res - else: - res = self.match(graph, node) - if res: - yield res - - def none( - self, - node: ir.Node | None = None, - lineno: int | None = None, - msg: str = "", - ) -> None: - """Must be called every time a match fails to trace it. - - It may be useful which reason made a pattern matching fail. - Instead of returning None, method *match* can return the following - expression: - - :: - - return self.none(node, inspect.currentframe().f_lineno) - - By setting the verbosity (see next Section), the user may then know - which lines in the code returned None and which condition failed. - If logs are fully enabled, it shows information about matched none - and the line deciding the matched failed. - For example, this tells the matching failed at line 601 in ``generic_pattern.py``. - It happens when propagating the match in the backward directions. - The unmatched types are Mul, MatMul and below, - it shows the matched nodes. The first one was Cast. - And the failure happened at iteration 5. - ``139774002356544-139774000632672`` is the pair of ids used in container ``matched``. - ``id(node)`` is used as a unique identifiers of the nodes. - - :: - - [RotaryEmbeddingPattern.match] NONE - line: 601:__main__, op_type=Cast - --hint--: BACKWARD: different node types - --pattern - Mul(pos_ids, cast) -> (mul) - -- model - MatMul(/_original_modu...Expand_output_0, /_original_modu...b/Cast_output_0) -> (/_original_modu...MatMul_output_0) - iteration=5 - --matched-- #6 - Cast(/_original_modu...mb/Cos_output_0) ~ Cast(cos) [139774002356544-139774000632672] - Cos(/_original_modu...ncat_1_output_0) ~ Cos(concattraining-transpose-0) [139774002356448-139774000632048] - ConcatTraining(/_original_modu...nspose_output_0,/_original_modu...nspose_output_0) ~ ConcatTraining(transpose,transpose) [139774002356352-139774000631712] - Transpose(/_original_modu...MatMul_output_0) ~ Transpose(mul) [139774002356256-139774000631184] - Sin(/_original_modu...ncat_1_output_0) ~ Sin(concattraining-transpose-0) [139774002358512-139774000631568] - Cast(/_original_modu...mb/Sin_output_0) ~ Cast(sin) [139774002358608-139774000632384] - len(stack)=0:[] - - 'hints' are not added everywhere. More can easily be added with method ``_hint``. - """ - if node and self.verbose: - if self.verbose >= 10: - if hasattr(self, "_debug"): - msg2 = self._debug_print() - if msg2: - msg2 = f"\n{textwrap.indent(msg2, ' ')}" - else: - msg2 = "" - print( - f"[{self.__class__.__name__}.match] NONE - line: {lineno}:" - f"{os.path.split(self.__class__.__module__)[-1]}, " - f"op_type={node.op_type}{msg}{msg2}" - ) - - def print_match(self, n1: ir.Node, n2: ir.Node) -> str: - s1 = f"{n1.op_type}({n1.inputs})" - s2 = f"{n2.op_type}({n2.inputs})" - return f"match {s1} with {s2} (pattern)" - - def _debug_print(self) -> str: - if not hasattr(self, "_debug"): - return "" - - def _s(s: str) -> str: - if len(s) <= 30: - return s - return f"{s[:15]}...{s[-15:]}" - - def _p(n: ir.Node, full: bool = False) -> str: - if full: - return str(n) - return f"{n.op_type}({', '.join([str(input) for input in n.inputs])})" - - rows = [] - for k, v in sorted(self._debug.items()): - if k == "stack": - rows.append(f"len({k})={len(v)}:{v}") # type: ignore[arg-type] - continue - if k == "iteration": - rows.append(f"{k}={v}") - continue - if k == "matched": - rows.append(f"--matched-- #{len(v)}") # type: ignore[arg-type] - for pattern_node, graph_node in v.items(): - rows.append( - f" {_p(pattern_node)} ~ {_p(graph_node)} [{id(pattern_node)}-{id(graph_node)}]" - ) - continue - if k == "hint": - rows.append(f"--hint--: {v[0]}") # type: ignore[arg-type] - for i in v[1:]: - if isinstance(i, ir.Node): - rows.append(" " + _p(i, full=True)) - continue - if k in {"node", "pattern", "pattern_node", "pattern_nodes"}: - continue - rows.append(f"-- not shown {k}") - - return "\n".join(rows) - - def _hint(self, *args: Any) -> None: - """Add debugging information to help users.""" - self._debug["hint"] = args - - def _match_backward( - self, - node: ir.Node, - matched: dict[ir.Node, ir.Node], - stack: list[ir.Node], - graph_node: ir.Node, - pattern_node: ir.Node, - ) -> int | None: - """ - Matches backward. - - Args: - node: root node (the node the matched begain with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_node: node coming from the graph - pattern_node: node coming from the pattern - - Returns: - number of matched nodes, None or False to indicate a failed match - """ - match_count = 0 - - # predecessors - if len(graph_node.inputs) != len(pattern_node.inputs): - # not the same number of inputs - self._hint( - "BACKWARD: not the same number of inputs", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(node, inspect.currentframe().f_lineno) - for i, pi in zip(graph_node.inputs, pattern_node.inputs): - ppred = pi.producer() - if ppred is None: - # ppred is None means the pattern ends here. - continue - pred = i.producer() - if pred is None: - # No node in the graph. - return self.none(node, inspect.currentframe().f_lineno) - if pred.op_type != ppred.op_type: - self._hint( - "BACKWARD: different node types", - "--pattern", - ppred, - "-- model", - pred, - ) - return self.none(node, inspect.currentframe().f_lineno) - # matching backward - if ppred not in matched: - if self.verbose >= 10: - print(f"[GenericPattern._match_backward] {self.print_match(pred, ppred)}") - matched[ppred] = pred - stack.append(ppred) - match_count += 1 - if self.verbose > 5 and match_count > 0: - print(f"[GenericPattern._match_backward] add {match_count} nodes") - return match_count - - def _match_forward( - self, - root_node: ir.Node, - matched: dict[ir.Node, ir.Node], - stack: list[int], - graph_node: ir.Node, - pattern_node: ir.Node, - ) -> int | None: - """ - Matches forward. - - Args: - root_node: root node (the node the match begins with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_node: node coming from the graph - pattern_node: node coming from the pattern - - Returns: - number of matched nodes to continue, None or False to indicate a failed match - """ - match_count = 0 - - # successors - if len(graph_node.outputs) != len(pattern_node.outputs): - # not the same number of outputs - self._hint( - "FORWARD: not the same number of output_names", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(root_node, inspect.currentframe().f_lineno) - - for o, op in zip(graph_node.outputs, pattern_node.outputs): - graph_node_users = [user for user, _ in o.uses()] - pattern_node_users = [user for user, _ in op.uses()] - if not pattern_node_users: - # The pattern has no node forward, the matching stops. - continue - if len(graph_node_users) < len(pattern_node_users): - # Not enough node in the graph to match the pattern. A match is not possible - return self.none(root_node, inspect.currentframe().f_lineno) - - # Here comes the fun part, there is the same number of successors or more - # nodes in the graph to match with the pattern. - # And we have to handle the nodes already matched as found. - # Hopefully, there is only one option. - - if len(graph_node_users) == len(pattern_node_users) == 1: - # Let's deal with the simple case - if graph_node_users[0].op_type != pattern_node_users[0].op_type: - return self.none(root_node, inspect.currentframe().f_lineno) - - node = pattern_node_users[0] - if node not in matched: - if self.verbose >= 10: - print( - f"[GenericPattern._match_forward]{self.print_match(graph_node_users[0], pattern_node_users[0])}" - ) - matched[node] = graph_node_users[0] - stack.append(node) - match_count += 1 - continue - - # Let's remove the nodes already matched. - pattern_node_users_not_matched = [ - unmatched_node - for unmatched_node in pattern_node_users - if unmatched_node not in matched - ] - pattern_node_users_matched = [ - matched[matched_node] - for matched_node in pattern_node_users - if matched_node in matched - ] - assert len(pattern_node_users_matched) + len( - pattern_node_users_not_matched - ) == len(pattern_node_users), ( - f"pattern_node_users_not_matched={pattern_node_users_not_matched}, " - f"pattern_node_users_matched={pattern_node_users_matched}, " - f"pattern_node_users={pattern_node_users}, " - f"matched={matched}" - ) - free = list(set(graph_node_users) - set(pattern_node_users_matched)) - if not pattern_node_users_not_matched: - # Everything is already matched. - continue - if len(free) < len(pattern_node_users_not_matched): - # Not enough successors to match the remaining patterns. - return self.none(node, inspect.currentframe().f_lineno) - if len(pattern_node_users_not_matched) == len(free) == 1: - # Only one option again. - graph_node = free[0] - if pattern_node_users_not_matched[0].op_type != graph_node.op_type: - return self.none(node, inspect.currentframe().f_lineno) - - key = pattern_node_users_not_matched[0] - if self.verbose >= 10: - print( - f"[GenericPattern._match_forward] {self.print_match(graph_node, pattern_node_users_not_matched[0])}" - ) - matched[key] = graph_node - stack.append(key) - match_count += 1 - continue - - # And now another fun part, let's try to handle the case when - # there is only one option, matching on node type only returns one - # option. - expected_op_type = [_.op_type for _ in pattern_node_users_not_matched] - got_op_type = [_.op_type for _ in free] - - ec = collections.Counter(expected_op_type) - gc = collections.Counter(got_op_type) - if len(ec) != len(gc) or set(ec) != set(gc): - # unique operator types is different. - self._hint( - "FORWARD: unique operator types are different", - "-- pattern", - ec, - pattern_node, - "-- model", - gc, - graph_node, - "-- model-matched", - pattern_node_users_matched, - ) - return self.none(node, inspect.currentframe().f_lineno) - for k, v in ec.items(): - if gc[k] < v: - # Not enough types to match. - return self.none(node, inspect.currentframe().f_lineno) - - # At this stage, we know matching the types is possible. - # We first mark whatever is possible. - ptype_to_node = {_.op_type: _ for _ in pattern_node_users_not_matched} - gtype_to_node = {_.op_type: _ for _ in free} - missing = [] - for k, v in ec.items(): - if gc[k] == v == 1: - key = id(ptype_to_node[k]) - if key not in matched: - if self.verbose >= 10: - print( - f"[GenericPattern._match_forward] match " - f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}" - ) - matched[key] = gtype_to_node[k] - stack.append(key) - match_count += 1 - else: - missing.append(k) - - if not missing: - continue - - # At this stage, there are mutiple options for matching. We can: - # 1. make assumptions and continue - # 2. mark the node as incomplete matching, we could end up stuck anyway. - raise NotImplementedError( - f"There are more than one option, this will be implemented later, " - f"ec={ec}, gc={gc}" - ) - if self.verbose > 5 and match_count > 0: - print(f"[GenericPattern._match_forward] add {match_count} nodes") - return match_count - - def match( - self, - g: ir.Graph | ir.GraphView, - node: ir.Node, - ) -> PatternMatchResult | None: - self._debug = {} - - match_pattern: ir.Graph = self._get_match_pattern(g) - - # Let's match the last node. - # Then we need to match successors and predecessors. - last_pattern_node = match_pattern[-1] - if node.op_type != last_pattern_node.op_type: - # The last node does not have the same op_type. - return self.none() - - if self.verbose > 5: - print(f"[GenericPattern.match] starts with {node}") - if self.verbose >= 10: - print(f"[GenericPattern.match] match pattern {self!r}") - - all_pattern_nodes = set(match_pattern) - matched: dict[ir.Node, ir.Node] = {last_pattern_node: node} - stack: list[ir.Node] = [last_pattern_node] - iteration = 0 - - if self.verbose > 5: - self._debug = dict( - pattern=match_pattern, - matched=matched, - stack=stack, - iteration=iteration, - node=node, - pattern_node=last_pattern_node, - pattern_nodes=match_pattern, - ) - - max_iter = len(match_pattern) * 2 - while stack and iteration < max_iter: - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert not nodes_not_in_pattern, ( - f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - f"\nall_pattern_nodes={all_pattern_nodes}" - ) - - # TODO(justinchuby): Change to a for loop - iteration += 1 - if self.verbose > 5: - print( - f"[GenericPattern.match] iteration={iteration} " - f"n_matched={len(matched)}, n_stack={len(stack)}, " - f"matched_types={collections.Counter(_.op_type for _ in matched)}" - ) - pattern_node_from_stack = stack.pop() - pattern_to_graph_node = matched[pattern_node_from_stack] - - result = self._match_backward( - node, matched, stack, pattern_to_graph_node, pattern_node_from_stack - ) - if result is None: - if self.verbose > 5: - print("[GenericPattern.match] done. backward failed.") - return result - - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert ( - not nodes_not_in_pattern - ), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - - result = self._match_forward( - node, matched, stack, pattern_to_graph_node, pattern_node_from_stack - ) - if result is None: - if self.verbose > 5: - print("[GenericPattern.match] done. forward failed.") - return result - - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert ( - not nodes_not_in_pattern - ), f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - - if self.verbose > 5: - self._debug["iteration"] = iteration - - if iteration >= max_iter and stack: - self._hint("reached {iteration}>={max_iter} iterations") - return self.none(node, inspect.currentframe().f_lineno) - - if self.verbose > 5: - print(f"[GenericPattern.match] done. {len(matched)} matched nodes") - - # At this point, the pattern is matched but let's make sure. - assert len(matched) == len(match_pattern), ( - f"Number of matched nodes is different, {len(matched)} matched nodes, " - f"and {len(match_pattern)} nodes in the pattern, matched is {matched}" - ) - assert len(stack) == 0, f"There are still {len(stack)} nodes to explore." - - # We order the matched nodes in the same order than the pattern - # to let next functions to be able to build the matching again. - matched_nodes = [matched[pattern_node] for pattern_node in match_pattern] - return PatternMatchResult( - self, - matched_nodes, - tuple(match_pattern), - match_pattern.inputs, - match_pattern.outputs, - ) - - def apply( - self, - model: ir.Model, - match_result: orp.MatchResult, - verbose: int = 0, - ) -> orp.ReplacementSubgraph | None: - x = orp.ReplacementPatternFunction(self.apply_pattern) - replacement = x.get_replacement(match_result) - # if replacement is not None: - # TODO(Rama) - # assert len(replacement.new_outputs) == len(match_result.pattern_outputs), ( - # f"Not the same number of outputs, matched " - # f"outputs={match_result.pattern_outputs}, " - # f"got {replacement.new_outputs} in the applied pattern." - # ) - return replacement - - def make_rule(self) -> orp.RewriteRule: - """Creates the corresponding rule for this pattern.""" - return GenericRewriteRule(self) - - -class FunctionPattern(GenericPattern): - """An instance of GenericPattern taking ir.Function. - - It defines the matching pattern and its replacement. - - Args: - match_pattern: the onnx ir function defining the matching pattern - apply_pattern: the onnx ir function defining the new pattern - validate_mapping: the function used to validate a pattern - verbose: in [0, 10], increase the verbosity to understand why a pattern - does not match - - """ - - def __init__( - self, - match_pattern: ir.Function, - apply_pattern: Callable, - validate_mapping: Callable, - verbose: int = 0, - ): - self.match_pattern = match_pattern - self.apply_pattern = apply_pattern - self.validate_mapping = validate_mapping - self.verbose = verbose - - def _get_match_pattern(self, *_, **__): - return self.match_pattern - - -def _build_pattern(match_pattern_function: Callable) -> ir.Graph: - kwargs = {} - args = [] - - # There should be a better way. - sig = inspect.signature(match_pattern_function) - for i, p in enumerate(sig.parameters.values()): - if i == 0: - continue - if p.default is not inspect._empty: - # an attribute - kwargs[p.name] = p.default - else: - args.append(p.name) - - assert len(kwargs) == 0, f"Attributes are not supported yet but kwargs={kwargs}" - - inputs = [ir.Input(name=name) for name in args] - builder = orp.RewriterContext() - outputs = match_pattern_function(builder, *inputs, **kwargs) - if isinstance(outputs, ir.Value): - outputs = [outputs] - # TODO(Rama): Should construct a function! - graph = ir.Graph(inputs=inputs, outputs=outputs, nodes=builder.nodes) - graph.outputs[:] = outputs - return graph - - -def make_pattern_rule( - match_pattern_function: Callable, - apply_pattern_function: Callable, - validate_mapping: Callable | None = None, - verbose: int = 0, -) -> orp.RewriteRule: - """ - Creates a rewriting rule from a callable or a function proto. - - Args: - match_pattern_function: an onnxscript-like function that defines - the pattern subgraph (nodes) to be replaced - apply_pattern_function: an onnxscript-like function that constructs - the replacement subgraph (new nodes replacing the matched nodes) - validate_mapping: a function that validates the matching subgraph once - it is found. If it returns False the pattern is not applied. - If not specified, it is equivalent to a function that always return True - verbose: verbosity level - - Returns: - the rewriting rule - """ - - match_pattern_ir = _build_pattern(match_pattern_function) - - pat = FunctionPattern( - match_pattern_ir, - apply_pattern_function, - validate_mapping or (lambda *_, **__: True), - verbose=verbose, - ) - return pat.make_rule() diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py deleted file mode 100644 index d1184552b8..0000000000 --- a/onnxscript/rewriter/generic_pattern_test.py +++ /dev/null @@ -1,513 +0,0 @@ -from __future__ import annotations - -import contextlib -import io -import os -import unittest - -import numpy as np -import onnx -import onnx.reference -import onnxruntime as ort - -from onnxscript import ir -from onnxscript.rewriter import generic_pattern - -FLOAT = onnx.TensorProto.FLOAT - - -class GenericPatternTest(unittest.TestCase): - def _range(self, *shape, bias: float | None = None): - n = np.prod(shape) - x = np.arange(n).astype(np.float32) / n - if bias: - x = x + bias - return x.reshape(tuple(shape)).astype(np.float32) - - def test_graph_pattern_builder(self): - """Test replacing Add + Add by AddAdd.""" - - def match_pattern(op, x, y, z): - """Builds the pattern to match.""" - tmp = op.Add(x, y) - return op.Add(tmp, z) - - def apply_pattern(op, x, y, z, **_): - """Builds the replacement graph.""" - return op.AddAdd(x, y, z, domain="ZZZ") - - def validate_mapping(context, x, y, z, **_) -> bool: - """Validates the mapping.""" - del context - return True - - rule = generic_pattern.make_pattern_rule( - match_pattern, apply_pattern, validate_mapping, verbose=0 - ) - - class AddAdd(onnx.reference.op_run.OpRun): - op_domain = "ZZZ" - - def _run(self, x, y, z): - return (x + y + z,) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["x", "y"], ["gggg"]), - onnx.helper.make_node("Add", ["gggg", "z"], ["final"]), - ], - "dummy", - [ - onnx.helper.make_tensor_value_info("x", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("y", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("z", FLOAT, [None, None]), - ], - [onnx.helper.make_tensor_value_info("final", FLOAT, [None, None])], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ir_version=9, - ) - onnx.checker.check_model(model) - - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule.apply_to_model(ir_model) - self.assertEqual( - ["AddAdd"], - [n.op_type for n in ir_model.graph], - ) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - rewriten_model = ir.serde.serialize_model(ir_model) - self.assertEqual( - ["AddAdd"], - [n.op_type for n in rewriten_model.graph.node], - ) - - feeds = { - "x": self._range(5, 6), - "y": self._range(5, 6), - "z": self._range(5, 6), - } - ref1 = onnx.reference.ReferenceEvaluator(model) - expected = ref1.run(None, feeds) - - self.assertEqual(0, len(rewriten_model.graph.initializer)) - opsets = {v.domain: v.version for v in rewriten_model.opset_import} - self.assertIn("ZZZ", opsets) - self.assertEqual(opsets["ZZZ"], 1) - - ref2 = onnx.reference.ReferenceEvaluator(rewriten_model, new_ops=[AddAdd]) - got = ref2.run(None, feeds) - np.testing.assert_almost_equal(expected[0], got[0]) - - def test_graph_pattern_builder_multi_outputs(self): - def match_pattern(op, x, y, w, z): - """Builds the pattern to match.""" - tmp = op.Add(x, y) - tmp2 = op.Add(tmp, w) - r1 = op.Add(tmp, z) - return tmp2, r1 - - def apply_pattern(op, x, y, w, z, **_): - """Builds the pattern to match.""" - return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", outputs=2) - - def validate_mapping(context, **_) -> bool: - return True - - rule = generic_pattern.make_pattern_rule( - match_pattern, apply_pattern, validate_mapping, verbose=10 - ) - - class AddAddAddAdd(onnx.reference.op_run.OpRun): - op_domain = "ZZZ" - - def _run(self, x, y, w, z): - return (x + y + w, x + y + z) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["x", "y"], ["gggg"]), - onnx.helper.make_node("Add", ["gggg", "w"], ["f1"]), - onnx.helper.make_node("Add", ["gggg", "z"], ["f2"]), - ], - "dummy", - [ - onnx.helper.make_tensor_value_info("x", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("y", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("z", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("w", FLOAT, [None, None]), - ], - [ - onnx.helper.make_tensor_value_info("f1", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("f2", FLOAT, [None, None]), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ir_version=9, - ) - onnx.checker.check_model(model) - - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule.apply_to_model(ir_model) - self.assertEqual( - ["AddAddAddAdd"], - [n.op_type for n in ir_model.graph], - ) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - self.assertEqual( - ["AddAddAddAdd"], - [n.op_type for n in rewriten_model.graph.node], - ) - - feeds = { - "x": self._range(5, 6), - "y": self._range(5, 6), - "w": self._range(5, 6), - "z": self._range(5, 6), - } - ref1 = onnx.reference.ReferenceEvaluator(model) - expected = ref1.run(None, feeds) - - self.assertEqual(0, len(rewriten_model.graph.initializer)) - opsets = {v.domain: v.version for v in rewriten_model.opset_import} - self.assertIn("ZZZ", opsets) - self.assertEqual(opsets["ZZZ"], 1) - - ref2 = onnx.reference.ReferenceEvaluator(rewriten_model, new_ops=[AddAddAddAdd]) - got = ref2.run(None, feeds) - np.testing.assert_almost_equal(expected[0], got[0]) - - def check_with_ort(self, model: onnx.ModelProto, providers=None): - if providers is None: - providers = ["CPUExecutionProvider"] - - if isinstance(model, onnx.ModelProto): - model = model.SerializeToString() - session = ort.InferenceSession(model, providers=providers) - return session - - def get_rotary_model(self): - inputs = [ - onnx.helper.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]), - onnx.helper.make_tensor_value_info("pos_ids", FLOAT, shape=[]), - onnx.helper.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]), - ] - nodes = [ - onnx.helper.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]), - onnx.helper.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1), - onnx.helper.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]), - onnx.helper.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]), - onnx.helper.make_node( - "ConcatTraining", - ["_onx_transpose0", "_onx_transpose0"], - ["_onx_concattraining0", "_onx_concattraining1"], - domain="com.microsoft", - ), - onnx.helper.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]), - onnx.helper.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1), - onnx.helper.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]), - onnx.helper.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1), - ] - outputs = [ - onnx.helper.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []), - onnx.helper.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []), - ] - model = onnx.helper.make_model( - onnx.helper.make_graph( - nodes, - "experiment", - inputs, - outputs, - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 18), - ], - ) - return model - - def test_rotary_embedding(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def match_pattern(op, x, pos_ids, axis): - # original code: the code does verifies the constant yet - # unsqueeze = op.Unsqueeze(x, [1]) - - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, length = op.ConcatTraining( - transpose, - transpose, - domain="com.microsoft", - outputs=2, - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_mapping(match_result, **_) -> bool: - del match_result - return True - - def apply_pattern(op, x, pos_ids, axis, **_): - del axis - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - return op.RotaryEmbedding( - x, - pos_ids, - cos_cache, - sin_cache, - domain="com.microsoft", - outputs=2, - ) - - rule = generic_pattern.make_pattern_rule( - match_pattern, apply_pattern, validate_mapping, verbose=10 - ) - - model = self.get_rotary_model() - - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - # back to ir - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - # starts matching - rule.apply_to_model(ir_model) - ir_model.opset_imports["com.microsoft"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Constant", "Constant", "RotaryEmbedding"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - out = buffer.getvalue() - # TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working? - self.assertIn("[GenericPattern.match", out) - - def test_rotary_embedding_onnxscript(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def rotary_match_pattern(op, x, pos_ids, axis): - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_rotary_mapping(match_result, **_) -> bool: - # If some pattern needs to be rejected. - del match_result - return True - - def rotary_apply_pattern(op, x, pos_ids, axis, **_): - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 - ) - return part1, part2 - - rule = generic_pattern.make_pattern_rule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, - verbose=10, - ) - - model = self.get_rotary_model() - - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - # back to ir - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - # starts matching - rule.apply_to_model(ir_model) - ir_model.opset_imports["com.microsoft"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Constant", "Constant", "RotaryEmbedding"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - out = buffer.getvalue() - # TODO(justinchuby): Remove this assert - capturing stdout is not robust - self.assertIn("[GenericPattern.match", out) - - def test_rotary_emb_file_onnxscript(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def rotary_match_pattern(op, x, pos_ids, axis): - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, length = op.ConcatTraining( - transpose, transpose, domain="com.microsoft", outputs=2 - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_rotary_mapping(match_result, **_) -> bool: - # If some pattern needs to be rejected. - del match_result - return True - - def rotary_apply_pattern(op, x, pos_ids, axis): - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 - ) - return part1, part2 - - model_path = "gemma_optimized_pre_grad_training_2.onnx" - if not os.path.exists(model_path): - raise unittest.SkipTest(f"{model_path!r} is missing") - model = onnx.load(model_path) - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule = generic_pattern.make_pattern_rule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, - verbose=10, - ) - - rule.apply_to_model(ir_model) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - buffer = rewriten_model.SerializeToString() - with open(f"{model}.opt.onnx", "wb") as f: - f.write(buffer) - self.check_with_ort(rewriten_model) - - def test_transpose_transpose_onnxscript(self): - # TODO(rama): Attribute-parameters not yet supported in multi-output matching. - # def transpose_transpose_pattern(op, X, perm0, perm1): - # xt = op.Transpose(X, perm=perm0) - # Y = op.Transpose(xt, perm=perm1) - # return Y - - def transpose_transpose_pattern(op, X): - XT = op.Transpose(X, outputs=["XT"]) - Y = op.Transpose(XT, outputs=["Y"]) - return Y - - def transpose_transpose_mapping(perm0, perm1): - new_perm = [0 for p in perm0] - for i, p in enumerate(perm1): - new_perm[i] = perm0[p] - # replace by return [perm0[p] for p in perm1] ? - return new_perm - - def transpose_transpose_check(op, **_) -> bool: - return True - - def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): - perm0 = XT.producer().attributes.get("perm") - if perm0 is not None: - perm0 = perm0.value # TODO(rama): handle RefAttr - perm1 = Y.producer().attributes.get("perm") - if perm1 is not None: - perm1 = perm1.value # TODO(rama): handle RefAttr - if perm0 is None and perm1 is None: - return op.Identity(X) - if perm0 is None: - perm0 = range(len(perm1) - 1, -1, -1) - if perm1 is None: - perm1 = range(len(perm0) - 1, -1, -1) - composed_perm = transpose_transpose_mapping(perm0, perm1) - return op.Transpose(X, perm=composed_perm) - - rule = generic_pattern.make_pattern_rule( - transpose_transpose_pattern, - transpose_transpose_apply_pattern, - transpose_transpose_check, - verbose=0, - ) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]), - onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ) - - # back to ir - ir_model = ir.serde.deserialize_model(model) - - # starts matching - - rule.apply_to_model(ir_model) - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Transpose"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - node = rewriten_model.graph.node[0] - self.assertEqual(len(node.attribute), 1) - att = node.attribute[0] - self.assertEqual(att.name, "perm") - self.assertEqual(list(att.ints), [2, 0, 1]) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/match_context_test.py b/onnxscript/rewriter/match_context_test.py new file mode 100644 index 0000000000..e45b8e9ab5 --- /dev/null +++ b/onnxscript/rewriter/match_context_test.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Test for MatchContext functionality.""" + +import unittest + +import onnx.parser + +from onnxscript import ir +from onnxscript.rewriter import pattern + + +class MatchContextTest(unittest.TestCase): + def test_context_usage_in_condition_function(self): + """Test that MatchContext can be meaningfully used in condition functions.""" + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + c1 = Constant() + t1 = Div(c1, x) + z = Mul(t1, y) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + + def condition_using_context(context, x, y): + # Use context to check properties of the match + self.assertIs(context.model, model) + self.assertIs(context.graph_or_function, model.graph) + self.assertIs(context.root, model.graph[2]) + + # Verify that we can inspect the matched nodes + self.assertEqual(len(context.nodes), 2) + + return True # Allow the rewrite + + def reciprocal_mul_pattern(op, x, y): + return (1 / x) * y + + def replacement(op, x, y): + return op.Div(y, x) + + rule = pattern.RewriteRule( + reciprocal_mul_pattern, replacement, condition_function=condition_using_context + ) + + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/models/_bart_encoder.py b/onnxscript/rewriter/models/_bart_encoder.py new file mode 100644 index 0000000000..2e5bcce5c0 --- /dev/null +++ b/onnxscript/rewriter/models/_bart_encoder.py @@ -0,0 +1,701 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Onnxscript version of "hf-internal-testing_tiny-random-bart". + +See: https://huggingface.co/hf-internal-testing/tiny-random-bart +""" + +import numpy as np + +import onnxscript.ir as ir +from onnxscript import script +from onnxscript.onnx_opset import opset20 +from onnxscript.onnx_types import FLOAT, INT64 + + +def make_model( + encoder_embed_tokens_weight, + encoder_embed_positions_weight, + encoder_layers_0_self_attn_k_proj_bias, + encoder_layers_0_self_attn_layer_norm_weight, + encoder_layers_0_fc1_bias, + matmul_257, + matmul_267, + matmul_268, + matmul_270, + matmul_271, + matmul_272, + matmul_273, + matmul_283, + matmul_284, + matmul_286, + matmul_287, + matmul_288, +): + @script() + def main_graph(input_ids: INT64[1, None]) -> FLOAT[None, None, 16]: + encoder_layernorm_embedding_bias = opset20.Identity( + encoder_layers_0_self_attn_layer_norm_weight + ) + encoder_layernorm_embedding_weight = opset20.Identity( + encoder_layers_0_self_attn_layer_norm_weight + ) + + encoder_layers_1_final_layer_norm_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_1_final_layer_norm_weight = opset20.Identity( + encoder_layers_0_self_attn_layer_norm_weight + ) + + encoder_layers_1_fc2_bias = opset20.Identity(encoder_layers_0_self_attn_k_proj_bias) + encoder_layers_1_self_attn_layer_norm_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_1_self_attn_layer_norm_weight = opset20.Identity( + encoder_layers_0_self_attn_layer_norm_weight + ) + encoder_layers_1_self_attn_out_proj_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_1_self_attn_q_proj_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_1_self_attn_v_proj_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_1_self_attn_k_proj_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_0_final_layer_norm_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_0_final_layer_norm_weight = opset20.Identity( + encoder_layers_0_self_attn_layer_norm_weight + ) + encoder_layers_0_fc2_bias = opset20.Identity(encoder_layers_0_self_attn_k_proj_bias) + encoder_layers_1_fc1_bias = opset20.Identity(encoder_layers_0_fc1_bias) + encoder_layers_0_self_attn_out_proj_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_0_self_attn_q_proj_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + encoder_layers_0_self_attn_v_proj_bias = opset20.Identity( + encoder_layers_0_self_attn_k_proj_bias + ) + + encoder_shape_output_0 = opset20.Shape(input_ids) + encoder_constant_output_0 = opset20.Constant(value=1) + encoder_gather_output_0 = opset20.Gather( + encoder_shape_output_0, encoder_constant_output_0 + ) + + encoder_constant_1_output_0 = opset20.Constant(value=[-1]) + unsqueeze_43 = opset20.Constant(value=[0]) + encoder_unsqueeze_output_0 = opset20.Unsqueeze(encoder_gather_output_0, unsqueeze_43) + encoder_concat_output_0 = opset20.Concat( + encoder_constant_1_output_0, encoder_unsqueeze_output_0, axis=0 + ) + encoder_reshape_output_0 = opset20.Reshape( + input_ids, encoder_concat_output_0, allowzero=0 + ) + encoder_embed_tokens_gather_output_0 = opset20.Gather( + encoder_embed_tokens_weight, encoder_reshape_output_0 + ) + encoder_embed_tokens_constant_output_0 = opset20.Constant(value=[1.0]) + encoder_embed_tokens_mul_output_0 = opset20.Mul( + encoder_embed_tokens_gather_output_0, encoder_embed_tokens_constant_output_0 + ) + encoder_embed_positions_shape_output_0 = opset20.Shape(input_ids) + encoder_embed_positions_constant_output_0 = opset20.Constant(value=0) + encoder_embed_positions_gather_output_0 = opset20.Gather( + encoder_embed_positions_shape_output_0, + encoder_embed_positions_constant_output_0, + axis=0, + ) + encoder_embed_positions_constant_1_output_0 = opset20.Constant(value=0) + encoder_embed_positions_cast_output_0 = opset20.Cast(encoder_gather_output_0, to=7) + encoder_embed_positions_constant_2_output_0 = opset20.Constant(value=1) + encoder_embed_positions_range_output_0 = opset20.Range( + encoder_embed_positions_constant_1_output_0, + encoder_embed_positions_cast_output_0, + encoder_embed_positions_constant_2_output_0, + ) + encoder_embed_positions_constant_3_output_0 = opset20.Constant(value=[0]) + encoder_embed_positions_unsqueeze_output_0 = opset20.Unsqueeze( + encoder_embed_positions_gather_output_0, + encoder_embed_positions_constant_3_output_0, + ) + encoder_embed_positions_constant_4_output_0 = opset20.Constant(value=[-1]) + encoder_embed_positions_concat_output_0 = opset20.Concat( + encoder_embed_positions_unsqueeze_output_0, + encoder_embed_positions_constant_4_output_0, + axis=0, + ) + encoder_embed_positions_constant_5_output_0 = opset20.Constant(value=[-1]) + encoder_embed_positions_reshape_output_0 = opset20.Reshape( + encoder_embed_positions_concat_output_0, + encoder_embed_positions_constant_5_output_0, + ) + encoder_embed_positions_shape_1_output_0 = opset20.Shape( + encoder_embed_positions_reshape_output_0 + ) + encoder_embed_positions_constantofshape_output_0 = opset20.ConstantOfShape( + encoder_embed_positions_shape_1_output_0, + value=ir.tensor(np.array([1], dtype=np.int64)), + ) + encoder_embed_positions_constant_6_output_0 = opset20.Constant(value=[-1]) + encoder_embed_positions_mul_output_0 = opset20.Mul( + encoder_embed_positions_constantofshape_output_0, + encoder_embed_positions_constant_6_output_0, + ) + encoder_embed_positions_equal_output_0 = opset20.Equal( + encoder_embed_positions_reshape_output_0, encoder_embed_positions_mul_output_0 + ) + encoder_embed_positions_where_output_0 = opset20.Where( + encoder_embed_positions_equal_output_0, + encoder_embed_positions_constantofshape_output_0, + encoder_embed_positions_reshape_output_0, + ) + encoder_embed_positions_expand_output_0 = opset20.Expand( + encoder_embed_positions_range_output_0, encoder_embed_positions_where_output_0 + ) + encoder_embed_positions_constant_7_output_0 = opset20.Constant(value=2) + encoder_embed_positions_add_output_0 = opset20.Add( + encoder_embed_positions_expand_output_0, + encoder_embed_positions_constant_7_output_0, + ) + encoder_embed_positions_gather_1_output_0 = opset20.Gather( + encoder_embed_positions_weight, encoder_embed_positions_add_output_0 + ) + encoder_cast_output_0 = opset20.Cast(encoder_embed_positions_gather_1_output_0, to=1) + encoder_add_output_0 = opset20.Add( + encoder_embed_tokens_mul_output_0, encoder_cast_output_0 + ) + encoder_layernorm_embedding_layernormalization_output_0 = opset20.LayerNormalization( + encoder_add_output_0, + encoder_layernorm_embedding_weight, + encoder_layernorm_embedding_bias, + axis=-1, + epsilon=9.999999747378752e-06, + ) + encoder_layers_0_self_attn_shape_output_0 = opset20.Shape( + encoder_layernorm_embedding_layernormalization_output_0 + ) + encoder_layers_0_self_attn_constant_output_0 = opset20.Constant(value=0) + encoder_layers_0_self_attn_gather_output_0 = opset20.Gather( + encoder_layers_0_self_attn_shape_output_0, + encoder_layers_0_self_attn_constant_output_0, + axis=0, + ) + encoder_layers_0_self_attn_shape_1_output_0 = opset20.Shape( + encoder_layernorm_embedding_layernormalization_output_0 + ) + encoder_layers_0_self_attn_constant_1_output_0 = opset20.Constant(value=1) + encoder_layers_0_self_attn_gather_1_output_0 = opset20.Gather( + encoder_layers_0_self_attn_shape_1_output_0, + encoder_layers_0_self_attn_constant_1_output_0, + axis=0, + ) + encoder_layers_0_self_attn_q_proj_matmul_output_0 = opset20.MatMul( + encoder_layernorm_embedding_layernormalization_output_0, matmul_257 + ) + encoder_layers_0_self_attn_q_proj_add_output_0 = opset20.Add( + encoder_layers_0_self_attn_q_proj_bias, + encoder_layers_0_self_attn_q_proj_matmul_output_0, + ) + unsqueeze_88 = opset20.Constant(value=[0]) + encoder_layers_0_self_attn_unsqueeze_output_0 = opset20.Unsqueeze( + encoder_layers_0_self_attn_gather_output_0, unsqueeze_88 + ) + encoder_layers_0_self_attn_constant_2_output_0 = opset20.Constant(value=[-1]) + encoder_layers_0_self_attn_constant_3_output_0 = opset20.Constant(value=[4]) + encoder_layers_0_self_attn_constant_4_output_0 = opset20.Constant(value=[4]) + encoder_layers_0_self_attn_concat_output_0 = opset20.Concat( + encoder_layers_0_self_attn_unsqueeze_output_0, + encoder_layers_0_self_attn_constant_2_output_0, + encoder_layers_0_self_attn_constant_3_output_0, + encoder_layers_0_self_attn_constant_4_output_0, + axis=0, + ) + unsqueeze_97 = opset20.Constant(value=[0]) + encoder_layers_0_self_attn_unsqueeze_1_output_0 = opset20.Unsqueeze( + encoder_layers_0_self_attn_gather_output_0, unsqueeze_97 + ) + encoder_layers_0_self_attn_constant_5_output_0 = opset20.Constant(value=[-1]) + encoder_layers_0_self_attn_constant_6_output_0 = opset20.Constant(value=[4]) + encoder_layers_0_self_attn_constant_7_output_0 = opset20.Constant(value=[4]) + encoder_layers_0_self_attn_concat_1_output_0 = opset20.Concat( + encoder_layers_0_self_attn_unsqueeze_1_output_0, + encoder_layers_0_self_attn_constant_5_output_0, + encoder_layers_0_self_attn_constant_6_output_0, + encoder_layers_0_self_attn_constant_7_output_0, + axis=0, + ) + unsqueeze_106 = opset20.Constant(value=[0]) + encoder_layers_0_self_attn_unsqueeze_2_output_0 = opset20.Unsqueeze( + encoder_layers_0_self_attn_gather_output_0, unsqueeze_106 + ) + encoder_layers_0_self_attn_constant_8_output_0 = opset20.Constant(value=[-1]) + encoder_layers_0_self_attn_constant_9_output_0 = opset20.Constant(value=[4]) + encoder_layers_0_self_attn_constant_10_output_0 = opset20.Constant(value=[4]) + encoder_layers_0_self_attn_concat_2_output_0 = opset20.Concat( + encoder_layers_0_self_attn_unsqueeze_2_output_0, + encoder_layers_0_self_attn_constant_8_output_0, + encoder_layers_0_self_attn_constant_9_output_0, + encoder_layers_0_self_attn_constant_10_output_0, + axis=0, + ) + + encoder_layers_0_self_attn_reshape_output_0 = opset20.Reshape( + encoder_layers_0_self_attn_q_proj_add_output_0, + encoder_layers_0_self_attn_concat_output_0, + allowzero=0, + ) + encoder_layers_0_self_attn_transpose_output_0 = opset20.Transpose( + encoder_layers_0_self_attn_reshape_output_0, perm=[0, 2, 1, 3] + ) + encoder_layers_0_self_attn_k_proj_matmul_output_0 = opset20.MatMul( + encoder_layernorm_embedding_layernormalization_output_0, matmul_267 + ) + encoder_layers_0_self_attn_k_proj_add_output_0 = opset20.Add( + encoder_layers_0_self_attn_k_proj_bias, + encoder_layers_0_self_attn_k_proj_matmul_output_0, + ) + encoder_layers_0_self_attn_v_proj_matmul_output_0 = opset20.MatMul( + encoder_layernorm_embedding_layernormalization_output_0, matmul_268 + ) + encoder_layers_0_self_attn_v_proj_add_output_0 = opset20.Add( + encoder_layers_0_self_attn_v_proj_bias, + encoder_layers_0_self_attn_v_proj_matmul_output_0, + ) + encoder_layers_0_self_attn_reshape_1_output_0 = opset20.Reshape( + encoder_layers_0_self_attn_k_proj_add_output_0, + encoder_layers_0_self_attn_concat_1_output_0, + allowzero=0, + ) + encoder_layers_0_self_attn_reshape_2_output_0 = opset20.Reshape( + encoder_layers_0_self_attn_v_proj_add_output_0, + encoder_layers_0_self_attn_concat_2_output_0, + allowzero=0, + ) + encoder_layers_0_self_attn_transpose_1_output_0 = opset20.Transpose( + encoder_layers_0_self_attn_reshape_2_output_0, perm=[0, 2, 1, 3] + ) + encoder_layers_0_self_attn_shape_2_output_0 = opset20.Shape( + encoder_layers_0_self_attn_transpose_output_0 + ) + encoder_layers_0_self_attn_constant_11_output_0 = opset20.Constant(value=[-1]) + encoder_layers_0_self_attn_constant_12_output_0 = opset20.Constant( + value=[9223372036854775807] + ) + encoder_layers_0_self_attn_slice_output_0 = opset20.Slice( + encoder_layers_0_self_attn_shape_2_output_0, + encoder_layers_0_self_attn_constant_11_output_0, + encoder_layers_0_self_attn_constant_12_output_0, + ) + encoder_layers_0_self_attn_cast_output_0 = opset20.Cast( + encoder_layers_0_self_attn_slice_output_0, to=1 + ) + encoder_layers_0_self_attn_sqrt_output_0 = opset20.Sqrt( + encoder_layers_0_self_attn_cast_output_0 + ) + encoder_layers_0_self_attn_constant_13_output_0 = opset20.Constant(value=[1.0]) + encoder_layers_0_self_attn_div_output_0 = opset20.Div( + encoder_layers_0_self_attn_constant_13_output_0, + encoder_layers_0_self_attn_sqrt_output_0, + ) + encoder_layers_0_self_attn_cast_1_output_0 = opset20.Cast( + encoder_layers_0_self_attn_div_output_0, to=1 + ) + encoder_layers_0_self_attn_transpose_2_output_0 = opset20.Transpose( + encoder_layers_0_self_attn_reshape_1_output_0, perm=[0, 2, 3, 1] + ) + encoder_layers_0_self_attn_sqrt_1_output_0 = opset20.Sqrt( + encoder_layers_0_self_attn_cast_1_output_0 + ) + encoder_layers_0_self_attn_mul_output_0 = opset20.Mul( + encoder_layers_0_self_attn_transpose_output_0, + encoder_layers_0_self_attn_sqrt_1_output_0, + ) + encoder_layers_0_self_attn_sqrt_2_output_0 = opset20.Sqrt( + encoder_layers_0_self_attn_cast_1_output_0 + ) + encoder_layers_0_self_attn_mul_1_output_0 = opset20.Mul( + encoder_layers_0_self_attn_transpose_2_output_0, + encoder_layers_0_self_attn_sqrt_2_output_0, + ) + encoder_layers_0_self_attn_matmul_output_0 = opset20.MatMul( + encoder_layers_0_self_attn_mul_output_0, encoder_layers_0_self_attn_mul_1_output_0 + ) + encoder_layers_0_self_attn_softmax_output_0 = opset20.Softmax( + encoder_layers_0_self_attn_matmul_output_0, axis=-1 + ) + encoder_layers_0_self_attn_matmul_1_output_0 = opset20.MatMul( + encoder_layers_0_self_attn_softmax_output_0, + encoder_layers_0_self_attn_transpose_1_output_0, + ) + encoder_layers_0_self_attn_transpose_3_output_0 = opset20.Transpose( + encoder_layers_0_self_attn_matmul_1_output_0, perm=[0, 2, 1, 3] + ) + unsqueeze_145 = opset20.Constant(value=[0]) + encoder_layers_0_self_attn_unsqueeze_3_output_0 = opset20.Unsqueeze( + encoder_layers_0_self_attn_gather_output_0, unsqueeze_145 + ) + unsqueeze_147 = opset20.Constant(value=[0]) + encoder_layers_0_self_attn_unsqueeze_4_output_0 = opset20.Unsqueeze( + encoder_layers_0_self_attn_gather_1_output_0, unsqueeze_147 + ) + encoder_layers_0_self_attn_constant_14_output_0 = opset20.Constant(value=[16]) + encoder_layers_0_self_attn_concat_3_output_0 = opset20.Concat( + encoder_layers_0_self_attn_unsqueeze_3_output_0, + encoder_layers_0_self_attn_unsqueeze_4_output_0, + encoder_layers_0_self_attn_constant_14_output_0, + axis=0, + ) + encoder_layers_0_self_attn_reshape_3_output_0 = opset20.Reshape( + encoder_layers_0_self_attn_transpose_3_output_0, + encoder_layers_0_self_attn_concat_3_output_0, + allowzero=0, + ) + encoder_layers_0_self_attn_out_proj_matmul_output_0 = opset20.MatMul( + encoder_layers_0_self_attn_reshape_3_output_0, matmul_270 + ) + encoder_layers_0_self_attn_out_proj_add_output_0 = opset20.Add( + encoder_layers_0_self_attn_out_proj_bias, + encoder_layers_0_self_attn_out_proj_matmul_output_0, + ) + encoder_layers_0_add_output_0 = opset20.Add( + encoder_layernorm_embedding_layernormalization_output_0, + encoder_layers_0_self_attn_out_proj_add_output_0, + ) + encoder_layers_0_self_attn_layer_norm_layernormalization_output_0 = ( + opset20.LayerNormalization( + encoder_layers_0_add_output_0, + encoder_layers_0_self_attn_layer_norm_weight, + encoder_layernorm_embedding_bias, + axis=-1, + epsilon=9.999999747378752e-0, + ) + ) + encoder_layers_0_fc1_matmul_output_0 = opset20.MatMul( + encoder_layers_0_self_attn_layer_norm_layernormalization_output_0, matmul_271 + ) + encoder_layers_0_fc1_add_output_0 = opset20.Add( + encoder_layers_0_fc1_bias, encoder_layers_0_fc1_matmul_output_0 + ) + encoder_layers_0_activation_fn_gelu_output_0 = opset20.Gelu( + encoder_layers_0_fc1_add_output_0, approximate="none" + ) + encoder_layers_0_fc2_matmul_output_0 = opset20.MatMul( + encoder_layers_0_activation_fn_gelu_output_0, matmul_272 + ) + encoder_layers_0_fc2_add_output_0 = opset20.Add( + encoder_layers_0_fc2_bias, encoder_layers_0_fc2_matmul_output_0 + ) + encoder_layers_0_add_1_output_0 = opset20.Add( + encoder_layers_0_self_attn_layer_norm_layernormalization_output_0, + encoder_layers_0_fc2_add_output_0, + ) + encoder_layers_0_final_layer_norm_layernormalization_output_0 = ( + opset20.LayerNormalization( + encoder_layers_0_add_1_output_0, + encoder_layers_0_final_layer_norm_weight, + encoder_layers_0_final_layer_norm_bias, + axis=-1, + epsilon=9.999999747378752e-06, + ) + ) + encoder_layers_1_self_attn_shape_output_0 = opset20.Shape( + encoder_layers_0_final_layer_norm_layernormalization_output_0 + ) + encoder_layers_1_self_attn_constant_output_0 = opset20.Constant(value=0) + encoder_layers_1_self_attn_gather_output_0 = opset20.Gather( + encoder_layers_1_self_attn_shape_output_0, + encoder_layers_1_self_attn_constant_output_0, + axis=0, + ) + encoder_layers_1_self_attn_shape_1_output_0 = opset20.Shape( + encoder_layers_0_final_layer_norm_layernormalization_output_0 + ) + encoder_layers_1_self_attn_constant_1_output_0 = opset20.Constant(value=1) + encoder_layers_1_self_attn_gather_1_output_0 = opset20.Gather( + encoder_layers_1_self_attn_shape_1_output_0, + encoder_layers_1_self_attn_constant_1_output_0, + axis=0, + ) + encoder_layers_1_self_attn_q_proj_matmul_output_0 = opset20.MatMul( + encoder_layers_0_final_layer_norm_layernormalization_output_0, matmul_273 + ) + encoder_layers_1_self_attn_q_proj_add_output_0 = opset20.Add( + encoder_layers_1_self_attn_q_proj_bias, + encoder_layers_1_self_attn_q_proj_matmul_output_0, + ) + unsqueeze_176 = opset20.Constant(value=[0]) + encoder_layers_1_self_attn_unsqueeze_output_0 = opset20.Unsqueeze( + encoder_layers_1_self_attn_gather_output_0, unsqueeze_176 + ) + encoder_layers_1_self_attn_constant_2_output_0 = opset20.Constant(value=[-1]) + encoder_layers_1_self_attn_constant_3_output_0 = opset20.Constant(value=[4]) + encoder_layers_1_self_attn_constant_4_output_0 = opset20.Constant(value=[4]) + encoder_layers_1_self_attn_concat_output_0 = opset20.Concat( + encoder_layers_1_self_attn_unsqueeze_output_0, + encoder_layers_1_self_attn_constant_2_output_0, + encoder_layers_1_self_attn_constant_3_output_0, + encoder_layers_1_self_attn_constant_4_output_0, + axis=0, + ) + unsqueeze_185 = opset20.Constant(value=[0]) + encoder_layers_1_self_attn_unsqueeze_1_output_0 = opset20.Unsqueeze( + encoder_layers_1_self_attn_gather_output_0, unsqueeze_185 + ) + encoder_layers_1_self_attn_constant_5_output_0 = opset20.Constant(value=[-1]) + encoder_layers_1_self_attn_constant_6_output_0 = opset20.Constant(value=[4]) + encoder_layers_1_self_attn_constant_7_output_0 = opset20.Constant(value=[4]) + encoder_layers_1_self_attn_concat_1_output_0 = opset20.Concat( + encoder_layers_1_self_attn_unsqueeze_1_output_0, + encoder_layers_1_self_attn_constant_5_output_0, + encoder_layers_1_self_attn_constant_6_output_0, + encoder_layers_1_self_attn_constant_7_output_0, + axis=0, + ) + unsqueeze_194 = opset20.Constant(value=[0]) + encoder_layers_1_self_attn_unsqueeze_2_output_0 = opset20.Unsqueeze( + encoder_layers_1_self_attn_gather_output_0, unsqueeze_194 + ) + encoder_layers_1_self_attn_constant_8_output_0 = opset20.Constant(value=[-1]) + encoder_layers_1_self_attn_constant_9_output_0 = opset20.Constant(value=[4]) + encoder_layers_1_self_attn_constant_10_output_0 = opset20.Constant(value=[4]) + encoder_layers_1_self_attn_concat_2_output_0 = opset20.Concat( + encoder_layers_1_self_attn_unsqueeze_2_output_0, + encoder_layers_1_self_attn_constant_8_output_0, + encoder_layers_1_self_attn_constant_9_output_0, + encoder_layers_1_self_attn_constant_10_output_0, + axis=0, + ) + encoder_layers_1_self_attn_reshape_output_0 = opset20.Reshape( + encoder_layers_1_self_attn_q_proj_add_output_0, + encoder_layers_1_self_attn_concat_output_0, + allowzero=0, + ) + encoder_layers_1_self_attn_transpose_output_0 = opset20.Transpose( + encoder_layers_1_self_attn_reshape_output_0, perm=[0, 2, 1, 3] + ) + encoder_layers_1_self_attn_k_proj_matmul_output_0 = opset20.MatMul( + encoder_layers_0_final_layer_norm_layernormalization_output_0, matmul_283 + ) + encoder_layers_1_self_attn_k_proj_add_output_0 = opset20.Add( + encoder_layers_1_self_attn_k_proj_bias, + encoder_layers_1_self_attn_k_proj_matmul_output_0, + ) + encoder_layers_1_self_attn_v_proj_matmul_output_0 = opset20.MatMul( + encoder_layers_0_final_layer_norm_layernormalization_output_0, matmul_284 + ) + encoder_layers_1_self_attn_v_proj_add_output_0 = opset20.Add( + encoder_layers_1_self_attn_v_proj_bias, + encoder_layers_1_self_attn_v_proj_matmul_output_0, + ) + encoder_layers_1_self_attn_reshape_1_output_0 = opset20.Reshape( + encoder_layers_1_self_attn_k_proj_add_output_0, + encoder_layers_1_self_attn_concat_1_output_0, + allowzero=0, + ) + encoder_layers_1_self_attn_reshape_2_output_0 = opset20.Reshape( + encoder_layers_1_self_attn_v_proj_add_output_0, + encoder_layers_1_self_attn_concat_2_output_0, + allowzero=0, + ) + encoder_layers_1_self_attn_transpose_1_output_0 = opset20.Transpose( + encoder_layers_1_self_attn_reshape_2_output_0, perm=[0, 2, 1, 3] + ) + encoder_layers_1_self_attn_shape_2_output_0 = opset20.Shape( + encoder_layers_1_self_attn_transpose_output_0 + ) + encoder_layers_1_self_attn_constant_11_output_0 = opset20.Constant(value=[-1]) + encoder_layers_1_self_attn_constant_12_output_0 = opset20.Constant( + value=[9223372036854775807] + ) + encoder_layers_1_self_attn_slice_output_0 = opset20.Slice( + encoder_layers_1_self_attn_shape_2_output_0, + encoder_layers_1_self_attn_constant_11_output_0, + encoder_layers_1_self_attn_constant_12_output_0, + ) + encoder_layers_1_self_attn_cast_output_0 = opset20.Cast( + encoder_layers_1_self_attn_slice_output_0, to=1 + ) + encoder_layers_1_self_attn_sqrt_output_0 = opset20.Sqrt( + encoder_layers_1_self_attn_cast_output_0 + ) + encoder_layers_1_self_attn_constant_13_output_0 = opset20.Constant(value=[1.0]) + encoder_layers_1_self_attn_div_output_0 = opset20.Div( + encoder_layers_1_self_attn_constant_13_output_0, + encoder_layers_1_self_attn_sqrt_output_0, + ) + encoder_layers_1_self_attn_cast_1_output_0 = opset20.Cast( + encoder_layers_1_self_attn_div_output_0, to=1 + ) + encoder_layers_1_self_attn_transpose_2_output_0 = opset20.Transpose( + encoder_layers_1_self_attn_reshape_1_output_0, perm=[0, 2, 3, 1] + ) + encoder_layers_1_self_attn_sqrt_1_output_0 = opset20.Sqrt( + encoder_layers_1_self_attn_cast_1_output_0 + ) + encoder_layers_1_self_attn_mul_output_0 = opset20.Mul( + encoder_layers_1_self_attn_transpose_output_0, + encoder_layers_1_self_attn_sqrt_1_output_0, + ) + encoder_layers_1_self_attn_sqrt_2_output_0 = opset20.Sqrt( + encoder_layers_1_self_attn_cast_1_output_0 + ) + encoder_layers_1_self_attn_mul_1_output_0 = opset20.Mul( + encoder_layers_1_self_attn_transpose_2_output_0, + encoder_layers_1_self_attn_sqrt_2_output_0, + ) + encoder_layers_1_self_attn_matmul_output_0 = opset20.MatMul( + encoder_layers_1_self_attn_mul_output_0, encoder_layers_1_self_attn_mul_1_output_0 + ) + encoder_layers_1_self_attn_softmax_output_0 = opset20.Softmax( + encoder_layers_1_self_attn_matmul_output_0, axis=-1 + ) + encoder_layers_1_self_attn_matmul_1_output_0 = opset20.MatMul( + encoder_layers_1_self_attn_softmax_output_0, + encoder_layers_1_self_attn_transpose_1_output_0, + ) + encoder_layers_1_self_attn_transpose_3_output_0 = opset20.Transpose( + encoder_layers_1_self_attn_matmul_1_output_0, perm=[0, 2, 1, 3] + ) + unsqueeze_232 = opset20.Constant(value=[0]) + encoder_layers_1_self_attn_unsqueeze_3_output_0 = opset20.Unsqueeze( + encoder_layers_1_self_attn_gather_output_0, unsqueeze_232 + ) + unsqueeze_234 = opset20.Constant(value=[0]) + encoder_layers_1_self_attn_unsqueeze_4_output_0 = opset20.Unsqueeze( + encoder_layers_1_self_attn_gather_1_output_0, unsqueeze_234 + ) + encoder_layers_1_self_attn_constant_14_output_0 = opset20.Constant(value=[16]) + + encoder_layers_1_self_attn_concat_3_output_0 = opset20.Concat( + encoder_layers_1_self_attn_unsqueeze_3_output_0, + encoder_layers_1_self_attn_unsqueeze_4_output_0, + encoder_layers_1_self_attn_constant_14_output_0, + axis=0, + ) + encoder_layers_1_self_attn_reshape_3_output_0 = opset20.Reshape( + encoder_layers_1_self_attn_transpose_3_output_0, + encoder_layers_1_self_attn_concat_3_output_0, + allowzero=0, + ) + encoder_layers_1_self_attn_out_proj_matmul_output_0 = opset20.MatMul( + encoder_layers_1_self_attn_reshape_3_output_0, matmul_286 + ) + encoder_layers_1_self_attn_out_proj_add_output_0 = opset20.Add( + encoder_layers_1_self_attn_out_proj_bias, + encoder_layers_1_self_attn_out_proj_matmul_output_0, + ) + encoder_layers_1_add_output_0 = opset20.Add( + encoder_layers_0_final_layer_norm_layernormalization_output_0, + encoder_layers_1_self_attn_out_proj_add_output_0, + ) + encoder_layers_1_self_attn_layer_norm_layernormalization_output_0 = ( + opset20.LayerNormalization( + encoder_layers_1_add_output_0, + encoder_layers_1_self_attn_layer_norm_weight, + encoder_layers_1_self_attn_layer_norm_bias, + axis=-1, + epsilon=9.999999747378752e-06, + ) + ) + encoder_layers_1_fc1_matmul_output_0 = opset20.MatMul( + encoder_layers_1_self_attn_layer_norm_layernormalization_output_0, matmul_287 + ) + encoder_layers_1_fc1_add_output_0 = opset20.Add( + encoder_layers_1_fc1_bias, encoder_layers_1_fc1_matmul_output_0 + ) + encoder_layers_1_activation_fn_gelu_output_0 = opset20.Gelu( + encoder_layers_1_fc1_add_output_0, approximate="none" + ) + encoder_layers_1_fc2_matmul_output_0 = opset20.MatMul( + encoder_layers_1_activation_fn_gelu_output_0, matmul_288 + ) + encoder_layers_1_fc2_add_output_0 = opset20.Add( + encoder_layers_1_fc2_bias, encoder_layers_1_fc2_matmul_output_0 + ) + encoder_layers_1_add_1_output_0 = opset20.Add( + encoder_layers_1_self_attn_layer_norm_layernormalization_output_0, + encoder_layers_1_fc2_add_output_0, + ) + encoder_output = opset20.LayerNormalization( + encoder_layers_1_add_1_output_0, + encoder_layers_1_final_layer_norm_weight, + encoder_layers_1_final_layer_norm_bias, + axis=-1, + epsilon=9.999999747378752e-06, + ) + return encoder_output + + return main_graph.to_model_proto() + + +def make_model_with_random_weights(): + encoder_embed_tokens_weight = np.random.rand(1024, 16).astype(np.float32) + encoder_embed_positions_weight = np.random.rand(102, 16).astype(np.float32) + encoder_layers_0_self_attn_k_proj_bias = np.random.rand(16).astype(np.float32) + encoder_layers_0_self_attn_layer_norm_weight = np.random.rand(16).astype(np.float32) + encoder_layers_0_fc1_bias = np.zeros((4), dtype=np.float32) + + matmul_257 = np.random.rand(16, 16).astype(np.float32) + matmul_267 = np.random.rand(16, 16).astype(np.float32) + matmul_268 = np.random.rand(16, 16).astype(np.float32) + matmul_270 = np.random.rand(16, 16).astype(np.float32) + matmul_271 = np.random.rand(16, 4).astype(np.float32) + matmul_272 = np.random.rand(4, 16).astype(np.float32) + matmul_273 = np.random.rand(16, 16).astype(np.float32) + matmul_283 = np.random.rand(16, 16).astype(np.float32) + matmul_284 = np.random.rand(16, 16).astype(np.float32) + matmul_286 = np.random.rand(16, 16).astype(np.float32) + matmul_287 = np.random.rand(16, 4).astype(np.float32) + matmul_288 = np.random.rand(4, 16).astype(np.float32) + + model = make_model( + encoder_embed_positions_weight=encoder_embed_positions_weight, + encoder_embed_tokens_weight=encoder_embed_tokens_weight, + encoder_layers_0_self_attn_k_proj_bias=encoder_layers_0_self_attn_k_proj_bias, + encoder_layers_0_self_attn_layer_norm_weight=encoder_layers_0_self_attn_layer_norm_weight, + encoder_layers_0_fc1_bias=encoder_layers_0_fc1_bias, + matmul_257=matmul_257, + matmul_267=matmul_267, + matmul_268=matmul_268, + matmul_270=matmul_270, + matmul_271=matmul_271, + matmul_272=matmul_272, + matmul_273=matmul_273, + matmul_283=matmul_283, + matmul_284=matmul_284, + matmul_286=matmul_286, + matmul_287=matmul_287, + matmul_288=matmul_288, + ) + return model + + +class _BartEncoderTest: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "input_ids": np.random.randint(0, 1024, (1, 16)).astype(np.int64), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def bart_encoder_test(): + return _BartEncoderTest() diff --git a/onnxscript/rewriter/models/_phi2lm.py b/onnxscript/rewriter/models/_phi2lm.py new file mode 100644 index 0000000000..08f529a6de --- /dev/null +++ b/onnxscript/rewriter/models/_phi2lm.py @@ -0,0 +1,508 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Generated from Phi2LM 1 Layer ONNX model produced by the new (Dynamo) exporter +# ruff: noqa: F821 + +import numpy +import onnx_ir as ir + +from onnxscript import script +from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import BOOL, FLOAT, INT64 + +value_infos = { + "model_embed_tokens_weight": FLOAT[51200, 2560], + "model_layers_0_self_attn_q_proj_weight": FLOAT[2560, 2560], + "model_layers_0_self_attn_q_proj_bias": FLOAT[2560], + "model_layers_0_self_attn_k_proj_weight": FLOAT[2560, 2560], + "model_layers_0_self_attn_k_proj_bias": FLOAT[2560], + "model_layers_0_self_attn_v_proj_weight": FLOAT[2560, 2560], + "model_layers_0_self_attn_v_proj_bias": FLOAT[2560], + "model_layers_0_self_attn_dense_weight": FLOAT[2560, 2560], + "model_layers_0_self_attn_dense_bias": FLOAT[2560], + "model_layers_0_mlp_fc1_weight": FLOAT[10240, 2560], + "model_layers_0_mlp_fc1_bias": FLOAT[10240], + "model_layers_0_mlp_fc2_weight": FLOAT[2560, 10240], + "model_layers_0_mlp_fc2_bias": FLOAT[2560], + "model_layers_0_input_layernorm_weight": FLOAT[2560], + "model_layers_0_input_layernorm_bias": FLOAT[2560], + "model_final_layernorm_weight": FLOAT[2560], + "model_final_layernorm_bias": FLOAT[2560], + "lm_head_weight": FLOAT[51200, 2560], + "lm_head_bias": FLOAT[51200], + "expand_2": FLOAT[1, 16, 1], + "val_1": INT64[1], + "sym_size_int_44": INT64, + "val_4": INT64[1], + "val_5": INT64[1], + "sym_size_int_50": INT64, + "embedding": FLOAT["s34", "s16", 2560], + "add_4": INT64, + "val_6": FLOAT, + "val_7": INT64, + "arange": INT64["s16"], + "val_8": INT64[1], + "unsqueeze": INT64[1, "s16"], + "val_10": FLOAT, + "val_13": INT64[1], + "val_14": INT64[1], + "val_15": INT64[2], + "full": FLOAT["s16", "s16 + s62"], + "diagonal": INT64, + "triu": FLOAT["s16", "s16 + s62"], + "val_18": INT64, + "val_19": INT64, + "arange_1": INT64["s16 + s62"], + "val_21": INT64[2], + "view": INT64["s16", 1], + "gt": BOOL["s16", "s16 + s62"], + "convert_element_type_default": FLOAT["s16", "s16 + s62"], + "mul_16": FLOAT["s16", "s16 + s62"], + "val_22": INT64[1], + "val_421": INT64[2], + "unsqueeze_4": FLOAT[1, 1, "s16", "s16 + s62"], + "val_23": INT64, + "val_31": INT64, + "val_49": INT64[1], + "val_50": INT64[4], + "val_52": INT64[4], + "expand_1": FLOAT["s34", 1, "s16", "s16 + s62"], + "val_61": INT64, + "val_72": INT64[1], + "val_74": INT64[1], + "val_75": INT64[1], + "val_78": INT64[1], + "val_79": INT64[1], + "slice_8": FLOAT["s34", 1, "s16", "s16 + s62"], + "val_422": INT64[2], + "unsqueeze_6": INT64["s34", 1, 1, "s16 + s62"], + "convert_element_type_default_1": FLOAT["s34", 1, 1, "s16 + s62"], + "add_89": FLOAT["s34", 1, "s16", "s16 + s62"], + "scalar_tensor_default": FLOAT, + "eq_64": BOOL["s34", 1, "s16", "s16 + s62"], + "val_119": INT64[1], + "val_121": INT64[1], + "val_122": INT64[1], + "val_125": INT64[1], + "val_126": INT64[1], + "slice_14": FLOAT["s34", 1, "s16", "s16 + s62"], + "val_127": FLOAT, + "masked_fill": FLOAT["s34", 1, "s16", "s16 + s62"], + "val_179": INT64[4], + "val_180": INT64, + "val_181": INT64[None], + "val_186": INT64[None, 1], + "val_187": FLOAT["s16", 1, "s34", "s16 + s62"], + "val_188": FLOAT["s16", 1, "s34", "s16 + s62"], + "val_189": FLOAT["s16", 1, "s34", "s16 + s62"], + "val_191": INT64[4], + "val_192": INT64, + "val_193": INT64[None], + "val_198": INT64[None, 1], + "val_199": FLOAT[1, "s34", "s16", "s16 + s62"], + "val_200": FLOAT[1, "s34", "s16", "s16 + s62"], + "val_201": FLOAT[1, "s34", "s16", "s16 + s62"], + "slice_scatter_1": FLOAT["s34", 1, "s16", "s16 + s62"], + "val_203": INT64[4], + "val_204": INT64, + "val_205": INT64[None], + "val_210": INT64[None, 1], + "slice_scatter_2": FLOAT["s34", 1, "s16", "s16 + s62"], + "unsqueeze_9": INT64[1, 1, "s16"], + "_to_copy": FLOAT[1, 1, "s16"], + "matmul": FLOAT[1, 16, "s16"], + "transpose": FLOAT[1, "s16", 16], + "cat": FLOAT[1, "s16", 32], + "cos": FLOAT[1, "s16", 32], + "sin": FLOAT[1, "s16", 32], + "layer_norm": FLOAT["s34", "s16", 2560], + "val_246": FLOAT[2560, 2560], + "val_247": FLOAT["s34", "s16", 2560], + "linear": FLOAT["s34", "s16", 2560], + "val_252": INT64[1], + "val_253": INT64[4], + "view_1": FLOAT["s34", "s16", 32, 80], + "transpose_1": FLOAT["s34", 32, "s16", 80], + "val_255": FLOAT[2560, 2560], + "val_256": FLOAT["s34", "s16", 2560], + "linear_1": FLOAT["s34", "s16", 2560], + "val_261": INT64[4], + "view_2": FLOAT["s34", "s16", 32, 80], + "transpose_2": FLOAT["s34", 32, "s16", 80], + "val_263": FLOAT[2560, 2560], + "val_264": FLOAT["s34", "s16", 2560], + "linear_2": FLOAT["s34", "s16", 2560], + "val_269": INT64[4], + "view_3": FLOAT["s34", "s16", 32, 80], + "transpose_3": FLOAT["s34", 32, "s16", 80], + "val_273": INT64[1], + "val_277": INT64[1], + "val_280": INT64[1], + "val_281": INT64[1], + "slice_26": FLOAT["s34", 32, "s16", 32], + "val_284": INT64[1], + "val_287": INT64[1], + "val_290": INT64[1], + "val_291": INT64[1], + "slice_27": FLOAT["s34", 32, "s16", 48], + "val_294": INT64[1], + "val_297": INT64[1], + "val_300": INT64[1], + "val_301": INT64[1], + "slice_28": FLOAT["s34", 32, "s16", 32], + "val_304": INT64[1], + "val_307": INT64[1], + "val_310": INT64[1], + "val_311": INT64[1], + "slice_29": FLOAT["s34", 32, "s16", 48], + "unsqueeze_10": FLOAT[1, 1, "s16", 32], + "unsqueeze_11": FLOAT[1, 1, "s16", 32], + "mul_213": FLOAT["s34", 32, "s16", 32], + "val_314": INT64[1], + "val_318": INT64[1], + "val_321": INT64[1], + "val_322": INT64[1], + "slice_30": FLOAT["s34", 32, "s16", 16], + "val_325": INT64[1], + "val_328": INT64[1], + "val_331": INT64[1], + "val_332": INT64[1], + "slice_31": FLOAT["s34", 32, "s16", 16], + "neg": FLOAT["s34", 32, "s16", 16], + "cat_1": FLOAT["s34", 32, "s16", 32], + "mul_230": FLOAT["s34", 32, "s16", 32], + "add_290": FLOAT["s34", 32, "s16", 32], + "mul_238": FLOAT["s34", 32, "s16", 32], + "val_335": INT64[1], + "val_338": INT64[1], + "val_341": INT64[1], + "val_342": INT64[1], + "slice_32": FLOAT["s34", 32, "s16", 16], + "val_345": INT64[1], + "val_348": INT64[1], + "val_351": INT64[1], + "val_352": INT64[1], + "slice_33": FLOAT["s34", 32, "s16", 16], + "neg_1": FLOAT["s34", 32, "s16", 16], + "cat_2": FLOAT["s34", 32, "s16", 32], + "mul_255": FLOAT["s34", 32, "s16", 32], + "add_326": FLOAT["s34", 32, "s16", 32], + "cat_3": FLOAT["s34", 32, "s16", 80], + "cat_4": FLOAT["s34", 32, "s16", 80], + "transpose_4": FLOAT["s34", 32, 80, "s16 + s62"], + "matmul_1": FLOAT["s34", 32, "s16", "s16 + s62"], + "val_353": FLOAT, + "mul_287": FLOAT["s34", 32, "s16", "s16 + s62"], + "val_372": INT64[1], + "val_374": INT64[1], + "val_375": INT64[1], + "val_378": INT64[1], + "val_379": INT64[1], + "slice_41": FLOAT["s34", 1, "s16", "s16 + s62"], + "add_387": FLOAT["s34", 32, "s16", "s16 + s62"], + "val_380": FLOAT["s34", 32, "s16", "s16 + s62"], + "matmul_2": FLOAT["s34", 32, "s16", 80], + "transpose_5": FLOAT["s34", "s16", 32, 80], + "val_385": INT64[3], + "view_4": FLOAT["s34", "s16", 2560], + "val_387": FLOAT[2560, 2560], + "val_388": FLOAT["s34", "s16", 2560], + "linear_3": FLOAT["s34", "s16", 2560], + "val_389": FLOAT[2560, 10240], + "val_390": FLOAT["s34", "s16", 10240], + "linear_4": FLOAT["s34", "s16", 10240], + "val_391": FLOAT, + "mul_351": FLOAT["s34", "s16", 10240], + "val_392": FLOAT, + "pow_1": FLOAT["s34", "s16", 10240], + "val_393": FLOAT, + "mul_358": FLOAT["s34", "s16", 10240], + "add_446": FLOAT["s34", "s16", 10240], + "val_394": FLOAT, + "mul_365": FLOAT["s34", "s16", 10240], + "tanh": FLOAT["s34", "s16", 10240], + "add_459": FLOAT["s34", "s16", 10240], + "mul_375": FLOAT["s34", "s16", 10240], + "val_395": FLOAT[10240, 2560], + "val_396": FLOAT["s34", "s16", 2560], + "linear_5": FLOAT["s34", "s16", 2560], + "add_476": FLOAT["s34", "s16", 2560], + "add_481": FLOAT["s34", "s16", 2560], + "layer_norm_1": FLOAT["s34", "s16", 2560], + "val_419": FLOAT[2560, 51200], + "val_420": FLOAT["s34", "s16", 51200], +} + + +def make_model( + model_embed_tokens_weight, + model_layers_0_self_attn_q_proj_weight, + model_layers_0_self_attn_q_proj_bias, + model_layers_0_self_attn_k_proj_weight, + model_layers_0_self_attn_k_proj_bias, + model_layers_0_self_attn_v_proj_weight, + model_layers_0_self_attn_v_proj_bias, + model_layers_0_self_attn_dense_weight, + model_layers_0_self_attn_dense_bias, + model_layers_0_mlp_fc1_weight, + model_layers_0_mlp_fc1_bias, + model_layers_0_mlp_fc2_weight, + model_layers_0_mlp_fc2_bias, + model_layers_0_input_layernorm_weight, + model_layers_0_input_layernorm_bias, + model_final_layernorm_weight, + model_final_layernorm_bias, + lm_head_weight, + lm_head_bias, + expand_2, +): + @script() + def main_graph( + input_ids: INT64["s34", "s16"], + attention_mask: INT64["s34", "s16 + s62"], + past_key_values_key_cache_0: FLOAT["s34", 32, "s62", 80], + past_key_values_value_cache_0: FLOAT["s34", 32, "s62", 80], + ) -> ( + FLOAT["s34", "s16", 51200], + FLOAT["s34", 32, "s16 + s62", 80], + FLOAT["s34", 32, "s16 + s62", 80], + ): + val_1 = opset18.Shape(input_ids, end=2, start=1) + sym_size_int_44 = opset18.Squeeze(val_1) + val_4 = opset18.Shape(past_key_values_value_cache_0, end=1, start=0) + val_5 = opset18.Shape(past_key_values_value_cache_0, end=3, start=2) + sym_size_int_50 = opset18.Squeeze(val_5) + embedding = opset18.Gather(model_embed_tokens_weight, input_ids, axis=0) + add_4 = opset18.Add(sym_size_int_50, sym_size_int_44) + arange = opset18.Range(sym_size_int_50, add_4, 1) + unsqueeze = opset18.Unsqueeze(arange, [0]) + val_14 = opset18.Reshape(add_4, [-1], allowzero=0) + val_15 = opset18.Concat(val_1, val_14, axis=0) + full = opset18.Expand(-3.4028235e38, val_15) + diagonal = opset18.Constant(value_int=1) + triu = opset18.Trilu(full, diagonal, upper=1) + arange_1 = opset18.Range(0, add_4, 1) + view = opset18.Reshape(arange, [-1, 1], allowzero=1) + gt = opset18.Greater(arange_1, view) + convert_element_type_default = opset18.Cast(gt, to=1) + mul_16 = opset18.Mul(triu, convert_element_type_default) + unsqueeze_4 = opset18.Unsqueeze(mul_16, [0, 1]) + val_50 = opset18.Concat(val_4, [1], [-1], [-1], axis=0) + val_52 = opset18.Abs(val_50) + expand_1 = opset18.Expand(unsqueeze_4, val_52) + val_72 = opset18.Constant(value_ints=[0]) + val_74 = opset18.Constant(value_ints=[-1]) + val_75 = opset18.Reshape(add_4, val_74, allowzero=0) + val_79 = opset18.Constant(value_ints=[1]) + slice_8 = opset18.Slice(expand_1, val_72, val_75, [3], val_79) + unsqueeze_6 = opset18.Unsqueeze(attention_mask, [1, 2]) + convert_element_type_default_1 = opset18.Cast(unsqueeze_6, to=1) + add_89 = opset18.Add(slice_8, convert_element_type_default_1) + eq_64 = opset18.Equal(add_89, 0.0) + val_119 = opset18.Constant(value_ints=[0]) + val_121 = opset18.Constant(value_ints=[-1]) + val_122 = opset18.Reshape(add_4, val_121, allowzero=0) + val_126 = opset18.Constant(value_ints=[1]) + slice_14 = opset18.Slice(expand_1, val_119, val_122, [3], val_126) + masked_fill = opset18.Where(eq_64, -3.4028235e38, slice_14) + val_179 = opset18.Shape(expand_1, start=0) + val_180 = opset18.Gather(val_179, 2, axis=0) + val_181 = opset18.Range(0, val_180, 1) + val_186 = opset18.Unsqueeze(val_181, [-1]) + val_187 = opset18.Transpose(masked_fill, perm=[2, 1, 0, 3]) + val_188 = opset18.Transpose(expand_1, perm=[2, 1, 0, 3]) + val_189 = opset18.ScatterND(val_188, val_186, val_187, reduction="none") + val_191 = opset18.Shape(expand_1, start=0) + val_192 = opset18.Gather(val_191, 1, axis=0) + val_193 = opset18.Range(0, val_192, 1) + val_198 = opset18.Unsqueeze(val_193, [-1]) + val_199 = opset18.Transpose(val_189, perm=[1, 2, 0, 3]) + val_200 = opset18.Transpose(expand_1, perm=[1, 0, 2, 3]) + val_201 = opset18.ScatterND(val_200, val_198, val_199, reduction="none") + slice_scatter_1 = opset18.Transpose(val_201, perm=[1, 0, 2, 3]) + val_203 = opset18.Shape(expand_1, start=0) + val_204 = opset18.Gather(val_203, 0, axis=0) + val_205 = opset18.Range(0, val_204, 1) + val_210 = opset18.Unsqueeze(val_205, [-1]) + slice_scatter_2 = opset18.ScatterND( + expand_1, val_210, slice_scatter_1, reduction="none" + ) + unsqueeze_9 = opset18.Unsqueeze(unsqueeze, [1]) + _to_copy = opset18.Cast(unsqueeze_9, to=1) + matmul = opset18.MatMul(expand_2, _to_copy) + transpose = opset18.Transpose(matmul, perm=[0, 2, 1]) + cat = opset18.Concat(transpose, transpose, axis=-1) + cos = opset18.Cos(cat) + sin = opset18.Sin(cat) + layer_norm = opset18.LayerNormalization( + embedding, + model_layers_0_input_layernorm_weight, + model_layers_0_input_layernorm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_246 = opset18.Transpose(model_layers_0_self_attn_q_proj_weight, perm=[1, 0]) + val_247 = opset18.MatMul(layer_norm, val_246) + linear = opset18.Add(val_247, model_layers_0_self_attn_q_proj_bias) + val_253 = opset18.Concat(val_4, val_1, [-1], [80], axis=0) + view_1 = opset18.Reshape(linear, val_253, allowzero=1) + transpose_1 = opset18.Transpose(view_1, perm=[0, 2, 1, 3]) + val_255 = opset18.Transpose(model_layers_0_self_attn_k_proj_weight, perm=[1, 0]) + val_256 = opset18.MatMul(layer_norm, val_255) + linear_1 = opset18.Add(val_256, model_layers_0_self_attn_k_proj_bias) + val_261 = opset18.Concat(val_4, val_1, [-1], [80], axis=0) + view_2 = opset18.Reshape(linear_1, val_261, allowzero=1) + transpose_2 = opset18.Transpose(view_2, perm=[0, 2, 1, 3]) + val_263 = opset18.Transpose(model_layers_0_self_attn_v_proj_weight, perm=[1, 0]) + val_264 = opset18.MatMul(layer_norm, val_263) + linear_2 = opset18.Add(val_264, model_layers_0_self_attn_v_proj_bias) + val_269 = opset18.Concat(val_4, val_1, [-1], [80], axis=0) + view_3 = opset18.Reshape(linear_2, val_269, allowzero=1) + transpose_3 = opset18.Transpose(view_3, perm=[0, 2, 1, 3]) + val_281 = opset18.Constant(value_ints=[1]) + slice_26 = opset18.Slice(transpose_1, [0], [32], [3], val_281) + val_291 = opset18.Constant(value_ints=[1]) + slice_27 = opset18.Slice(transpose_1, [32], [9223372036854775807], [3], val_291) + val_301 = opset18.Constant(value_ints=[1]) + slice_28 = opset18.Slice(transpose_2, [0], [32], [3], val_301) + val_311 = opset18.Constant(value_ints=[1]) + slice_29 = opset18.Slice(transpose_2, [32], [9223372036854775807], [3], val_311) + unsqueeze_10 = opset18.Unsqueeze(cos, [1]) + unsqueeze_11 = opset18.Unsqueeze(sin, [1]) + mul_213 = opset18.Mul(slice_26, unsqueeze_10) + val_322 = opset18.Constant(value_ints=[1]) + slice_30 = opset18.Slice(slice_26, [0], [16], [3], val_322) + val_332 = opset18.Constant(value_ints=[1]) + slice_31 = opset18.Slice(slice_26, [16], [9223372036854775807], [3], val_332) + neg = opset18.Neg(slice_31) + cat_1 = opset18.Concat(neg, slice_30, axis=-1) + mul_230 = opset18.Mul(cat_1, unsqueeze_11) + add_290 = opset18.Add(mul_213, mul_230) + mul_238 = opset18.Mul(slice_28, unsqueeze_10) + val_342 = opset18.Constant(value_ints=[1]) + slice_32 = opset18.Slice(slice_28, [0], [16], [3], val_342) + val_352 = opset18.Constant(value_ints=[1]) + slice_33 = opset18.Slice(slice_28, [16], [9223372036854775807], [3], val_352) + neg_1 = opset18.Neg(slice_33) + cat_2 = opset18.Concat(neg_1, slice_32, axis=-1) + mul_255 = opset18.Mul(cat_2, unsqueeze_11) + add_326 = opset18.Add(mul_238, mul_255) + cat_3 = opset18.Concat(add_290, slice_27, axis=-1) + cat_4 = opset18.Concat(add_326, slice_29, axis=-1) + cat_5 = opset18.Concat(past_key_values_key_cache_0, cat_4, axis=-2) + cat_6 = opset18.Concat(past_key_values_value_cache_0, transpose_3, axis=-2) + transpose_4 = opset18.Transpose(cat_5, perm=[0, 1, 3, 2]) + matmul_1 = opset18.MatMul(cat_3, transpose_4) + mul_287 = opset18.Mul(matmul_1, 0.1118034) + val_372 = opset18.Constant(value_ints=[0]) + val_374 = opset18.Constant(value_ints=[-1]) + val_375 = opset18.Reshape(add_4, val_374, allowzero=0) + val_379 = opset18.Constant(value_ints=[1]) + slice_41 = opset18.Slice(slice_scatter_2, val_372, val_375, [3], val_379) + add_387 = opset18.Add(mul_287, slice_41) + val_380 = opset18.Softmax(add_387, axis=-1) + matmul_2 = opset18.MatMul(val_380, cat_6) + transpose_5 = opset18.Transpose(matmul_2, perm=[0, 2, 1, 3]) + val_385 = opset18.Concat(val_4, val_1, [-1], axis=0) + view_4 = opset18.Reshape(transpose_5, val_385, allowzero=1) + val_387 = opset18.Transpose(model_layers_0_self_attn_dense_weight, perm=[1, 0]) + val_388 = opset18.MatMul(view_4, val_387) + linear_3 = opset18.Add(val_388, model_layers_0_self_attn_dense_bias) + val_389 = opset18.Transpose(model_layers_0_mlp_fc1_weight, perm=[1, 0]) + val_390 = opset18.MatMul(layer_norm, val_389) + linear_4 = opset18.Add(val_390, model_layers_0_mlp_fc1_bias) + mul_351 = opset18.Mul(linear_4, 0.5) + pow_1 = opset18.Pow(linear_4, 3.0) + mul_358 = opset18.Mul(pow_1, 0.044715) + add_446 = opset18.Add(linear_4, mul_358) + mul_365 = opset18.Mul(add_446, 0.7978846) + tanh = opset18.Tanh(mul_365) + add_459 = opset18.Add(tanh, 1.0) + mul_375 = opset18.Mul(mul_351, add_459) + val_395 = opset18.Transpose(model_layers_0_mlp_fc2_weight, perm=[1, 0]) + val_396 = opset18.MatMul(mul_375, val_395) + linear_5 = opset18.Add(val_396, model_layers_0_mlp_fc2_bias) + add_476 = opset18.Add(linear_3, linear_5) + add_481 = opset18.Add(add_476, embedding) + layer_norm_1 = opset18.LayerNormalization( + add_481, + model_final_layernorm_weight, + model_final_layernorm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_419 = opset18.Transpose(lm_head_weight, perm=[1, 0]) + val_420 = opset18.MatMul(layer_norm_1, val_419) + linear_6 = opset18.Add(val_420, lm_head_bias) + return linear_6, cat_5, cat_6 + + model = main_graph.to_model_proto(value_infos=value_infos) + return model + + +def make_model_with_random_weights(): + model_embed_tokens_weight = numpy.random.rand(51200, 2560).astype(numpy.float32) + model_layers_0_self_attn_q_proj_weight = numpy.random.rand(2560, 2560).astype( + numpy.float32 + ) + model_layers_0_self_attn_q_proj_bias = numpy.random.rand(2560).astype(numpy.float32) + model_layers_0_self_attn_k_proj_weight = numpy.random.rand(2560, 2560).astype( + numpy.float32 + ) + model_layers_0_self_attn_k_proj_bias = numpy.random.rand(2560).astype(numpy.float32) + model_layers_0_self_attn_v_proj_weight = numpy.random.rand(2560, 2560).astype( + numpy.float32 + ) + model_layers_0_self_attn_v_proj_bias = numpy.random.rand(2560).astype(numpy.float32) + model_layers_0_self_attn_dense_weight = numpy.random.rand(2560, 2560).astype(numpy.float32) + model_layers_0_self_attn_dense_bias = numpy.random.rand(2560).astype(numpy.float32) + model_layers_0_mlp_fc1_weight = numpy.random.rand(10240, 2560).astype(numpy.float32) + model_layers_0_mlp_fc1_bias = numpy.random.rand(10240).astype(numpy.float32) + model_layers_0_mlp_fc2_weight = numpy.random.rand(2560, 10240).astype(numpy.float32) + model_layers_0_mlp_fc2_bias = numpy.random.rand(2560).astype(numpy.float32) + model_layers_0_input_layernorm_weight = numpy.random.rand(2560).astype(numpy.float32) + model_layers_0_input_layernorm_bias = numpy.random.rand(2560).astype(numpy.float32) + model_final_layernorm_weight = numpy.random.rand(2560).astype(numpy.float32) + model_final_layernorm_bias = numpy.random.rand(2560).astype(numpy.float32) + lm_head_weight = numpy.random.rand(51200, 2560).astype(numpy.float32) + lm_head_bias = numpy.random.rand(51200).astype(numpy.float32) + expand_2 = numpy.random.rand(1, 16, 1).astype(numpy.float32) + model = make_model( + model_embed_tokens_weight, + model_layers_0_self_attn_q_proj_weight, + model_layers_0_self_attn_q_proj_bias, + model_layers_0_self_attn_k_proj_weight, + model_layers_0_self_attn_k_proj_bias, + model_layers_0_self_attn_v_proj_weight, + model_layers_0_self_attn_v_proj_bias, + model_layers_0_self_attn_dense_weight, + model_layers_0_self_attn_dense_bias, + model_layers_0_mlp_fc1_weight, + model_layers_0_mlp_fc1_bias, + model_layers_0_mlp_fc2_weight, + model_layers_0_mlp_fc2_bias, + model_layers_0_input_layernorm_weight, + model_layers_0_input_layernorm_bias, + model_final_layernorm_weight, + model_final_layernorm_bias, + lm_head_weight, + lm_head_bias, + expand_2, + ) + return model + + +class _Phi2LMTest: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + +def phi2lm_test(): + return _Phi2LMTest() diff --git a/onnxscript/rewriter/models/_phi4lm.py b/onnxscript/rewriter/models/_phi4lm.py new file mode 100644 index 0000000000..8a911095b5 --- /dev/null +++ b/onnxscript/rewriter/models/_phi4lm.py @@ -0,0 +1,747 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Generated from Phi4LM 2 Layer ONNX model produced by the new (Dynamo) exporter +# ruff: noqa: F821 + +import numpy +import onnx_ir as ir + +from onnxscript import script +from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import BOOL, FLOAT, INT64 + +value_infos = { + "model_embed_tokens_weight": FLOAT[100352, 5120], + "model_layers_0_self_attn_o_proj_weight": FLOAT[5120, 5120], + "model_layers_0_self_attn_qkv_proj_weight": FLOAT[7680, 5120], + "model_layers_0_mlp_gate_up_proj_weight": FLOAT[35840, 5120], + "model_layers_0_mlp_down_proj_weight": FLOAT[5120, 17920], + "model_layers_0_input_layernorm_weight": FLOAT[5120], + "model_layers_0_post_attention_layernorm_weight": FLOAT[5120], + "model_layers_1_self_attn_o_proj_weight": FLOAT[5120, 5120], + "model_layers_1_self_attn_qkv_proj_weight": FLOAT[7680, 5120], + "model_layers_1_mlp_gate_up_proj_weight": FLOAT[35840, 5120], + "model_layers_1_mlp_down_proj_weight": FLOAT[5120, 17920], + "model_layers_1_input_layernorm_weight": FLOAT[5120], + "model_layers_1_post_attention_layernorm_weight": FLOAT[5120], + "model_norm_weight": FLOAT[5120], + "lm_head_weight": FLOAT[100352, 5120], + "expand_2": FLOAT[1, 64, 1], + "val_1": INT64[1], + "sym_size_int_61": INT64, + "val_5": INT64[1], + "sym_size_int_67": INT64, + "val_6": INT64[1], + "embedding": FLOAT["s34", "s16", 5120], + "add_4": INT64, + "val_11": INT64, + "arange": INT64["s16"], + "val_12": INT64[1], + "unsqueeze": INT64[1, "s16"], + "val_14": FLOAT, + "val_17": INT64[1], + "val_18": INT64[1], + "val_19": INT64[2], + "full": FLOAT["s16", "s16 + s17"], + "val_22": INT64, + "val_23": INT64, + "arange_1": INT64["s16 + s17"], + "val_25": INT64[2], + "view": INT64["s16", 1], + "gt": BOOL["s16", "s16 + s17"], + "convert_element_type_default": FLOAT["s16", "s16 + s17"], + "mul_14": FLOAT["s16", "s16 + s17"], + "val_26": INT64[1], + "val_805": INT64[2], + "unsqueeze_4": FLOAT[1, 1, "s16", "s16 + s17"], + "val_27": INT64, + "val_35": INT64, + "val_53": INT64[1], + "val_54": INT64[4], + "val_56": INT64[4], + "expand_1": FLOAT["s34", 1, "s16", "s16 + s17"], + "val_65": INT64, + "val_76": INT64[1], + "val_78": INT64[1], + "val_79": INT64[1], + "val_82": INT64[1], + "val_83": INT64[1], + "slice_8": FLOAT["s34", 1, "s16", "s16 + s17"], + "val_94": INT64[1], + "val_806": INT64[2], + "unsqueeze_6": INT64["s34", 1, 1, "s16 + s17"], + "convert_element_type_default_1": FLOAT["s34", 1, 1, "s16 + s17"], + "add_86": FLOAT["s34", 1, "s16", "s16 + s17"], + "scalar_tensor_default": FLOAT, + "eq_65": BOOL["s34", 1, "s16", "s16 + s17"], + "val_123": INT64[1], + "val_125": INT64[1], + "val_126": INT64[1], + "val_129": INT64[1], + "val_130": INT64[1], + "slice_14": FLOAT["s34", 1, "s16", "s16 + s17"], + "val_131": FLOAT, + "masked_fill": FLOAT["s34", 1, "s16", "s16 + s17"], + "val_183": INT64[4], + "val_184": INT64, + "val_185": INT64[None], + "val_190": INT64[None, 1], + "val_191": FLOAT["s16", 1, "s34", "s16 + s17"], + "val_192": FLOAT["s16", 1, "s34", "s16 + s17"], + "val_193": FLOAT["s16", 1, "s34", "s16 + s17"], + "val_195": INT64[4], + "val_196": INT64, + "val_197": INT64[None], + "val_202": INT64[None, 1], + "val_203": FLOAT[1, "s34", "s16", "s16 + s17"], + "val_204": FLOAT[1, "s34", "s16", "s16 + s17"], + "val_205": FLOAT[1, "s34", "s16", "s16 + s17"], + "slice_scatter_1": FLOAT["s34", 1, "s16", "s16 + s17"], + "val_207": INT64[4], + "val_208": INT64, + "val_209": INT64[None], + "val_214": INT64[None, 1], + "slice_scatter_2": FLOAT["s34", 1, "s16", "s16 + s17"], + "unsqueeze_9": INT64[1, 1, "s16"], + "_to_copy": FLOAT[1, 1, "s16"], + "matmul": FLOAT[1, 64, "s16"], + "transpose": FLOAT[1, "s16", 64], + "cat": FLOAT[1, "s16", 128], + "cos": FLOAT[1, "s16", 128], + "sin": FLOAT[1, "s16", 128], + "val_248": FLOAT, + "pow_1": FLOAT["s34", "s16", 5120], + "val_250": INT64[1], + "mean": FLOAT["s34", "s16", 1], + "val_251": FLOAT, + "add_189": FLOAT["s34", "s16", 1], + "val_252": FLOAT["s34", "s16", 1], + "rsqrt": FLOAT["s34", "s16", 1], + "mul_167": FLOAT["s34", "s16", 5120], + "mul_171": FLOAT["s34", "s16", 5120], + "val_253": FLOAT[5120, 7680], + "linear": FLOAT["s34", "s16", 7680], + "val_256": INT64[1], + "val_260": INT64[1], + "val_263": INT64[1], + "val_264": INT64[1], + "slice_26": FLOAT["s34", "s16", 5120], + "val_267": INT64[1], + "val_271": INT64[1], + "val_274": INT64[1], + "val_275": INT64[1], + "slice_27": FLOAT["s34", "s16", 1280], + "val_278": INT64[1], + "val_281": INT64[1], + "val_284": INT64[1], + "val_285": INT64[1], + "slice_28": FLOAT["s34", "s16", 1280], + "val_290": INT64[1], + "val_291": INT64[4], + "view_1": FLOAT["s34", "s16", 40, 128], + "transpose_1": FLOAT["s34", 40, "s16", 128], + "val_297": INT64[4], + "view_2": FLOAT["s34", "s16", 10, 128], + "transpose_2": FLOAT["s34", 10, "s16", 128], + "val_303": INT64[4], + "view_3": FLOAT["s34", "s16", 10, 128], + "transpose_3": FLOAT["s34", 10, "s16", 128], + "unsqueeze_10": FLOAT[1, 1, "s16", 128], + "unsqueeze_11": FLOAT[1, 1, "s16", 128], + "mul_223": FLOAT["s34", 40, "s16", 128], + "val_328": INT64[1], + "val_332": INT64[1], + "val_335": INT64[1], + "val_336": INT64[1], + "slice_31": FLOAT["s34", 40, "s16", 64], + "val_339": INT64[1], + "val_342": INT64[1], + "val_345": INT64[1], + "val_346": INT64[1], + "slice_32": FLOAT["s34", 40, "s16", 64], + "neg": FLOAT["s34", 40, "s16", 64], + "cat_1": FLOAT["s34", 40, "s16", 128], + "mul_240": FLOAT["s34", 40, "s16", 128], + "add_304": FLOAT["s34", 40, "s16", 128], + "mul_252": FLOAT["s34", 10, "s16", 128], + "val_349": INT64[1], + "val_352": INT64[1], + "val_355": INT64[1], + "val_356": INT64[1], + "slice_33": FLOAT["s34", 10, "s16", 64], + "val_359": INT64[1], + "val_362": INT64[1], + "val_365": INT64[1], + "val_366": INT64[1], + "slice_34": FLOAT["s34", 10, "s16", 64], + "neg_1": FLOAT["s34", 10, "s16", 64], + "cat_3": FLOAT["s34", 10, "s16", 128], + "mul_269": FLOAT["s34", 10, "s16", 128], + "add_345": FLOAT["s34", 10, "s16", 128], + "unsqueeze_12": FLOAT["s34", 10, 1, "s16 + s17", 128], + "val_410": INT64[1], + "val_411": INT64[1], + "val_412": INT64[1], + "val_413": INT64[1], + "val_414": INT64[5], + "val_416": INT64[5], + "expand_3": FLOAT["s34", 10, 4, "s16 + s17", 128], + "val_419": INT64[1], + "val_420": INT64[1], + "val_421": INT64[1], + "val_422": INT64[4], + "_unsafe_view": FLOAT["s34", 40, "s16 + s17", 128], + "unsqueeze_13": FLOAT["s34", 10, 1, "s16 + s17", 128], + "val_466": INT64[1], + "val_467": INT64[1], + "val_468": INT64[5], + "val_470": INT64[5], + "expand_4": FLOAT["s34", 10, 4, "s16 + s17", 128], + "val_473": INT64[1], + "val_474": INT64[1], + "val_475": INT64[4], + "_unsafe_view_1": FLOAT["s34", 40, "s16 + s17", 128], + "transpose_4": FLOAT["s34", 40, 128, "s16 + s17"], + "matmul_1": FLOAT["s34", 40, "s16", "s16 + s17"], + "val_477": FLOAT, + "mul_433": FLOAT["s34", 40, "s16", "s16 + s17"], + "val_496": INT64[1], + "val_498": INT64[1], + "val_499": INT64[1], + "val_502": INT64[1], + "val_503": INT64[1], + "slice_50": FLOAT["s34", 1, "s16", "s16 + s17"], + "add_491": FLOAT["s34", 40, "s16", "s16 + s17"], + "val_504": FLOAT["s34", 40, "s16", "s16 + s17"], + "matmul_2": FLOAT["s34", 40, "s16", 128], + "transpose_5": FLOAT["s34", "s16", 40, 128], + "val_509": INT64[3], + "view_4": FLOAT["s34", "s16", 5120], + "val_511": FLOAT[5120, 5120], + "linear_1": FLOAT["s34", "s16", 5120], + "add_534": FLOAT["s34", "s16", 5120], + "val_512": FLOAT, + "pow_2": FLOAT["s34", "s16", 5120], + "val_514": INT64[1], + "mean_1": FLOAT["s34", "s16", 1], + "add_547": FLOAT["s34", "s16", 1], + "val_515": FLOAT["s34", "s16", 1], + "rsqrt_1": FLOAT["s34", "s16", 1], + "mul_506": FLOAT["s34", "s16", 5120], + "mul_510": FLOAT["s34", "s16", 5120], + "val_516": FLOAT[5120, 35840], + "linear_2": FLOAT["s34", "s16", 35840], + "split_split_0": FLOAT["s34", "s16", 17920], + "split_split_1": FLOAT["s34", "s16", 17920], + "val_518": FLOAT["s34", "s16", 17920], + "silu": FLOAT["s34", "s16", 17920], + "mul_526": FLOAT["s34", "s16", 17920], + "val_519": FLOAT[17920, 5120], + "linear_3": FLOAT["s34", "s16", 5120], + "add_592": FLOAT["s34", "s16", 5120], + "val_520": FLOAT, + "pow_3": FLOAT["s34", "s16", 5120], + "val_522": INT64[1], + "mean_2": FLOAT["s34", "s16", 1], + "add_605": FLOAT["s34", "s16", 1], + "val_523": FLOAT["s34", "s16", 1], + "rsqrt_2": FLOAT["s34", "s16", 1], + "mul_548": FLOAT["s34", "s16", 5120], + "mul_552": FLOAT["s34", "s16", 5120], + "val_524": FLOAT[5120, 7680], + "linear_4": FLOAT["s34", "s16", 7680], + "val_527": INT64[1], + "val_530": INT64[1], + "val_533": INT64[1], + "val_534": INT64[1], + "slice_51": FLOAT["s34", "s16", 5120], + "val_537": INT64[1], + "val_540": INT64[1], + "val_543": INT64[1], + "val_544": INT64[1], + "slice_52": FLOAT["s34", "s16", 1280], + "val_547": INT64[1], + "val_550": INT64[1], + "val_553": INT64[1], + "val_554": INT64[1], + "slice_53": FLOAT["s34", "s16", 1280], + "val_559": INT64[4], + "view_5": FLOAT["s34", "s16", 40, 128], + "transpose_6": FLOAT["s34", 40, "s16", 128], + "val_565": INT64[4], + "view_6": FLOAT["s34", "s16", 10, 128], + "transpose_7": FLOAT["s34", 10, "s16", 128], + "val_571": INT64[4], + "view_7": FLOAT["s34", "s16", 10, 128], + "transpose_8": FLOAT["s34", 10, "s16", 128], + "unsqueeze_14": FLOAT[1, 1, "s16", 128], + "unsqueeze_15": FLOAT[1, 1, "s16", 128], + "mul_604": FLOAT["s34", 40, "s16", 128], + "val_595": INT64[1], + "val_598": INT64[1], + "val_601": INT64[1], + "val_602": INT64[1], + "slice_56": FLOAT["s34", 40, "s16", 64], + "val_605": INT64[1], + "val_608": INT64[1], + "val_611": INT64[1], + "val_612": INT64[1], + "slice_57": FLOAT["s34", 40, "s16", 64], + "neg_2": FLOAT["s34", 40, "s16", 64], + "cat_7": FLOAT["s34", 40, "s16", 128], + "mul_621": FLOAT["s34", 40, "s16", 128], + "add_720": FLOAT["s34", 40, "s16", 128], + "mul_633": FLOAT["s34", 10, "s16", 128], + "val_615": INT64[1], + "val_618": INT64[1], + "val_621": INT64[1], + "val_622": INT64[1], + "slice_58": FLOAT["s34", 10, "s16", 64], + "val_625": INT64[1], + "val_628": INT64[1], + "val_631": INT64[1], + "val_632": INT64[1], + "slice_59": FLOAT["s34", 10, "s16", 64], + "neg_3": FLOAT["s34", 10, "s16", 64], + "cat_9": FLOAT["s34", 10, "s16", 128], + "mul_650": FLOAT["s34", 10, "s16", 128], + "add_761": FLOAT["s34", 10, "s16", 128], + "unsqueeze_16": FLOAT["s34", 10, 1, "s16 + s17", 128], + "val_675": INT64[1], + "val_676": INT64[1], + "val_677": INT64[5], + "val_679": INT64[5], + "expand_5": FLOAT["s34", 10, 4, "s16 + s17", 128], + "val_682": INT64[1], + "val_683": INT64[1], + "val_684": INT64[4], + "_unsafe_view_2": FLOAT["s34", 40, "s16 + s17", 128], + "unsqueeze_17": FLOAT["s34", 10, 1, "s16 + s17", 128], + "val_728": INT64[1], + "val_729": INT64[1], + "val_730": INT64[5], + "val_732": INT64[5], + "expand_6": FLOAT["s34", 10, 4, "s16 + s17", 128], + "val_735": INT64[1], + "val_736": INT64[1], + "val_737": INT64[4], + "_unsafe_view_3": FLOAT["s34", 40, "s16 + s17", 128], + "transpose_9": FLOAT["s34", 40, 128, "s16 + s17"], + "matmul_3": FLOAT["s34", 40, "s16", "s16 + s17"], + "mul_814": FLOAT["s34", 40, "s16", "s16 + s17"], + "val_757": INT64[1], + "val_759": INT64[1], + "val_760": INT64[1], + "val_763": INT64[1], + "val_764": INT64[1], + "slice_75": FLOAT["s34", 1, "s16", "s16 + s17"], + "add_907": FLOAT["s34", 40, "s16", "s16 + s17"], + "val_765": FLOAT["s34", 40, "s16", "s16 + s17"], + "matmul_4": FLOAT["s34", 40, "s16", 128], + "transpose_10": FLOAT["s34", "s16", 40, 128], + "val_770": INT64[3], + "view_8": FLOAT["s34", "s16", 5120], + "val_772": FLOAT[5120, 5120], + "linear_5": FLOAT["s34", "s16", 5120], + "add_950": FLOAT["s34", "s16", 5120], + "val_773": FLOAT, + "pow_4": FLOAT["s34", "s16", 5120], + "val_775": INT64[1], + "mean_3": FLOAT["s34", "s16", 1], + "add_963": FLOAT["s34", "s16", 1], + "val_776": FLOAT["s34", "s16", 1], + "rsqrt_3": FLOAT["s34", "s16", 1], + "mul_887": FLOAT["s34", "s16", 5120], + "mul_891": FLOAT["s34", "s16", 5120], + "val_777": FLOAT[5120, 35840], + "linear_6": FLOAT["s34", "s16", 35840], + "split_1_split_0": FLOAT["s34", "s16", 17920], + "split_1_split_1": FLOAT["s34", "s16", 17920], + "val_778": FLOAT["s34", "s16", 17920], + "silu_1": FLOAT["s34", "s16", 17920], + "mul_907": FLOAT["s34", "s16", 17920], + "val_779": FLOAT[17920, 5120], + "linear_7": FLOAT["s34", "s16", 5120], + "add_1008": FLOAT["s34", "s16", 5120], + "val_780": FLOAT, + "pow_5": FLOAT["s34", "s16", 5120], + "val_782": INT64[1], + "mean_4": FLOAT["s34", "s16", 1], + "add_1021": FLOAT["s34", "s16", 1], + "val_783": FLOAT["s34", "s16", 1], + "rsqrt_4": FLOAT["s34", "s16", 1], + "mul_929": FLOAT["s34", "s16", 5120], + "mul_933": FLOAT["s34", "s16", 5120], + "val_804": FLOAT[5120, 100352], +} + + +def make_model( + model_embed_tokens_weight, + model_layers_0_self_attn_o_proj_weight, + model_layers_0_self_attn_qkv_proj_weight, + model_layers_0_mlp_gate_up_proj_weight, + model_layers_0_mlp_down_proj_weight, + model_layers_0_input_layernorm_weight, + model_layers_0_post_attention_layernorm_weight, + model_layers_1_self_attn_o_proj_weight, + model_layers_1_self_attn_qkv_proj_weight, + model_layers_1_mlp_gate_up_proj_weight, + model_layers_1_mlp_down_proj_weight, + model_layers_1_input_layernorm_weight, + model_layers_1_post_attention_layernorm_weight, + model_norm_weight, + lm_head_weight, + expand_2, +): + @script() + def main_graph( + input_ids: INT64["s34", "s16"], + attention_mask: INT64["s34", "s16 + s17"], + past_key_values_key_cache_0: FLOAT["s34", 10, "s17", 128], + past_key_values_key_cache_1: FLOAT["s34", 10, "s17", 128], + past_key_values_value_cache_0: FLOAT["s34", 10, "s17", 128], + past_key_values_value_cache_1: FLOAT["s34", 10, "s17", 128], + ) -> ( + FLOAT["s34", "s16", 100352], + FLOAT["s34", 10, "s16 + s17", 128], + FLOAT["s34", 10, "s16 + s17", 128], + FLOAT["s34", 10, "s16 + s17", 128], + FLOAT["s34", 10, "s16 + s17", 128], + ): + val_1 = opset18.Shape(input_ids, end=2, start=1) + sym_size_int_61 = opset18.Squeeze(val_1) + val_5 = opset18.Shape(past_key_values_key_cache_1, end=3, start=2) + sym_size_int_67 = opset18.Squeeze(val_5) + val_6 = opset18.Shape(past_key_values_value_cache_0, end=1, start=0) + embedding = opset18.Gather(model_embed_tokens_weight, input_ids, axis=0) + add_4 = opset18.Add(sym_size_int_67, sym_size_int_61) + arange = opset18.Range(sym_size_int_67, add_4, 1) + unsqueeze = opset18.Unsqueeze(arange, [0]) + val_18 = opset18.Reshape(add_4, [-1], allowzero=0) + val_19 = opset18.Concat(val_1, val_18, axis=0) + full = opset18.Expand(-3.4028235e38, val_19) + arange_1 = opset18.Range(0, add_4, 1) + view = opset18.Reshape(arange, [-1, 1], allowzero=1) + gt = opset18.Greater(arange_1, view) + convert_element_type_default = opset18.Cast(gt, to=1) + mul_14 = opset18.Mul(full, convert_element_type_default) + unsqueeze_4 = opset18.Unsqueeze(mul_14, [0, 1]) + val_54 = opset18.Concat(val_6, [1], [-1], [-1], axis=0) + val_56 = opset18.Abs(val_54) + expand_1 = opset18.Expand(unsqueeze_4, val_56) + val_76 = opset18.Constant(value_ints=[0]) + val_78 = opset18.Constant(value_ints=[-1]) + val_79 = opset18.Reshape(add_4, val_78, allowzero=0) + val_83 = opset18.Constant(value_ints=[1]) + slice_8 = opset18.Slice(expand_1, val_76, val_79, [3], val_83) + unsqueeze_6 = opset18.Unsqueeze(attention_mask, [1, 2]) + convert_element_type_default_1 = opset18.Cast(unsqueeze_6, to=1) + add_86 = opset18.Add(slice_8, convert_element_type_default_1) + eq_65 = opset18.Equal(add_86, 0.0) + val_123 = opset18.Constant(value_ints=[0]) + val_125 = opset18.Constant(value_ints=[-1]) + val_126 = opset18.Reshape(add_4, val_125, allowzero=0) + val_130 = opset18.Constant(value_ints=[1]) + slice_14 = opset18.Slice(expand_1, val_123, val_126, [3], val_130) + masked_fill = opset18.Where(eq_65, -3.4028235e38, slice_14) + val_183 = opset18.Shape(expand_1, start=0) + val_184 = opset18.Gather(val_183, 2, axis=0) + val_185 = opset18.Range(0, val_184, 1) + val_190 = opset18.Unsqueeze(val_185, [-1]) + val_191 = opset18.Transpose(masked_fill, perm=[2, 1, 0, 3]) + val_192 = opset18.Transpose(expand_1, perm=[2, 1, 0, 3]) + val_193 = opset18.ScatterND(val_192, val_190, val_191, reduction="none") + val_195 = opset18.Shape(expand_1, start=0) + val_196 = opset18.Gather(val_195, 1, axis=0) + val_197 = opset18.Range(0, val_196, 1) + val_202 = opset18.Unsqueeze(val_197, [-1]) + val_203 = opset18.Transpose(val_193, perm=[1, 2, 0, 3]) + val_204 = opset18.Transpose(expand_1, perm=[1, 0, 2, 3]) + val_205 = opset18.ScatterND(val_204, val_202, val_203, reduction="none") + slice_scatter_1 = opset18.Transpose(val_205, perm=[1, 0, 2, 3]) + val_207 = opset18.Shape(expand_1, start=0) + val_208 = opset18.Gather(val_207, 0, axis=0) + val_209 = opset18.Range(0, val_208, 1) + val_214 = opset18.Unsqueeze(val_209, [-1]) + slice_scatter_2 = opset18.ScatterND( + expand_1, val_214, slice_scatter_1, reduction="none" + ) + unsqueeze_9 = opset18.Unsqueeze(unsqueeze, [1]) + _to_copy = opset18.Cast(unsqueeze_9, to=1) + matmul = opset18.MatMul(expand_2, _to_copy) + transpose = opset18.Transpose(matmul, perm=[0, 2, 1]) + cat = opset18.Concat(transpose, transpose, axis=-1) + cos = opset18.Cos(cat) + sin = opset18.Sin(cat) + pow_1 = opset18.Pow(embedding, 2.0) + mean = opset18.ReduceMean(pow_1, [-1], noop_with_empty_axes=0, keepdims=1) + add_189 = opset18.Add(mean, 1e-05) + val_252 = opset18.Sqrt(add_189) + rsqrt = opset18.Reciprocal(val_252) + mul_167 = opset18.Mul(embedding, rsqrt) + mul_171 = opset18.Mul(model_layers_0_input_layernorm_weight, mul_167) + val_253 = opset18.Transpose(model_layers_0_self_attn_qkv_proj_weight, perm=[1, 0]) + linear = opset18.MatMul(mul_171, val_253) + val_264 = opset18.Constant(value_ints=[1]) + slice_26 = opset18.Slice(linear, [0], [5120], [2], val_264) + val_275 = opset18.Constant(value_ints=[1]) + slice_27 = opset18.Slice(linear, [5120], [6400], [2], val_275) + val_285 = opset18.Constant(value_ints=[1]) + slice_28 = opset18.Slice(linear, [6400], [9223372036854775807], [2], val_285) + val_291 = opset18.Concat(val_6, val_1, [-1], [128], axis=0) + view_1 = opset18.Reshape(slice_26, val_291, allowzero=1) + transpose_1 = opset18.Transpose(view_1, perm=[0, 2, 1, 3]) + val_297 = opset18.Concat(val_6, val_1, [-1], [128], axis=0) + view_2 = opset18.Reshape(slice_27, val_297, allowzero=1) + transpose_2 = opset18.Transpose(view_2, perm=[0, 2, 1, 3]) + val_303 = opset18.Concat(val_6, val_1, [-1], [128], axis=0) + view_3 = opset18.Reshape(slice_28, val_303, allowzero=1) + transpose_3 = opset18.Transpose(view_3, perm=[0, 2, 1, 3]) + unsqueeze_10 = opset18.Unsqueeze(cos, [1]) + unsqueeze_11 = opset18.Unsqueeze(sin, [1]) + mul_223 = opset18.Mul(transpose_1, unsqueeze_10) + val_336 = opset18.Constant(value_ints=[1]) + slice_31 = opset18.Slice(transpose_1, [0], [64], [3], val_336) + val_346 = opset18.Constant(value_ints=[1]) + slice_32 = opset18.Slice(transpose_1, [64], [9223372036854775807], [3], val_346) + neg = opset18.Neg(slice_32) + cat_1 = opset18.Concat(neg, slice_31, axis=-1) + mul_240 = opset18.Mul(cat_1, unsqueeze_11) + add_304 = opset18.Add(mul_223, mul_240) + mul_252 = opset18.Mul(transpose_2, unsqueeze_10) + val_356 = opset18.Constant(value_ints=[1]) + slice_33 = opset18.Slice(transpose_2, [0], [64], [3], val_356) + val_366 = opset18.Constant(value_ints=[1]) + slice_34 = opset18.Slice(transpose_2, [64], [9223372036854775807], [3], val_366) + neg_1 = opset18.Neg(slice_34) + cat_3 = opset18.Concat(neg_1, slice_33, axis=-1) + mul_269 = opset18.Mul(cat_3, unsqueeze_11) + add_345 = opset18.Add(mul_252, mul_269) + cat_5 = opset18.Concat(past_key_values_key_cache_0, add_345, axis=-2) + cat_6 = opset18.Concat(past_key_values_value_cache_0, transpose_3, axis=-2) + unsqueeze_12 = opset18.Unsqueeze(cat_5, [2]) + val_413 = opset18.Reshape(add_4, [-1], allowzero=0) + val_414 = opset18.Concat(val_6, [10], [4], val_413, [128], axis=0) + val_416 = opset18.Abs(val_414) + expand_3 = opset18.Expand(unsqueeze_12, val_416) + val_421 = opset18.Reshape(add_4, [-1], allowzero=0) + val_422 = opset18.Concat(val_6, [40], val_421, [128], axis=0) + _unsafe_view = opset18.Reshape(expand_3, val_422, allowzero=1) + unsqueeze_13 = opset18.Unsqueeze(cat_6, [2]) + val_467 = opset18.Reshape(add_4, [-1], allowzero=0) + val_468 = opset18.Concat(val_6, [10], [4], val_467, [128], axis=0) + val_470 = opset18.Abs(val_468) + expand_4 = opset18.Expand(unsqueeze_13, val_470) + val_474 = opset18.Reshape(add_4, [-1], allowzero=0) + val_475 = opset18.Concat(val_6, [40], val_474, [128], axis=0) + _unsafe_view_1 = opset18.Reshape(expand_4, val_475, allowzero=1) + transpose_4 = opset18.Transpose(_unsafe_view, perm=[0, 1, 3, 2]) + matmul_1 = opset18.MatMul(add_304, transpose_4) + mul_433 = opset18.Mul(matmul_1, 0.088388346) + val_496 = opset18.Constant(value_ints=[0]) + val_498 = opset18.Constant(value_ints=[-1]) + val_499 = opset18.Reshape(add_4, val_498, allowzero=0) + val_503 = opset18.Constant(value_ints=[1]) + slice_50 = opset18.Slice(slice_scatter_2, val_496, val_499, [3], val_503) + add_491 = opset18.Add(mul_433, slice_50) + val_504 = opset18.Softmax(add_491, axis=-1) + matmul_2 = opset18.MatMul(val_504, _unsafe_view_1) + transpose_5 = opset18.Transpose(matmul_2, perm=[0, 2, 1, 3]) + val_509 = opset18.Concat(val_6, val_1, [-1], axis=0) + view_4 = opset18.Reshape(transpose_5, val_509, allowzero=1) + val_511 = opset18.Transpose(model_layers_0_self_attn_o_proj_weight, perm=[1, 0]) + linear_1 = opset18.MatMul(view_4, val_511) + add_534 = opset18.Add(embedding, linear_1) + pow_2 = opset18.Pow(add_534, 2.0) + mean_1 = opset18.ReduceMean(pow_2, [-1], noop_with_empty_axes=0, keepdims=1) + add_547 = opset18.Add(mean_1, 1e-05) + val_515 = opset18.Sqrt(add_547) + rsqrt_1 = opset18.Reciprocal(val_515) + mul_506 = opset18.Mul(add_534, rsqrt_1) + mul_510 = opset18.Mul(model_layers_0_post_attention_layernorm_weight, mul_506) + val_516 = opset18.Transpose(model_layers_0_mlp_gate_up_proj_weight, perm=[1, 0]) + linear_2 = opset18.MatMul(mul_510, val_516) + split_split_0, split_split_1 = opset18.Split(linear_2, axis=2, num_outputs=2) + val_518 = opset18.Sigmoid(split_split_0) + silu = opset18.Mul(split_split_0, val_518) + mul_526 = opset18.Mul(split_split_1, silu) + val_519 = opset18.Transpose(model_layers_0_mlp_down_proj_weight, perm=[1, 0]) + linear_3 = opset18.MatMul(mul_526, val_519) + add_592 = opset18.Add(add_534, linear_3) + pow_3 = opset18.Pow(add_592, 2.0) + mean_2 = opset18.ReduceMean(pow_3, [-1], noop_with_empty_axes=0, keepdims=1) + add_605 = opset18.Add(mean_2, 1e-05) + val_523 = opset18.Sqrt(add_605) + rsqrt_2 = opset18.Reciprocal(val_523) + mul_548 = opset18.Mul(add_592, rsqrt_2) + mul_552 = opset18.Mul(model_layers_1_input_layernorm_weight, mul_548) + val_524 = opset18.Transpose(model_layers_1_self_attn_qkv_proj_weight, perm=[1, 0]) + linear_4 = opset18.MatMul(mul_552, val_524) + val_534 = opset18.Constant(value_ints=[1]) + slice_51 = opset18.Slice(linear_4, [0], [5120], [2], val_534) + val_544 = opset18.Constant(value_ints=[1]) + slice_52 = opset18.Slice(linear_4, [5120], [6400], [2], val_544) + val_554 = opset18.Constant(value_ints=[1]) + slice_53 = opset18.Slice(linear_4, [6400], [9223372036854775807], [2], val_554) + val_559 = opset18.Concat(val_6, val_1, [-1], [128], axis=0) + view_5 = opset18.Reshape(slice_51, val_559, allowzero=1) + transpose_6 = opset18.Transpose(view_5, perm=[0, 2, 1, 3]) + val_565 = opset18.Concat(val_6, val_1, [-1], [128], axis=0) + view_6 = opset18.Reshape(slice_52, val_565, allowzero=1) + transpose_7 = opset18.Transpose(view_6, perm=[0, 2, 1, 3]) + val_571 = opset18.Concat(val_6, val_1, [-1], [128], axis=0) + view_7 = opset18.Reshape(slice_53, val_571, allowzero=1) + transpose_8 = opset18.Transpose(view_7, perm=[0, 2, 1, 3]) + unsqueeze_14 = opset18.Unsqueeze(cos, [1]) + unsqueeze_15 = opset18.Unsqueeze(sin, [1]) + mul_604 = opset18.Mul(transpose_6, unsqueeze_14) + val_602 = opset18.Constant(value_ints=[1]) + slice_56 = opset18.Slice(transpose_6, [0], [64], [3], val_602) + val_612 = opset18.Constant(value_ints=[1]) + slice_57 = opset18.Slice(transpose_6, [64], [9223372036854775807], [3], val_612) + neg_2 = opset18.Neg(slice_57) + cat_7 = opset18.Concat(neg_2, slice_56, axis=-1) + mul_621 = opset18.Mul(cat_7, unsqueeze_15) + add_720 = opset18.Add(mul_604, mul_621) + mul_633 = opset18.Mul(transpose_7, unsqueeze_14) + val_622 = opset18.Constant(value_ints=[1]) + slice_58 = opset18.Slice(transpose_7, [0], [64], [3], val_622) + val_632 = opset18.Constant(value_ints=[1]) + slice_59 = opset18.Slice(transpose_7, [64], [9223372036854775807], [3], val_632) + neg_3 = opset18.Neg(slice_59) + cat_9 = opset18.Concat(neg_3, slice_58, axis=-1) + mul_650 = opset18.Mul(cat_9, unsqueeze_15) + add_761 = opset18.Add(mul_633, mul_650) + cat_11 = opset18.Concat(past_key_values_key_cache_1, add_761, axis=-2) + cat_12 = opset18.Concat(past_key_values_value_cache_1, transpose_8, axis=-2) + unsqueeze_16 = opset18.Unsqueeze(cat_11, [2]) + val_676 = opset18.Reshape(add_4, [-1], allowzero=0) + val_677 = opset18.Concat(val_6, [10], [4], val_676, [128], axis=0) + val_679 = opset18.Abs(val_677) + expand_5 = opset18.Expand(unsqueeze_16, val_679) + val_683 = opset18.Reshape(add_4, [-1], allowzero=0) + val_684 = opset18.Concat(val_6, [40], val_683, [128], axis=0) + _unsafe_view_2 = opset18.Reshape(expand_5, val_684, allowzero=1) + unsqueeze_17 = opset18.Unsqueeze(cat_12, [2]) + val_729 = opset18.Reshape(add_4, [-1], allowzero=0) + val_730 = opset18.Concat(val_6, [10], [4], val_729, [128], axis=0) + val_732 = opset18.Abs(val_730) + expand_6 = opset18.Expand(unsqueeze_17, val_732) + val_736 = opset18.Reshape(add_4, [-1], allowzero=0) + val_737 = opset18.Concat(val_6, [40], val_736, [128], axis=0) + _unsafe_view_3 = opset18.Reshape(expand_6, val_737, allowzero=1) + transpose_9 = opset18.Transpose(_unsafe_view_2, perm=[0, 1, 3, 2]) + matmul_3 = opset18.MatMul(add_720, transpose_9) + mul_814 = opset18.Mul(matmul_3, 0.088388346) + val_757 = opset18.Constant(value_ints=[0]) + val_759 = opset18.Constant(value_ints=[-1]) + val_760 = opset18.Reshape(add_4, val_759, allowzero=0) + val_764 = opset18.Constant(value_ints=[1]) + slice_75 = opset18.Slice(slice_scatter_2, val_757, val_760, [3], val_764) + add_907 = opset18.Add(mul_814, slice_75) + val_765 = opset18.Softmax(add_907, axis=-1) + matmul_4 = opset18.MatMul(val_765, _unsafe_view_3) + transpose_10 = opset18.Transpose(matmul_4, perm=[0, 2, 1, 3]) + val_770 = opset18.Concat(val_6, val_1, [-1], axis=0) + view_8 = opset18.Reshape(transpose_10, val_770, allowzero=1) + val_772 = opset18.Transpose(model_layers_1_self_attn_o_proj_weight, perm=[1, 0]) + linear_5 = opset18.MatMul(view_8, val_772) + add_950 = opset18.Add(add_592, linear_5) + pow_4 = opset18.Pow(add_950, 2.0) + mean_3 = opset18.ReduceMean(pow_4, [-1], noop_with_empty_axes=0, keepdims=1) + add_963 = opset18.Add(mean_3, 1e-05) + val_776 = opset18.Sqrt(add_963) + rsqrt_3 = opset18.Reciprocal(val_776) + mul_887 = opset18.Mul(add_950, rsqrt_3) + mul_891 = opset18.Mul(model_layers_1_post_attention_layernorm_weight, mul_887) + val_777 = opset18.Transpose(model_layers_1_mlp_gate_up_proj_weight, perm=[1, 0]) + linear_6 = opset18.MatMul(mul_891, val_777) + split_1_split_0, split_1_split_1 = opset18.Split(linear_6, axis=2, num_outputs=2) + val_778 = opset18.Sigmoid(split_1_split_0) + silu_1 = opset18.Mul(split_1_split_0, val_778) + mul_907 = opset18.Mul(split_1_split_1, silu_1) + val_779 = opset18.Transpose(model_layers_1_mlp_down_proj_weight, perm=[1, 0]) + linear_7 = opset18.MatMul(mul_907, val_779) + add_1008 = opset18.Add(add_950, linear_7) + pow_5 = opset18.Pow(add_1008, 2.0) + mean_4 = opset18.ReduceMean(pow_5, [-1], noop_with_empty_axes=0, keepdims=1) + add_1021 = opset18.Add(mean_4, 1e-05) + val_783 = opset18.Sqrt(add_1021) + rsqrt_4 = opset18.Reciprocal(val_783) + mul_929 = opset18.Mul(add_1008, rsqrt_4) + mul_933 = opset18.Mul(model_norm_weight, mul_929) + val_804 = opset18.Transpose(lm_head_weight, perm=[1, 0]) + linear_8 = opset18.MatMul(mul_933, val_804) + return linear_8, cat_5, cat_11, cat_6, cat_12 + + model = main_graph.to_model_proto(value_infos=value_infos) + return model + + +def make_model_with_random_weights(): + model_embed_tokens_weight = numpy.random.rand(100352, 5120).astype(numpy.float32) + model_layers_0_self_attn_o_proj_weight = numpy.random.rand(5120, 5120).astype( + numpy.float32 + ) + model_layers_0_self_attn_qkv_proj_weight = numpy.random.rand(7680, 5120).astype( + numpy.float32 + ) + model_layers_0_mlp_gate_up_proj_weight = numpy.random.rand(35840, 5120).astype( + numpy.float32 + ) + model_layers_0_mlp_down_proj_weight = numpy.random.rand(5120, 17920).astype(numpy.float32) + model_layers_0_input_layernorm_weight = numpy.random.rand(5120).astype(numpy.float32) + model_layers_0_post_attention_layernorm_weight = numpy.random.rand(5120).astype( + numpy.float32 + ) + model_layers_1_self_attn_o_proj_weight = numpy.random.rand(5120, 5120).astype( + numpy.float32 + ) + model_layers_1_self_attn_qkv_proj_weight = numpy.random.rand(7680, 5120).astype( + numpy.float32 + ) + model_layers_1_mlp_gate_up_proj_weight = numpy.random.rand(35840, 5120).astype( + numpy.float32 + ) + model_layers_1_mlp_down_proj_weight = numpy.random.rand(5120, 17920).astype(numpy.float32) + model_layers_1_input_layernorm_weight = numpy.random.rand(5120).astype(numpy.float32) + model_layers_1_post_attention_layernorm_weight = numpy.random.rand(5120).astype( + numpy.float32 + ) + model_norm_weight = numpy.random.rand(5120).astype(numpy.float32) + lm_head_weight = numpy.random.rand(100352, 5120).astype(numpy.float32) + expand_2 = numpy.random.rand(1, 64, 1).astype(numpy.float32) + model = make_model( + model_embed_tokens_weight, + model_layers_0_self_attn_o_proj_weight, + model_layers_0_self_attn_qkv_proj_weight, + model_layers_0_mlp_gate_up_proj_weight, + model_layers_0_mlp_down_proj_weight, + model_layers_0_input_layernorm_weight, + model_layers_0_post_attention_layernorm_weight, + model_layers_1_self_attn_o_proj_weight, + model_layers_1_self_attn_qkv_proj_weight, + model_layers_1_mlp_gate_up_proj_weight, + model_layers_1_mlp_down_proj_weight, + model_layers_1_input_layernorm_weight, + model_layers_1_post_attention_layernorm_weight, + model_norm_weight, + lm_head_weight, + expand_2, + ) + return model + + +class _Phi4LMTest: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + +def phi4lm_test(): + return _Phi4LMTest() diff --git a/onnxscript/rewriter/models/_rotary_embedding_models.py b/onnxscript/rewriter/models/_rotary_embedding_models.py new file mode 100644 index 0000000000..ecdb7d138b --- /dev/null +++ b/onnxscript/rewriter/models/_rotary_embedding_models.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Small test case models for rotary embedding.""" + +import numpy +import onnx_ir as ir + +from onnxscript import script +from onnxscript.onnx_opset import opset18 as op +from onnxscript.onnx_types import FLOAT, INT64 + +# A simple rotary embedding example + + +# x: [B, H, S, E] +# position_ids: [B, S] +@script() +def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOAT[1, 4, 8, 8]: + inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0]) + inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2]) + position_ids_expanded = op.Unsqueeze(position_ids, [1]) # => [B, 1, S] + position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + freqs = op.MatMul(inv_freq_3d, position_ids_float) # [B, E, S] + freqs = op.Transpose(freqs, perm=[0, 2, 1]) # [B, S, E] + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(emb) + sin = op.Sin(emb) + cos_4d = op.Unsqueeze(cos, 1) + sin_4d = op.Unsqueeze(sin, 1) + + x1 = op.Slice(x, [0], [4], [3], [1]) + x2 = op.Slice(x, [4], [8], [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + rotary_embedding = op.Add(x * cos_4d, rotated_x * sin_4d) + return rotary_embedding + + +class _TestCase1: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = _test_case_1_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32), + "position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def test_case_1(): + return _TestCase1() + + +# A simple rotary embedding example with 1D position_ids +# x: [B, H, S, E] +# position_ids: [S] +@script() +def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1, 4, 8, 8]: + inv_freq = op.Constant(value_floats=[1.0, 2.0, 3.0, 4.0]) + inv_freq_3d = op.Unsqueeze(inv_freq, [0, 2]) + position_ids_expanded = op.Unsqueeze(position_ids, [0, 1]) # => [1, 1, S] + position_ids_float = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + freqs = op.MatMul(inv_freq_3d, position_ids_float) # [B, E, S] + freqs = op.Transpose(freqs, perm=[0, 2, 1]) # [B, S, E] + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(emb) + sin = op.Sin(emb) + cos_4d = op.Unsqueeze(cos, 1) + sin_4d = op.Unsqueeze(sin, 1) + + x1 = op.Slice(x, [0], [4], [3], [1]) + x2 = op.Slice(x, [4], [8], [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + rotary_embedding = op.Add(x * cos_4d, rotated_x * sin_4d) + return rotary_embedding + + +class _TestCase2: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = _test_case_2_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "x": numpy.random.rand(1, 4, 8, 8).astype(numpy.float32), + "position_ids": numpy.arange(8, dtype=numpy.int64).reshape(8), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def test_case_2(): + return _TestCase2() + + +# A partial rotary embedding example: + +rotary_embedding_dim = 32 # Abbreviated as "rd" in shape descriptors below +half_rotary_embedding_dim = rotary_embedding_dim // 2 +# A random inverse frequency tensor for the sake of this example. +inv_freqs_value = numpy.random.rand(1, half_rotary_embedding_dim, 1).astype(numpy.float32) + + +@script() +def _partial_rotary_script(position_ids, query): + inv_freqs = op.Constant(value=inv_freqs_value) # [1, rd/2, 1] + position_ids_3d = op.Unsqueeze(position_ids, 1) # [B, 1, S] + position_ids_3d_float = op.Cast(position_ids_3d, to=1) + matmul = op.MatMul(inv_freqs, position_ids_3d_float) # [B, rd/2, S] + transpose = op.Transpose(matmul, perm=[0, 2, 1]) # [B, S, rd/2] + cat = op.Concat(transpose, transpose, axis=-1) # [B, S, rd] + cos_3d = op.Cos(cat) # [B, S, rd] + sin_3d = op.Sin(cat) # [B, S, rd] + # Split the query for partial embedding + to_embed = op.Slice(query, [0], [32], [3], [1]) + unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) + cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd] + sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd] + # Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X) + # essentially represents X rotated by 90 degrees + to_embed_times_cos = op.Mul(to_embed, cos_4d) + to_embed_x = op.Slice(to_embed, [0], [16], [3], [1]) + to_embed_y = op.Slice(to_embed, [16], [9223372036854775807], [3], [1]) + minus_to_embed_y = op.Neg(to_embed_y) + to_embed_rotated_90 = op.Concat(minus_to_embed_y, to_embed_x, axis=-1) + to_embed_rotated_90_times_sin = op.Mul(to_embed_rotated_90, sin_4d) + embedded = op.Add(to_embed_times_cos, to_embed_rotated_90_times_sin) + final = op.Concat(embedded, unembedded, axis=-1) + return final + + +class _PartialRotaryTestCase: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = _partial_rotary_script.to_model_proto( + input_types=( + INT64["Batchsize", "Sequence"], + FLOAT["Batchsize", 32, "Sequence", 80], + ), + output_types=(FLOAT["Batchsize", 32, "Sequence", 80],), + ) + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "query": numpy.random.rand(1, 32, 8, 80).astype(numpy.float32), + "position_ids": numpy.arange(8, dtype=numpy.int64).reshape(1, 8), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def partial_rotary_test_case(): + return _PartialRotaryTestCase() diff --git a/onnxscript/rewriter/models/_smollm_1.py b/onnxscript/rewriter/models/_smollm_1.py new file mode 100644 index 0000000000..d592eb2572 --- /dev/null +++ b/onnxscript/rewriter/models/_smollm_1.py @@ -0,0 +1,256 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +A one-layer SmolLM model test case, with inputs: input_ids, attention_mask, and position_ids. +This is an onnxscript version of the model. +""" + +import numpy as np +import onnx_ir as ir + +from onnxscript import script +from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import FLOAT, INT64 + + +def make_model( + input_layernorm_weight_0, + post_attention_layernorm_weight0, + norm_weight, + head_weight, + self_attn_q_proj_weight0, + self_attn_k_proj_weight0, + self_attn_v_proj_weight0, + self_attn_o_proj_weight0, + mlp_gate_proj_weight0, + mlp_up_proj_weight0, + mlp_down_proj_weight0, +): + @script() + def main_graph( + input0: INT64[1, 10], input1: FLOAT[1, 10], input2: INT64[1, 10] + ) -> (FLOAT[1, 10, 49152], FLOAT[1, 32, 10, 64], FLOAT[1, 32, 10, 64]): + model_layers_0_input_layernorm_weight = opset18.Constant( + value=input_layernorm_weight_0 + ) + model_layers_0_post_attention_layernorm_weight = opset18.Constant( + value=post_attention_layernorm_weight0 + ) + model_norm_weight = opset18.Constant(value=norm_weight) + lm_head_weight = opset18.Constant(value=head_weight) + model_layers_0_self_attn_q_proj_weight = opset18.Constant( + value=self_attn_q_proj_weight0 + ) + model_layers_0_self_attn_k_proj_weight = opset18.Constant( + value=self_attn_k_proj_weight0 + ) + model_layers_0_self_attn_v_proj_weight = opset18.Constant( + value=self_attn_v_proj_weight0 + ) + model_layers_0_self_attn_o_proj_weight = opset18.Constant( + value=self_attn_o_proj_weight0 + ) + model_layers_0_mlp_gate_proj_weight = opset18.Constant(value=mlp_gate_proj_weight0) + model_layers_0_mlp_up_proj_weight = opset18.Constant(value=mlp_up_proj_weight0) + model_layers_0_mlp_down_proj_weight = opset18.Constant(value=mlp_down_proj_weight0) + + embedding = opset18.Gather(lm_head_weight, input0, axis=0) + minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38]) + mask_10x10 = opset18.Trilu(minus_inf_10x10, 1) + slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10]) + unsqueeze_2 = opset18.Unsqueeze(input1, 1) + unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2) + add = slice_5 + unsqueeze_3 + eq = add == 0.0 + slice_10 = slice_5 + masked_fill = opset18.Where(eq, -3.4028235e38, slice_10) + val_179 = opset18.Transpose(masked_fill, perm=[2, 1, 0, 3]) + slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3]) + val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3]) + slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3]) + unsqueeze_6 = opset18.Unsqueeze(input2, 1) + to_copy_1 = opset18.Cast(unsqueeze_6, to=1) + view_1 = opset18.Constant( + value=ir.tensor( + np.array( + [ + 1.0, + 0.7498942017555237, + 0.5623413324356079, + 0.4216965138912201, + 0.3162277638912201, + 0.23713736236095428, + 0.17782793939113617, + 0.1333521455526352, + 0.10000000149011612, + 0.07498941570520401, + 0.05623412877321243, + 0.04216964915394783, + 0.03162277862429619, + 0.0237137358635664, + 0.017782794311642647, + 0.01333521492779255, + 0.009999999776482582, + 0.007498942315578461, + 0.005623413249850273, + 0.0042169648222625256, + 0.003162277862429619, + 0.0023713738191872835, + 0.0017782794311642647, + 0.0013335214462131262, + 0.0010000000474974513, + 0.0007498941849917173, + 0.000562341301701963, + 0.00042169648804701865, + 0.0003162277862429619, + 0.0002371373848291114, + 0.00017782794020604342, + 0.0001333521504420787, + ], + dtype=np.float32, + ).reshape([1, 32, 1]) + ) + ) + view_2 = opset18.Reshape(to_copy_1, [1, 1, 10], allowzero=0) + bmm = view_1 @ view_2 + view_3 = opset18.Reshape(bmm, [1, 32, 10], allowzero=0) + transpose = opset18.Transpose(view_3, perm=[0, 2, 1]) + cat = opset18.Concat(transpose, transpose, axis=-1) + cos = opset18.Cos(cat) + sin = opset18.Sin(cat) + pow_1 = embedding**2.0 + mean = opset18.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0) + add_1 = mean + 1e-05 + val_244 = opset18.Sqrt(add_1) + rsqrt = opset18.Reciprocal(val_244) + mul_3 = embedding * rsqrt + mul_4 = model_layers_0_input_layernorm_weight * mul_3 + t = opset18.Transpose(model_layers_0_self_attn_q_proj_weight, perm=[1, 0]) + view_5 = mul_4 @ t + t_1 = opset18.Transpose(model_layers_0_self_attn_k_proj_weight, perm=[1, 0]) + view_7 = mul_4 @ t_1 + t_2 = opset18.Transpose(model_layers_0_self_attn_v_proj_weight, perm=[1, 0]) + view_9 = mul_4 @ t_2 + view_10 = opset18.Reshape(view_5, [1, 10, 32, 64], allowzero=0) + transpose_1 = opset18.Transpose(view_10, perm=[0, 2, 1, 3]) + view_11 = opset18.Reshape(view_7, [1, 10, 32, 64], allowzero=0) + transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3]) + view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0) + transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) + unsqueeze_7 = opset18.Unsqueeze(cos, 1) + unsqueeze_8 = opset18.Unsqueeze(sin, 1) + mul_5 = transpose_1 * unsqueeze_7 + val_267 = opset18.Constant(value_ints=[1]) + slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267) + val_277 = opset18.Constant(value_ints=[1]) + slice_20 = opset18.Slice(transpose_1, [32], [9223372036854775807], [3], val_277) + neg = opset18.Neg(slice_20) + cat_1 = opset18.Concat(neg, slice_19, axis=-1) + mul_6 = cat_1 * unsqueeze_8 + add_2 = mul_5 + mul_6 + mul_7 = transpose_2 * unsqueeze_7 + val_287 = opset18.Constant(value_ints=[1]) + slice_21 = opset18.Slice(transpose_2, [0], [32], [3], val_287) + val_297 = opset18.Constant(value_ints=[1]) + slice_22 = opset18.Slice(transpose_2, [32], [9223372036854775807], [3], val_297) + neg_1 = opset18.Neg(slice_22) + cat_2 = opset18.Concat(neg_1, slice_21, axis=-1) + mul_8 = cat_2 * unsqueeze_8 + add_3 = mul_7 + mul_8 + val_346 = opset18.Reshape(add_3, [-1, 10, 64], allowzero=0) + val_347 = opset18.Transpose(val_346, perm=[0, 2, 1]) + val_349 = opset18.Reshape(val_347, [1, 32, 64, 10], allowzero=0) + val_351 = add_2 * [0.35355338] + val_353 = val_349 * [0.35355338] + val_354 = val_351 @ val_353 + val_355 = val_354 + slice_scatter_1 + val_356 = opset18.Softmax(val_355, axis=-1) + getitem = val_356 @ transpose_3 + transpose_4 = opset18.Transpose(getitem, perm=[0, 2, 1, 3]) + view_13 = opset18.Reshape(transpose_4, [1, 10, -1], allowzero=0) + t_3 = opset18.Transpose(model_layers_0_self_attn_o_proj_weight, perm=[1, 0]) + view_15 = view_13 @ t_3 + add_4 = embedding + view_15 + pow_2 = add_4**2.0 + mean_1 = opset18.ReduceMean(pow_2, [-1], keepdims=1, noop_with_empty_axes=0) + add_5 = mean_1 + 1e-05 + val_379 = opset18.Sqrt(add_5) + rsqrt_1 = opset18.Reciprocal(val_379) + mul_9 = add_4 * rsqrt_1 + mul_10 = model_layers_0_post_attention_layernorm_weight * mul_9 + t_4 = opset18.Transpose(model_layers_0_mlp_gate_proj_weight, perm=[1, 0]) + view_17 = mul_10 @ t_4 + val_383 = opset18.Sigmoid(view_17) + silu = view_17 * val_383 + t_5 = opset18.Transpose(model_layers_0_mlp_up_proj_weight, perm=[1, 0]) + view_19 = mul_10 @ t_5 + mul_11 = silu * view_19 + t_6 = opset18.Transpose(model_layers_0_mlp_down_proj_weight, perm=[1, 0]) + view_21 = mul_11 @ t_6 + add_6 = add_4 + view_21 + pow_3 = add_6**2.0 + mean_2 = opset18.ReduceMean(pow_3, [-1], keepdims=1, noop_with_empty_axes=0) + add_7 = mean_2 + 1e-05 + val_391 = opset18.Sqrt(add_7) + rsqrt_2 = opset18.Reciprocal(val_391) + mul_12 = add_6 * rsqrt_2 + mul_13 = model_norm_weight * mul_12 + t_7 = opset18.Transpose(lm_head_weight, perm=[1, 0]) + view_23 = mul_13 @ t_7 + to_copy_12 = opset18.Identity(view_23) + return to_copy_12, add_3, transpose_3 + + model = main_graph.to_model_proto() + return model + + +def make_model_with_random_weights(): + input_layernorm_weight_0 = np.random.rand(2048).astype(np.float32) + post_attention_layernorm_weight0 = np.random.rand(2048).astype(np.float32) + norm_weight = np.random.rand(2048).astype(np.float32) + head_weight = np.random.rand(49152, 2048).astype(np.float32) + self_attn_q_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + self_attn_k_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + self_attn_v_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + self_attn_o_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + mlp_gate_proj_weight0 = np.random.rand(8192, 2048).astype(np.float32) + mlp_up_proj_weight0 = np.random.rand(8192, 2048).astype(np.float32) + mlp_down_proj_weight0 = np.random.rand(2048, 8192).astype(np.float32) + model = make_model( + ir.tensor(input_layernorm_weight_0), + ir.tensor(post_attention_layernorm_weight0), + ir.tensor(norm_weight), + ir.tensor(head_weight), + ir.tensor(self_attn_q_proj_weight0), + ir.tensor(self_attn_k_proj_weight0), + ir.tensor(self_attn_v_proj_weight0), + ir.tensor(self_attn_o_proj_weight0), + ir.tensor(mlp_gate_proj_weight0), + ir.tensor(mlp_up_proj_weight0), + ir.tensor(mlp_down_proj_weight0), + ) + return model + + +class _SmollmTest1: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "input0": np.random.randint(0, 49152, (1, 10)).astype(np.int64), + "input1": np.ones((1, 10), dtype=np.float32), + "input2": np.arange(10, dtype=np.int64).reshape(1, 10), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def smollm_test_1(): + return _SmollmTest1() diff --git a/onnxscript/rewriter/models/_smollm_2.py b/onnxscript/rewriter/models/_smollm_2.py new file mode 100644 index 0000000000..62d857a2d6 --- /dev/null +++ b/onnxscript/rewriter/models/_smollm_2.py @@ -0,0 +1,471 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +A one-layer SmolLM model test case, with inputs: input_ids, position_ids, and pask key/values. +This is an onnxscript version of the model. +""" + +import numpy +import onnx_ir as ir + +from onnxscript import script +from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import FLOAT, INT64 + + +def make_model( + model_layers_0_input_layernorm_weight, + model_layers_0_post_attention_layernorm_weight, + model_norm_weight, + lm_head_weight, + model_layers_0_self_attn_q_proj_weight, + model_layers_0_self_attn_k_proj_weight, + model_layers_0_self_attn_v_proj_weight, + model_layers_0_self_attn_o_proj_weight, + model_layers_0_mlp_gate_proj_weight, + model_layers_0_mlp_up_proj_weight, + model_layers_0_mlp_down_proj_weight, + model_rotary_emb_inv_freq, +): + @script() + def main_graph( + input_ids: INT64[1, 30], + position_ids: INT64[1, 30], + past_key_values_0_0: FLOAT[1, 32, 16, 64], + past_key_values_0_1: FLOAT[1, 32, 16, 64], + ) -> (FLOAT[1, 30, 49152], FLOAT[1, 32, 46, 64], FLOAT[1, 32, 46, 64]): + embedding = opset18.Gather(lm_head_weight, input_ids, axis=0) + val_2 = opset18.CastLike(1.0, 46) + arange = opset18.Range(16, 46, val_2) + val_5 = opset18.Cast(-3.4028235e38, to=1) + val_7 = opset18.Cast([30, 47], to=7) + full = opset18.Expand(val_5, val_7) + diagonal__1 = opset18.Constant(value_int=1) + triu = opset18.Trilu(full, diagonal__1, upper=1) + val_10 = opset18.CastLike(0.0, 47) + val_11 = opset18.CastLike(1.0, 47) + arange_1 = opset18.Range(val_10, 47, val_11) + val_13 = opset18.Cast([-1, 1], to=7) + view = opset18.Reshape(arange, val_13, allowzero=0) + gt = arange_1 > view + convert_element_type_default = opset18.Cast(gt, to=1) + mul = triu * convert_element_type_default + dim__2 = opset18.Constant(value_int=0) + dim_0__2 = opset18.Cast(dim__2, to=7) + unsqueeze = opset18.Unsqueeze(model_rotary_emb_inv_freq, dim_0__2) + val_15 = opset18.Cast(0, to=7) + val_16 = opset18.Constant(value_ints=[-1]) + val_17 = opset18.Reshape(val_15, val_16, allowzero=0) + val_19 = opset18.Cast(9223372036854775807, to=7) + val_20 = opset18.Constant(value_ints=[-1]) + val_21 = opset18.Reshape(val_19, val_20, allowzero=0) + val_23 = opset18.Cast(1, to=7) + val_24 = opset18.Constant(value_ints=[-1]) + val_25 = opset18.Reshape(val_23, val_24, allowzero=0) + val_26 = opset18.Constant(value_ints=[1]) + slice_1 = opset18.Slice(unsqueeze, val_17, val_21, val_25, val_26) + dim__3 = opset18.Constant(value_int=2) + dim_0__3 = opset18.Cast(dim__3, to=7) + unsqueeze_1 = opset18.Unsqueeze(slice_1, dim_0__3) + _to_copy = opset18.Cast(unsqueeze_1, to=1) + size_0__4 = opset18.Cast([1, -1, 1], to=7) + size_1__4 = opset18.Abs(size_0__4) + expand = opset18.Expand(_to_copy, size_1__4) + val_28 = opset18.Cast(0, to=7) + val_29 = opset18.Constant(value_ints=[-1]) + val_30 = opset18.Reshape(val_28, val_29, allowzero=0) + val_31 = opset18.Cast(9223372036854775807, to=7) + val_32 = opset18.Constant(value_ints=[-1]) + val_33 = opset18.Reshape(val_31, val_32, allowzero=0) + val_34 = opset18.Cast(0, to=7) + val_35 = opset18.Constant(value_ints=[-1]) + val_36 = opset18.Reshape(val_34, val_35, allowzero=0) + val_37 = opset18.Constant(value_ints=[1]) + slice_2 = opset18.Slice(position_ids, val_30, val_33, val_36, val_37) + dim__5 = opset18.Constant(value_int=1) + dim_0__5 = opset18.Cast(dim__5, to=7) + unsqueeze_2 = opset18.Unsqueeze(slice_2, dim_0__5) + val_38 = opset18.Cast(0, to=7) + val_39 = opset18.Constant(value_ints=[-1]) + val_40 = opset18.Reshape(val_38, val_39, allowzero=0) + val_41 = opset18.Cast(9223372036854775807, to=7) + val_42 = opset18.Constant(value_ints=[-1]) + val_43 = opset18.Reshape(val_41, val_42, allowzero=0) + val_45 = opset18.Cast(2, to=7) + val_46 = opset18.Constant(value_ints=[-1]) + val_47 = opset18.Reshape(val_45, val_46, allowzero=0) + val_48 = opset18.Constant(value_ints=[1]) + slice_3 = opset18.Slice(unsqueeze_2, val_40, val_43, val_47, val_48) + _to_copy_1 = opset18.Cast(slice_3, to=1) + _to_copy_2 = opset18.Cast(expand, to=1) + _to_copy_3 = opset18.Cast(_to_copy_1, to=1) + size_0__6 = opset18.Cast([1, 32, 1], to=7) + size_1__6 = opset18.Abs(size_0__6) + expand_1 = opset18.Expand(_to_copy_2, size_1__6) + val_50 = opset18.Cast([1, 32, 1], to=7) + view_1 = opset18.Reshape(expand_1, val_50, allowzero=0) + size_0__7 = opset18.Cast([1, 1, 30], to=7) + size_1__7 = opset18.Abs(size_0__7) + expand_2 = opset18.Expand(_to_copy_3, size_1__7) + val_52 = opset18.Cast([1, 1, 30], to=7) + view_2 = opset18.Reshape(expand_2, val_52, allowzero=0) + bmm = view_1 @ view_2 + val_54 = opset18.Cast([1, 32, 30], to=7) + view_3 = opset18.Reshape(bmm, val_54, allowzero=0) + transpose = opset18.Transpose(view_3, perm=[0, 2, 1]) + cat = opset18.Concat(transpose, transpose, axis=-1) + cos = opset18.Cos(cat) + sin = opset18.Sin(cat) + mul_1 = cos * 1.0 + mul_2 = sin * 1.0 + _to_copy_4 = opset18.Cast(mul_1, to=1) + _to_copy_5 = opset18.Cast(mul_2, to=1) + _to_copy_6 = opset18.Cast(embedding, to=1) + scalar_tensor_default = opset18.Cast(2, to=1) + pow_1 = _to_copy_6**scalar_tensor_default + val_55 = opset18.Constant(value_ints=[-1]) + val_57 = opset18.Reshape([-1], val_55, allowzero=0) + mean = opset18.ReduceMean(pow_1, val_57, keepdims=1, noop_with_empty_axes=0) + add = mean + 1e-05 + val_59 = opset18.Sqrt(add) + rsqrt = opset18.Reciprocal(val_59) + mul_3 = _to_copy_6 * rsqrt + _to_copy_7 = opset18.Cast(mul_3, to=1) + mul_4 = model_layers_0_input_layernorm_weight * _to_copy_7 + t = opset18.Transpose(model_layers_0_self_attn_q_proj_weight, perm=[1, 0]) + val_61 = opset18.Cast([30, 2048], to=7) + view_4 = opset18.Reshape(mul_4, val_61, allowzero=0) + mm = view_4 @ t + val_63 = opset18.Cast([1, 30, 2048], to=7) + view_5 = opset18.Reshape(mm, val_63, allowzero=0) + t_1 = opset18.Transpose(model_layers_0_self_attn_k_proj_weight, perm=[1, 0]) + val_64 = opset18.Cast([30, 2048], to=7) + view_6 = opset18.Reshape(mul_4, val_64, allowzero=0) + mm_1 = view_6 @ t_1 + val_65 = opset18.Cast([1, 30, 2048], to=7) + view_7 = opset18.Reshape(mm_1, val_65, allowzero=0) + t_2 = opset18.Transpose(model_layers_0_self_attn_v_proj_weight, perm=[1, 0]) + val_66 = opset18.Cast([30, 2048], to=7) + view_8 = opset18.Reshape(mul_4, val_66, allowzero=0) + mm_2 = view_8 @ t_2 + val_67 = opset18.Cast([1, 30, 2048], to=7) + view_9 = opset18.Reshape(mm_2, val_67, allowzero=0) + val_69 = opset18.Cast([1, 30, 32, 64], to=7) + view_10 = opset18.Reshape(view_5, val_69, allowzero=0) + transpose_1 = opset18.Transpose(view_10, perm=[0, 2, 1, 3]) + val_70 = opset18.Cast([1, 30, 32, 64], to=7) + view_11 = opset18.Reshape(view_7, val_70, allowzero=0) + transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3]) + val_71 = opset18.Cast([1, 30, 32, 64], to=7) + view_12 = opset18.Reshape(view_9, val_71, allowzero=0) + transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) + dim__8 = opset18.Constant(value_int=1) + dim_0__8 = opset18.Cast(dim__8, to=7) + unsqueeze_3 = opset18.Unsqueeze(_to_copy_4, dim_0__8) + dim__9 = opset18.Constant(value_int=1) + dim_0__9 = opset18.Cast(dim__9, to=7) + unsqueeze_4 = opset18.Unsqueeze(_to_copy_5, dim_0__9) + mul_5 = transpose_1 * unsqueeze_3 + val_72 = opset18.Cast(0, to=7) + val_73 = opset18.Constant(value_ints=[-1]) + val_74 = opset18.Reshape(val_72, val_73, allowzero=0) + val_76 = opset18.Cast(32, to=7) + val_77 = opset18.Constant(value_ints=[-1]) + val_78 = opset18.Reshape(val_76, val_77, allowzero=0) + val_80 = opset18.Cast(3, to=7) + val_81 = opset18.Constant(value_ints=[-1]) + val_82 = opset18.Reshape(val_80, val_81, allowzero=0) + val_83 = opset18.Constant(value_ints=[1]) + slice_4 = opset18.Slice(transpose_1, val_74, val_78, val_82, val_83) + val_84 = opset18.Cast(32, to=7) + val_85 = opset18.Constant(value_ints=[-1]) + val_86 = opset18.Reshape(val_84, val_85, allowzero=0) + val_87 = opset18.Cast(9223372036854775807, to=7) + val_88 = opset18.Constant(value_ints=[-1]) + val_89 = opset18.Reshape(val_87, val_88, allowzero=0) + val_90 = opset18.Cast(3, to=7) + val_91 = opset18.Constant(value_ints=[-1]) + val_92 = opset18.Reshape(val_90, val_91, allowzero=0) + val_93 = opset18.Constant(value_ints=[1]) + slice_5 = opset18.Slice(transpose_1, val_86, val_89, val_92, val_93) + neg = opset18.Neg(slice_5) + cat_1 = opset18.Concat(neg, slice_4, axis=-1) + mul_6 = cat_1 * unsqueeze_4 + add_1 = mul_5 + mul_6 + mul_7 = transpose_2 * unsqueeze_3 + val_94 = opset18.Cast(0, to=7) + val_95 = opset18.Constant(value_ints=[-1]) + val_96 = opset18.Reshape(val_94, val_95, allowzero=0) + val_97 = opset18.Cast(32, to=7) + val_98 = opset18.Constant(value_ints=[-1]) + val_99 = opset18.Reshape(val_97, val_98, allowzero=0) + val_100 = opset18.Cast(3, to=7) + val_101 = opset18.Constant(value_ints=[-1]) + val_102 = opset18.Reshape(val_100, val_101, allowzero=0) + val_103 = opset18.Constant(value_ints=[1]) + slice_6 = opset18.Slice(transpose_2, val_96, val_99, val_102, val_103) + val_104 = opset18.Cast(32, to=7) + val_105 = opset18.Constant(value_ints=[-1]) + val_106 = opset18.Reshape(val_104, val_105, allowzero=0) + val_107 = opset18.Cast(9223372036854775807, to=7) + val_108 = opset18.Constant(value_ints=[-1]) + val_109 = opset18.Reshape(val_107, val_108, allowzero=0) + val_110 = opset18.Cast(3, to=7) + val_111 = opset18.Constant(value_ints=[-1]) + val_112 = opset18.Reshape(val_110, val_111, allowzero=0) + val_113 = opset18.Constant(value_ints=[1]) + slice_7 = opset18.Slice(transpose_2, val_106, val_109, val_112, val_113) + neg_1 = opset18.Neg(slice_7) + cat_2 = opset18.Concat(neg_1, slice_6, axis=-1) + mul_8 = cat_2 * unsqueeze_4 + add_2 = mul_7 + mul_8 + cat_3 = opset18.Concat(past_key_values_0_0, add_2, axis=-2) + cat_4 = opset18.Concat(past_key_values_0_1, transpose_3, axis=-2) + dim__10 = opset18.Constant(value_int=0) + dim_0__10 = opset18.Cast(dim__10, to=7) + unsqueeze_5 = opset18.Unsqueeze(mul, dim_0__10) + dim__11 = opset18.Constant(value_int=1) + dim_0__11 = opset18.Cast(dim__11, to=7) + unsqueeze_6 = opset18.Unsqueeze(unsqueeze_5, dim_0__11) + val_114 = opset18.Cast(0, to=7) + val_115 = opset18.Constant(value_ints=[-1]) + val_116 = opset18.Reshape(val_114, val_115, allowzero=0) + val_117 = opset18.Cast(9223372036854775807, to=7) + val_118 = opset18.Constant(value_ints=[-1]) + val_119 = opset18.Reshape(val_117, val_118, allowzero=0) + val_120 = opset18.Cast(2, to=7) + val_121 = opset18.Constant(value_ints=[-1]) + val_122 = opset18.Reshape(val_120, val_121, allowzero=0) + val_123 = opset18.Constant(value_ints=[1]) + slice_8 = opset18.Slice(unsqueeze_6, val_116, val_119, val_122, val_123) + val_124 = opset18.Cast(0, to=7) + val_125 = opset18.Constant(value_ints=[-1]) + val_126 = opset18.Reshape(val_124, val_125, allowzero=0) + val_127 = opset18.Cast(9223372036854775807, to=7) + val_128 = opset18.Constant(value_ints=[-1]) + val_129 = opset18.Reshape(val_127, val_128, allowzero=0) + val_130 = opset18.Cast(3, to=7) + val_131 = opset18.Constant(value_ints=[-1]) + val_132 = opset18.Reshape(val_130, val_131, allowzero=0) + val_133 = opset18.Constant(value_ints=[1]) + slice_9 = opset18.Slice(slice_8, val_126, val_129, val_132, val_133) + size_0__12 = opset18.Cast([1, 1, -1, -1], to=7) + size_1__12 = opset18.Abs(size_0__12) + expand_3 = opset18.Expand(slice_9, size_1__12) + val_135 = opset18.Cast(0, to=7) + val_136 = opset18.Constant(value_ints=[-1]) + val_137 = opset18.Reshape(val_135, val_136, allowzero=0) + val_138 = opset18.Cast(9223372036854775807, to=7) + val_139 = opset18.Constant(value_ints=[-1]) + val_140 = opset18.Reshape(val_138, val_139, allowzero=0) + val_141 = opset18.Cast(0, to=7) + val_142 = opset18.Constant(value_ints=[-1]) + val_143 = opset18.Reshape(val_141, val_142, allowzero=0) + val_144 = opset18.Constant(value_ints=[1]) + slice_10 = opset18.Slice(expand_3, val_137, val_140, val_143, val_144) + val_145 = opset18.Cast(0, to=7) + val_146 = opset18.Constant(value_ints=[-1]) + val_147 = opset18.Reshape(val_145, val_146, allowzero=0) + val_148 = opset18.Cast(9223372036854775807, to=7) + val_149 = opset18.Constant(value_ints=[-1]) + val_150 = opset18.Reshape(val_148, val_149, allowzero=0) + val_151 = opset18.Cast(1, to=7) + val_152 = opset18.Constant(value_ints=[-1]) + val_153 = opset18.Reshape(val_151, val_152, allowzero=0) + val_154 = opset18.Constant(value_ints=[1]) + slice_11 = opset18.Slice(slice_10, val_147, val_150, val_153, val_154) + val_155 = opset18.Cast(0, to=7) + val_156 = opset18.Constant(value_ints=[-1]) + val_157 = opset18.Reshape(val_155, val_156, allowzero=0) + val_158 = opset18.Cast(9223372036854775807, to=7) + val_159 = opset18.Constant(value_ints=[-1]) + val_160 = opset18.Reshape(val_158, val_159, allowzero=0) + val_161 = opset18.Cast(2, to=7) + val_162 = opset18.Constant(value_ints=[-1]) + val_163 = opset18.Reshape(val_161, val_162, allowzero=0) + val_164 = opset18.Constant(value_ints=[1]) + slice_12 = opset18.Slice(slice_11, val_157, val_160, val_163, val_164) + val_165 = opset18.Cast(0, to=7) + val_166 = opset18.Constant(value_ints=[-1]) + val_167 = opset18.Reshape(val_165, val_166, allowzero=0) + val_168 = opset18.Cast(46, to=7) + val_169 = opset18.Constant(value_ints=[-1]) + val_170 = opset18.Reshape(val_168, val_169, allowzero=0) + val_171 = opset18.Cast(3, to=7) + val_172 = opset18.Constant(value_ints=[-1]) + val_173 = opset18.Reshape(val_171, val_172, allowzero=0) + val_174 = opset18.Constant(value_ints=[1]) + slice_13 = opset18.Slice(slice_12, val_167, val_170, val_173, val_174) + val_175 = opset18.Shape(add_1, start=0) + val_176 = opset18.Constant(value_ints=[-1]) + val_177 = opset18.Gather(val_175, val_176, axis=0) + val_178 = opset18.CastLike(val_177, add_1) + val_179 = opset18.Constant(value_float=1.0) + val_180 = opset18.CastLike(val_179, add_1) + val_181 = opset18.Sqrt(val_178) + val_182 = val_180 / val_181 + val_183 = opset18.CastLike(val_182, add_1) + val_184 = opset18.Shape(cat_3, start=0) + val_185 = opset18.Constant(value_ints=[9223372036854775807]) + val_186 = opset18.Slice(val_184, [-1], val_185) + val_188 = opset18.Slice(val_184, [-2], [-1]) + val_189 = opset18.Constant(value_ints=[-9223372036854775808]) + val_190 = opset18.Slice(val_184, val_189, [-2]) + val_191 = opset18.Constant(value_ints=[-1]) + val_192 = opset18.Concat(val_191, val_188, val_186, axis=0) + val_193 = opset18.Reshape(cat_3, val_192, allowzero=0) + val_194 = opset18.Transpose(val_193, perm=[0, 2, 1]) + val_195 = opset18.Concat(val_190, val_186, val_188, axis=0) + val_196 = opset18.Reshape(val_194, val_195, allowzero=0) + val_197 = opset18.Sqrt(val_183) + val_198 = add_1 * val_197 + val_199 = opset18.Sqrt(val_183) + val_200 = val_196 * val_199 + val_201 = val_198 @ val_200 + val_202 = val_201 + slice_13 + val_203 = opset18.Softmax(val_202, axis=-1) + val_204, _unused = opset18.Dropout(val_203, 0.0) + getitem = val_204 @ cat_4 + val_206 = opset18.Shape(add_1, start=0) + val_209 = opset18.Slice(val_206, [0], [1]) + val_211 = opset18.Slice(val_206, [1], [2]) + val_212 = opset18.Slice(val_206, [-2], [-1]) + val_213 = opset18.Cast(val_211, to=1) + val_215 = val_213 / 32.0 + val_216 = opset18.Ceil(val_215) + val_217 = val_216 * 32.0 + val_218 = opset18.Cast(val_217, to=7) + val_219 = opset18.Concat(val_209, val_212, val_218, axis=0) + _scaled_dot_product_flash_attention_for_cpu__1 = opset18.Expand(0.0, val_219) + transpose_4 = opset18.Transpose(getitem, perm=[0, 2, 1, 3]) + val_221 = opset18.Cast([1, 30, -1], to=7) + view_13 = opset18.Reshape(transpose_4, val_221, allowzero=0) + t_3 = opset18.Transpose(model_layers_0_self_attn_o_proj_weight, perm=[1, 0]) + val_222 = opset18.Cast([30, 2048], to=7) + view_14 = opset18.Reshape(view_13, val_222, allowzero=0) + mm_3 = view_14 @ t_3 + val_223 = opset18.Cast([1, 30, 2048], to=7) + view_15 = opset18.Reshape(mm_3, val_223, allowzero=0) + add_3 = embedding + view_15 + _to_copy_8 = opset18.Cast(add_3, to=1) + scalar_tensor_default_1 = opset18.Cast(2, to=1) + pow_2 = _to_copy_8**scalar_tensor_default_1 + val_224 = opset18.Constant(value_ints=[-1]) + val_225 = opset18.Reshape([-1], val_224, allowzero=0) + mean_1 = opset18.ReduceMean(pow_2, val_225, keepdims=1, noop_with_empty_axes=0) + add_4 = mean_1 + 1e-05 + val_226 = opset18.Sqrt(add_4) + rsqrt_1 = opset18.Reciprocal(val_226) + mul_9 = _to_copy_8 * rsqrt_1 + _to_copy_9 = opset18.Cast(mul_9, to=1) + mul_10 = model_layers_0_post_attention_layernorm_weight * _to_copy_9 + t_4 = opset18.Transpose(model_layers_0_mlp_gate_proj_weight, perm=[1, 0]) + val_227 = opset18.Cast([30, 2048], to=7) + view_16 = opset18.Reshape(mul_10, val_227, allowzero=0) + mm_4 = view_16 @ t_4 + val_229 = opset18.Cast([1, 30, 8192], to=7) + view_17 = opset18.Reshape(mm_4, val_229, allowzero=0) + val_230 = opset18.Sigmoid(view_17) + silu = view_17 * val_230 + t_5 = opset18.Transpose(model_layers_0_mlp_up_proj_weight, perm=[1, 0]) + val_231 = opset18.Cast([30, 2048], to=7) + view_18 = opset18.Reshape(mul_10, val_231, allowzero=0) + mm_5 = view_18 @ t_5 + val_232 = opset18.Cast([1, 30, 8192], to=7) + view_19 = opset18.Reshape(mm_5, val_232, allowzero=0) + mul_11 = silu * view_19 + t_6 = opset18.Transpose(model_layers_0_mlp_down_proj_weight, perm=[1, 0]) + val_234 = opset18.Cast([30, 8192], to=7) + view_20 = opset18.Reshape(mul_11, val_234, allowzero=0) + mm_6 = view_20 @ t_6 + val_235 = opset18.Cast([1, 30, 2048], to=7) + view_21 = opset18.Reshape(mm_6, val_235, allowzero=0) + add_5 = add_3 + view_21 + _to_copy_10 = opset18.Cast(add_5, to=1) + scalar_tensor_default_2 = opset18.Cast(2, to=1) + pow_3 = _to_copy_10**scalar_tensor_default_2 + val_236 = opset18.Constant(value_ints=[-1]) + val_237 = opset18.Reshape([-1], val_236, allowzero=0) + mean_2 = opset18.ReduceMean(pow_3, val_237, keepdims=1, noop_with_empty_axes=0) + add_6 = mean_2 + 1e-05 + val_238 = opset18.Sqrt(add_6) + rsqrt_2 = opset18.Reciprocal(val_238) + mul_12 = _to_copy_10 * rsqrt_2 + _to_copy_11 = opset18.Cast(mul_12, to=1) + mul_13 = model_norm_weight * _to_copy_11 + t_7 = opset18.Transpose(lm_head_weight, perm=[1, 0]) + val_239 = opset18.Cast([30, 2048], to=7) + view_22 = opset18.Reshape(mul_13, val_239, allowzero=0) + mm_7 = view_22 @ t_7 + val_241 = opset18.Cast([1, 30, 49152], to=7) + view_23 = opset18.Reshape(mm_7, val_241, allowzero=0) + _to_copy_12 = opset18.Cast(view_23, to=1) + return _to_copy_12, cat_3, cat_4 + + model = main_graph.to_model_proto() + return model + + +def make_model_with_random_weights(): + model_layers_0_input_layernorm_weight = numpy.random.rand(2048).astype(numpy.float32) + model_layers_0_post_attention_layernorm_weight = numpy.random.rand(2048).astype( + numpy.float32 + ) + model_norm_weight = numpy.random.rand(2048).astype(numpy.float32) + lm_head_weight = numpy.random.rand(49152, 2048).astype(numpy.float32) + model_layers_0_self_attn_q_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_self_attn_k_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_self_attn_v_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_self_attn_o_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_mlp_gate_proj_weight = numpy.random.rand(8192, 2048).astype(numpy.float32) + model_layers_0_mlp_up_proj_weight = numpy.random.rand(8192, 2048).astype(numpy.float32) + model_layers_0_mlp_down_proj_weight = numpy.random.rand(2048, 8192).astype(numpy.float32) + model_rotary_emb_inv_freq = numpy.random.rand(32).astype(numpy.float32) + model = make_model( + model_layers_0_input_layernorm_weight, + model_layers_0_post_attention_layernorm_weight, + model_norm_weight, + lm_head_weight, + model_layers_0_self_attn_q_proj_weight, + model_layers_0_self_attn_k_proj_weight, + model_layers_0_self_attn_v_proj_weight, + model_layers_0_self_attn_o_proj_weight, + model_layers_0_mlp_gate_proj_weight, + model_layers_0_mlp_up_proj_weight, + model_layers_0_mlp_down_proj_weight, + model_rotary_emb_inv_freq, + ) + return model + + +class _SmollmTest2: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "input_ids": numpy.random.randint(0, 49152, (1, 30)).astype(numpy.int64), + "position_ids": numpy.arange(30).reshape(1, 30).astype(numpy.int64), + "past_key_values_0_0": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32), + "past_key_values_0_1": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def smollm_test_2(): + return _SmollmTest2() diff --git a/onnxscript/rewriter/models/_test_models.py b/onnxscript/rewriter/models/_test_models.py new file mode 100644 index 0000000000..38de87fa21 --- /dev/null +++ b/onnxscript/rewriter/models/_test_models.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir +import torch +import transformers +from transformers import LlamaConfig + +import onnxscript.optimizer + +# Create a LlamaConfig object with the desired parameters +_config = LlamaConfig( + _name_or_path="HuggingFaceTB/SmolLM-1.7B", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=0, + eos_token_id=0, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=8192, + max_position_embeddings=2048, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype="float32", + transformers_version="4.37.2", + use_cache=True, + vocab_size=49152, +) + +# Dimensions for inputs: +_batch_size = 1 +_seq_len = 10 +_hidden_size = _config.hidden_size +_num_attention_heads = _config.num_attention_heads +dim = _hidden_size // _num_attention_heads +_vocab_size = _config.vocab_size + + +class _SmollmTestData: + def __init__(self): + pass + + def get_torch_model(self): + if not hasattr(self, "_torch_model"): + model = transformers.LlamaForCausalLM(_config) + model.eval() + self._torch_model = model + return self._torch_model + + def get_onnx_model(self) -> ir.Model: + model = self.get_torch_model() + inputs = self.get_inputs() + input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] + exported = torch.onnx.export( + model, inputs, input_names=input_names, dynamo=True, fallback=True + ) + # ORT Transformer optimizations are applied after basic optimization. + exported_model = exported.model # type: ignore[union-attr] + onnxscript.optimizer.optimize(exported_model) + return exported_model + + def get_inputs(self): + if not hasattr(self, "_inputs"): + input_ids = torch.randint(0, _vocab_size, (_batch_size, _seq_len)).to(torch.int64) + attention_mask = torch.ones(input_ids.shape) + position_ids = torch.arange(0, input_ids.size(-1)).unsqueeze(0) + self._inputs = (input_ids, attention_mask, position_ids) + return self._inputs + + def get_torch_outputs(self): + output = self.get_torch_model()(*self.get_inputs()) + logits = output.logits + past_key_value = output.past_key_values[0] + key = past_key_value[0] + value = past_key_value[1] + return (logits.detach().numpy(), key.detach().numpy(), value.detach().numpy()) + + def get_ort_inputs(self): + inputs = self.get_inputs() + return { + f"input{i}": input.numpy() for i, input in enumerate(inputs) if input is not None + } diff --git a/onnxscript/rewriter/models/_whisper_decoder.py b/onnxscript/rewriter/models/_whisper_decoder.py new file mode 100644 index 0000000000..20af1e05b7 --- /dev/null +++ b/onnxscript/rewriter/models/_whisper_decoder.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +A one-layer Whisper decoder model test case, with inputs: audio_features. +This model contains one layer of self-attention and one layer of cross-attention. +This is an onnxscript version of the model. +""" + +import numpy as np +import onnx_ir as ir + +from onnxscript import script +from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import FLOAT, INT32 + + +def make_model( + decoder_embed_positions_weight, + proj_out_weight, + decoder_layers_0_self_attn_layer_norm_weight, + decoder_layers_0_self_attn_layer_norm_bias, + decoder_layers_0_self_attn_q_proj_weight, + decoder_layers_0_self_attn_q_proj_bias, + decoder_layers_0_self_attn_k_proj_weight, + decoder_layers_0_self_attn_v_proj_weight, + decoder_layers_0_self_attn_v_proj_bias, + decoder_layers_0_self_attn_out_proj_weight, + decoder_layers_0_self_attn_out_proj_bias, + decoder_layers_0_encoder_attn_layer_norm_weight, + decoder_layers_0_encoder_attn_layer_norm_bias, + decoder_layers_0_encoder_attn_q_proj_weight, + decoder_layers_0_encoder_attn_q_proj_bias, + decoder_layers_0_encoder_attn_out_proj_weight, + decoder_layers_0_encoder_attn_out_proj_bias, + decoder_layers_0_final_layer_norm_weight, + decoder_layers_0_final_layer_norm_bias, + decoder_layers_0_fc1_weight, + decoder_layers_0_fc1_bias, + decoder_layers_0_fc2_weight, + decoder_layers_0_fc2_bias, + decoder_layer_norm_weight, + decoder_layer_norm_bias, +): + @script() + def main_graph( + # TODO: Fix test case for dynamic batch size and past sequence length + decoder_input_ids: INT32[1, 1], + encoder_hidden_states: FLOAT[1, 1500, 384], + past_key_values_0_0: FLOAT[1, 6, 32, 64], + past_key_values_0_1: FLOAT[1, 6, 32, 64], + past_key_values_0_2: FLOAT[1, 6, 32, 64], + past_key_values_0_3: FLOAT[1, 6, 32, 64], + ) -> ( + FLOAT[1, 1, 51865], + FLOAT[1, 6, 33, 64], + FLOAT[1, 6, 33, 64], + ): + val_0 = opset18.Shape(decoder_input_ids, end=1, start=0) + val_1 = opset18.Shape(past_key_values_0_0, end=3, start=2) + sym_size_int_42 = opset18.Squeeze(val_1) + view = opset18.Reshape(decoder_input_ids, [-1, 1], allowzero=0) + embedding = opset18.Gather(proj_out_weight, view, axis=0) + add_7 = opset18.Add(sym_size_int_42, 1) + arange = opset18.Range(sym_size_int_42, add_7, 1) + unsqueeze = opset18.Unsqueeze(arange, [0]) + val_16 = opset18.Concat(val_0, [1], axis=0) + repeat = opset18.Tile(unsqueeze, val_16) + val_22 = opset18.Unsqueeze(repeat, [-1]) + val_24 = opset18.GatherND(decoder_embed_positions_weight, val_22, batch_dims=0) + add_15 = opset18.Add(embedding, val_24) + add_24 = opset18.Add(add_7, 1) + val_28 = opset18.Reshape(add_24, [-1], allowzero=0) + val_29 = opset18.Concat([1], val_28, axis=0) + full = opset18.Expand(-3.4028235e38, val_29) + arange_1 = opset18.Range(0, add_24, 1) + view_1 = opset18.Reshape(arange, [-1, 1], allowzero=0) + gt = opset18.Greater(arange_1, view_1) + convert_element_type_default = opset18.Cast(gt, to=1) + mul_17 = opset18.Mul(full, convert_element_type_default) + layer_norm = opset18.LayerNormalization( + add_15, + decoder_layers_0_self_attn_layer_norm_weight, + decoder_layers_0_self_attn_layer_norm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_37 = opset18.Transpose(decoder_layers_0_self_attn_q_proj_weight, perm=[1, 0]) + val_38 = opset18.MatMul(layer_norm, val_37) + linear = opset18.Add(val_38, decoder_layers_0_self_attn_q_proj_bias) + mul_43 = opset18.Mul(linear, 0.125) + val_44 = opset18.Concat(val_0, [1], [6], [64], axis=0) + view_2 = opset18.Reshape(mul_43, val_44, allowzero=0) + transpose = opset18.Transpose(view_2, perm=[0, 2, 1, 3]) + val_46 = opset18.Transpose(decoder_layers_0_self_attn_k_proj_weight, perm=[1, 0]) + linear_1 = opset18.MatMul(layer_norm, val_46) + val_49 = opset18.Concat(val_0, [-1], [6], [64], axis=0) + view_3 = opset18.Reshape(linear_1, val_49, allowzero=0) + transpose_1 = opset18.Transpose(view_3, perm=[0, 2, 1, 3]) + val_51 = opset18.Transpose(decoder_layers_0_self_attn_v_proj_weight, perm=[1, 0]) + val_52 = opset18.MatMul(layer_norm, val_51) + linear_2 = opset18.Add(val_52, decoder_layers_0_self_attn_v_proj_bias) + val_55 = opset18.Concat(val_0, [-1], [6], [64], axis=0) + view_4 = opset18.Reshape(linear_2, val_55, allowzero=0) + transpose_2 = opset18.Transpose(view_4, perm=[0, 2, 1, 3]) + cat = opset18.Concat(past_key_values_0_0, transpose_1, axis=-2) + cat_1 = opset18.Concat(past_key_values_0_1, transpose_2, axis=-2) + transpose_3 = opset18.Transpose(cat, perm=[0, 1, 3, 2]) + matmul = opset18.MatMul(transpose, transpose_3) + unsqueeze_4 = opset18.Unsqueeze(mul_17, [0, 1]) + val_83 = opset18.Concat(val_0, [1], [-1], [-1], axis=0) + val_85 = opset18.Abs(val_83) + expand_1 = opset18.Expand(unsqueeze_4, val_85) + val_104 = opset18.Constant(value_ints=[0]) + val_106 = opset18.Constant(value_ints=[-1]) + val_107 = opset18.Reshape(add_7, val_106, allowzero=0) + val_111 = opset18.Constant(value_ints=[1]) + slice_12 = opset18.Slice(expand_1, val_104, val_107, [3], val_111) + add_125 = opset18.Add(matmul, slice_12) + softmax = opset18.Softmax(add_125, axis=-1) + matmul_1 = opset18.MatMul(softmax, cat_1) + transpose_4 = opset18.Transpose(matmul_1, perm=[0, 2, 1, 3]) + val_115 = opset18.Concat(val_0, [1], [384], axis=0) + view_5 = opset18.Reshape(transpose_4, val_115, allowzero=0) + val_117 = opset18.Transpose(decoder_layers_0_self_attn_out_proj_weight, perm=[1, 0]) + val_118 = opset18.MatMul(view_5, val_117) + linear_3 = opset18.Add(val_118, decoder_layers_0_self_attn_out_proj_bias) + add_163 = opset18.Add(add_15, linear_3) + layer_norm_1 = opset18.LayerNormalization( + add_163, + decoder_layers_0_encoder_attn_layer_norm_weight, + decoder_layers_0_encoder_attn_layer_norm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_121 = opset18.Transpose(decoder_layers_0_encoder_attn_q_proj_weight, perm=[1, 0]) + val_122 = opset18.MatMul(layer_norm_1, val_121) + linear_4 = opset18.Add(val_122, decoder_layers_0_encoder_attn_q_proj_bias) + mul_125 = opset18.Mul(linear_4, 0.125) + val_125 = opset18.Concat(val_0, [1], [6], [64], axis=0) + view_6 = opset18.Reshape(mul_125, val_125, allowzero=0) + transpose_5 = opset18.Transpose(view_6, perm=[0, 2, 1, 3]) + transpose_6 = opset18.Transpose(past_key_values_0_2, perm=[0, 1, 3, 2]) + matmul_2 = opset18.MatMul(transpose_5, transpose_6) + softmax_1 = opset18.Softmax(matmul_2, axis=-1) + matmul_3 = opset18.MatMul(softmax_1, past_key_values_0_3) + transpose_7 = opset18.Transpose(matmul_3, perm=[0, 2, 1, 3]) + val_129 = opset18.Concat(val_0, [1], [384], axis=0) + view_7 = opset18.Reshape(transpose_7, val_129, allowzero=0) + val_131 = opset18.Transpose(decoder_layers_0_encoder_attn_out_proj_weight, perm=[1, 0]) + val_132 = opset18.MatMul(view_7, val_131) + linear_5 = opset18.Add(val_132, decoder_layers_0_encoder_attn_out_proj_bias) + add_232 = opset18.Add(add_163, linear_5) + layer_norm_2 = opset18.LayerNormalization( + add_232, + decoder_layers_0_final_layer_norm_weight, + decoder_layers_0_final_layer_norm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_135 = opset18.Transpose(decoder_layers_0_fc1_weight, perm=[1, 0]) + val_136 = opset18.MatMul(layer_norm_2, val_135) + linear_6 = opset18.Add(val_136, decoder_layers_0_fc1_bias) + val_138 = opset18.Div(linear_6, 1.4142135) + val_139 = opset18.Erf(val_138) + val_141 = opset18.Add(val_139, 1.0) + val_143 = opset18.Mul(0.5, val_141) + gelu = opset18.Mul(linear_6, val_143) + val_144 = opset18.Transpose(decoder_layers_0_fc2_weight, perm=[1, 0]) + val_145 = opset18.MatMul(gelu, val_144) + linear_7 = opset18.Add(val_145, decoder_layers_0_fc2_bias) + add_261 = opset18.Add(add_232, linear_7) + layer_norm_12 = opset18.LayerNormalization( + add_261, + decoder_layer_norm_weight, + decoder_layer_norm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_457 = opset18.Transpose(proj_out_weight, perm=[1, 0]) + linear_32 = opset18.MatMul(layer_norm_12, val_457) + return linear_32, cat, cat_1 + + model = main_graph.to_model_proto() + return model + + +def make_model_with_random_weights(): + np.random.seed(10) # Set a fixed seed + decoder_embed_positions_weight = np.random.rand(448, 384).astype(np.float32) + proj_out_weight = np.random.rand(51865, 384).astype(np.float32) + decoder_layers_0_self_attn_layer_norm_weight = np.random.rand(384).astype(np.float32) + decoder_layers_0_self_attn_layer_norm_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_self_attn_q_proj_weight = np.random.rand(384, 384).astype(np.float32) + decoder_layers_0_self_attn_q_proj_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_self_attn_k_proj_weight = np.random.rand(384, 384).astype(np.float32) + decoder_layers_0_self_attn_v_proj_weight = np.random.rand(384, 384).astype(np.float32) + decoder_layers_0_self_attn_v_proj_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_self_attn_out_proj_weight = np.random.rand(384, 384).astype(np.float32) + decoder_layers_0_self_attn_out_proj_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_encoder_attn_layer_norm_weight = np.random.rand(384).astype(np.float32) + decoder_layers_0_encoder_attn_layer_norm_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_encoder_attn_q_proj_weight = np.random.rand(384, 384).astype(np.float32) + decoder_layers_0_encoder_attn_q_proj_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_encoder_attn_out_proj_weight = np.random.rand(384, 384).astype(np.float32) + decoder_layers_0_encoder_attn_out_proj_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_final_layer_norm_weight = np.random.rand(384).astype(np.float32) + decoder_layers_0_final_layer_norm_bias = np.random.rand(384).astype(np.float32) + decoder_layers_0_fc1_weight = np.random.rand(1536, 384).astype(np.float32) + decoder_layers_0_fc1_bias = np.random.rand(1536).astype(np.float32) + decoder_layers_0_fc2_weight = np.random.rand(384, 1536).astype(np.float32) + decoder_layers_0_fc2_bias = np.random.rand(384).astype(np.float32) + decoder_layer_norm_weight = np.random.rand(384).astype(np.float32) + decoder_layer_norm_bias = np.random.rand(384).astype(np.float32) + + model = make_model( + decoder_embed_positions_weight, + proj_out_weight, + decoder_layers_0_self_attn_layer_norm_weight, + decoder_layers_0_self_attn_layer_norm_bias, + decoder_layers_0_self_attn_q_proj_weight, + decoder_layers_0_self_attn_q_proj_bias, + decoder_layers_0_self_attn_k_proj_weight, + decoder_layers_0_self_attn_v_proj_weight, + decoder_layers_0_self_attn_v_proj_bias, + decoder_layers_0_self_attn_out_proj_weight, + decoder_layers_0_self_attn_out_proj_bias, + decoder_layers_0_encoder_attn_layer_norm_weight, + decoder_layers_0_encoder_attn_layer_norm_bias, + decoder_layers_0_encoder_attn_q_proj_weight, + decoder_layers_0_encoder_attn_q_proj_bias, + decoder_layers_0_encoder_attn_out_proj_weight, + decoder_layers_0_encoder_attn_out_proj_bias, + decoder_layers_0_final_layer_norm_weight, + decoder_layers_0_final_layer_norm_bias, + decoder_layers_0_fc1_weight, + decoder_layers_0_fc1_bias, + decoder_layers_0_fc2_weight, + decoder_layers_0_fc2_bias, + decoder_layer_norm_weight, + decoder_layer_norm_bias, + ) + return model + + +class _WhisperDecoderTest: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + np.random.seed(10) # Set a fixed seed + inputs = { + "decoder_input_ids": np.random.randint(0, 49152, (1, 1)).astype(np.int32), + "encoder_hidden_states": np.random.rand(1, 1500, 384).astype(np.float32), + "past_key_values_0_0": np.random.rand(1, 6, 32, 64).astype(np.float32), + "past_key_values_0_1": np.random.rand(1, 6, 32, 64).astype(np.float32), + "past_key_values_0_2": np.random.rand(1, 6, 32, 64).astype(np.float32), + "past_key_values_0_3": np.random.rand(1, 6, 32, 64).astype(np.float32), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def whisper_decoder_test(): + return _WhisperDecoderTest() diff --git a/onnxscript/rewriter/models/_whisper_encoder.py b/onnxscript/rewriter/models/_whisper_encoder.py new file mode 100644 index 0000000000..25a7ffe296 --- /dev/null +++ b/onnxscript/rewriter/models/_whisper_encoder.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +A one-layer Whisper encoder model test case, with inputs: audio_features. +This is an onnxscript version of the model. +""" + +import numpy as np +import onnx_ir as ir + +from onnxscript import script +from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import FLOAT + + +def make_model( + encoder_encoder_embed_positions_weight, + encoder_encoder_conv1_weight, + encoder_encoder_conv1_bias, + encoder_encoder_conv2_weight, + encoder_encoder_conv2_bias, + encoder_encoder_layers_0_self_attn_layer_norm_weight, + encoder_encoder_layers_0_self_attn_layer_norm_bias, + encoder_encoder_layers_0_self_attn_q_proj_weight, + encoder_encoder_layers_0_self_attn_q_proj_bias, + encoder_encoder_layers_0_self_attn_k_proj_weight, + encoder_encoder_layers_0_self_attn_v_proj_weight, + encoder_encoder_layers_0_self_attn_v_proj_bias, + encoder_encoder_layers_0_self_attn_out_proj_weight, + encoder_encoder_layers_0_self_attn_out_proj_bias, + encoder_encoder_layers_0_final_layer_norm_weight, + encoder_encoder_layers_0_final_layer_norm_bias, + encoder_encoder_layers_0_fc1_weight, + encoder_encoder_layers_0_fc1_bias, + encoder_encoder_layers_0_fc2_weight, + encoder_encoder_layers_0_fc2_bias, + encoder_encoder_layer_norm_weight, + encoder_encoder_layer_norm_bias, +): + @script() + def main_graph( + audio_features: FLOAT[1, 80, 3000], + ) -> FLOAT[1, 1500, 384]: + val_0 = opset18.Shape(audio_features, end=1, start=0) + conv1d = opset18.Conv( + audio_features, + encoder_encoder_conv1_weight, + encoder_encoder_conv1_bias, + group=1, + pads=[1, 1], + auto_pad="NOTSET", + strides=[1], + dilations=[1], + ) + val_2 = opset18.Div(conv1d, 1.4142135) + val_3 = opset18.Erf(val_2) + val_5 = opset18.Add(val_3, 1.0) + val_7 = opset18.Mul(0.5, val_5) + gelu = opset18.Mul(conv1d, val_7) + conv1d_1 = opset18.Conv( + gelu, + encoder_encoder_conv2_weight, + encoder_encoder_conv2_bias, + group=1, + pads=[1, 1], + auto_pad="NOTSET", + strides=[2], + dilations=[1], + ) + val_9 = opset18.Div(conv1d_1, 1.4142135) + val_10 = opset18.Erf(val_9) + val_12 = opset18.Add(val_10, 1.0) + val_14 = opset18.Mul(0.5, val_12) + gelu_1 = opset18.Mul(conv1d_1, val_14) + permute = opset18.Transpose(gelu_1, perm=[0, 2, 1]) + add_20 = opset18.Add(permute, encoder_encoder_embed_positions_weight) + layer_norm = opset18.LayerNormalization( + add_20, + encoder_encoder_layers_0_self_attn_layer_norm_weight, + encoder_encoder_layers_0_self_attn_layer_norm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_17 = opset18.Transpose( + encoder_encoder_layers_0_self_attn_q_proj_weight, perm=[1, 0] + ) + val_18 = opset18.MatMul(layer_norm, val_17) + linear = opset18.Add(val_18, encoder_encoder_layers_0_self_attn_q_proj_bias) + mul_18 = opset18.Mul(linear, 0.125) + val_25 = opset18.Concat(val_0, [1500], [6], [64], axis=0) + view = opset18.Reshape(mul_18, val_25, allowzero=0) + transpose = opset18.Transpose(view, perm=[0, 2, 1, 3]) + val_27 = opset18.Transpose( + encoder_encoder_layers_0_self_attn_k_proj_weight, perm=[1, 0] + ) + linear_1 = opset18.MatMul(layer_norm, val_27) + val_31 = opset18.Concat(val_0, [-1], [6], [64], axis=0) + view_1 = opset18.Reshape(linear_1, val_31, allowzero=0) + val_33 = opset18.Transpose( + encoder_encoder_layers_0_self_attn_v_proj_weight, perm=[1, 0] + ) + val_34 = opset18.MatMul(layer_norm, val_33) + linear_2 = opset18.Add(val_34, encoder_encoder_layers_0_self_attn_v_proj_bias) + val_37 = opset18.Concat(val_0, [-1], [6], [64], axis=0) + view_2 = opset18.Reshape(linear_2, val_37, allowzero=0) + transpose_2 = opset18.Transpose(view_2, perm=[0, 2, 1, 3]) + transpose_3 = opset18.Transpose(view_1, perm=[0, 2, 3, 1]) + matmul = opset18.MatMul(transpose, transpose_3) + softmax = opset18.Softmax(matmul, axis=-1) + matmul_1 = opset18.MatMul(softmax, transpose_2) + transpose_4 = opset18.Transpose(matmul_1, perm=[0, 2, 1, 3]) + val_42 = opset18.Concat(val_0, [1500], [384], axis=0) + _unsafe_view = opset18.Reshape(transpose_4, val_42, allowzero=0) + val_44 = opset18.Transpose( + encoder_encoder_layers_0_self_attn_out_proj_weight, perm=[1, 0] + ) + val_45 = opset18.MatMul(_unsafe_view, val_44) + linear_3 = opset18.Add(val_45, encoder_encoder_layers_0_self_attn_out_proj_bias) + add_141 = opset18.Add(add_20, linear_3) + layer_norm_1 = opset18.LayerNormalization( + add_141, + encoder_encoder_layers_0_final_layer_norm_weight, + encoder_encoder_layers_0_final_layer_norm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + val_48 = opset18.Transpose(encoder_encoder_layers_0_fc1_weight, perm=[1, 0]) + val_49 = opset18.MatMul(layer_norm_1, val_48) + linear_4 = opset18.Add(val_49, encoder_encoder_layers_0_fc1_bias) + val_51 = opset18.Div(linear_4, 1.4142135) + val_52 = opset18.Erf(val_51) + val_54 = opset18.Add(val_52, 1.0) + val_56 = opset18.Mul(0.5, val_54) + gelu_2 = opset18.Mul(linear_4, val_56) + val_57 = opset18.Transpose(encoder_encoder_layers_0_fc2_weight, perm=[1, 0]) + val_58 = opset18.MatMul(gelu_2, val_57) + linear_5 = opset18.Add(val_58, encoder_encoder_layers_0_fc2_bias) + add_170 = opset18.Add(add_141, linear_5) + layer_norm_2 = opset18.LayerNormalization( + add_170, + encoder_encoder_layer_norm_weight, + encoder_encoder_layer_norm_bias, + stash_type=1, + epsilon=9.999999747378752e-06, + axis=-1, + ) + return layer_norm_2 + + model = main_graph.to_model_proto() + return model + + +def make_model_with_random_weights(): + np.random.seed(10) # Set a fixed seed + encoder_encoder_embed_positions_weight = np.random.rand(1500, 384).astype(np.float32) + encoder_encoder_conv1_weight = np.random.rand(384, 80, 3).astype(np.float32) + encoder_encoder_conv1_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_conv2_weight = np.random.rand(384, 384, 3).astype(np.float32) + encoder_encoder_conv2_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_layers_0_self_attn_layer_norm_weight = np.random.rand(384).astype( + np.float32 + ) + encoder_encoder_layers_0_self_attn_layer_norm_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_layers_0_self_attn_q_proj_weight = np.random.rand(384, 384).astype( + np.float32 + ) + encoder_encoder_layers_0_self_attn_q_proj_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_layers_0_self_attn_k_proj_weight = np.random.rand(384, 384).astype( + np.float32 + ) + encoder_encoder_layers_0_self_attn_v_proj_weight = np.random.rand(384, 384).astype( + np.float32 + ) + encoder_encoder_layers_0_self_attn_v_proj_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_layers_0_self_attn_out_proj_weight = np.random.rand(384, 384).astype( + np.float32 + ) + encoder_encoder_layers_0_self_attn_out_proj_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_layers_0_final_layer_norm_weight = np.random.rand(384).astype(np.float32) + encoder_encoder_layers_0_final_layer_norm_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_layers_0_fc1_weight = np.random.rand(1536, 384).astype(np.float32) + encoder_encoder_layers_0_fc1_bias = np.random.rand(1536).astype(np.float32) + encoder_encoder_layers_0_fc2_weight = np.random.rand(384, 1536).astype(np.float32) + encoder_encoder_layers_0_fc2_bias = np.random.rand(384).astype(np.float32) + encoder_encoder_layer_norm_weight = np.random.rand(384).astype(np.float32) + encoder_encoder_layer_norm_bias = np.random.rand(384).astype(np.float32) + model = make_model( + encoder_encoder_embed_positions_weight, + encoder_encoder_conv1_weight, + encoder_encoder_conv1_bias, + encoder_encoder_conv2_weight, + encoder_encoder_conv2_bias, + encoder_encoder_layers_0_self_attn_layer_norm_weight, + encoder_encoder_layers_0_self_attn_layer_norm_bias, + encoder_encoder_layers_0_self_attn_q_proj_weight, + encoder_encoder_layers_0_self_attn_q_proj_bias, + encoder_encoder_layers_0_self_attn_k_proj_weight, + encoder_encoder_layers_0_self_attn_v_proj_weight, + encoder_encoder_layers_0_self_attn_v_proj_bias, + encoder_encoder_layers_0_self_attn_out_proj_weight, + encoder_encoder_layers_0_self_attn_out_proj_bias, + encoder_encoder_layers_0_final_layer_norm_weight, + encoder_encoder_layers_0_final_layer_norm_bias, + encoder_encoder_layers_0_fc1_weight, + encoder_encoder_layers_0_fc1_bias, + encoder_encoder_layers_0_fc2_weight, + encoder_encoder_layers_0_fc2_bias, + encoder_encoder_layer_norm_weight, + encoder_encoder_layer_norm_bias, + ) + return model + + +class _WhisperEncoderTest: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + np.random.seed(10) # Set a fixed seed + inputs = { + "audio_features": np.random.rand(1, 80, 3000).astype(np.float32), + } + self._ort_inputs = inputs + return self._ort_inputs + + +def whisper_encoder_test(): + return _WhisperEncoderTest() diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py deleted file mode 100644 index bd9b1c3703..0000000000 --- a/onnxscript/rewriter/no_op.py +++ /dev/null @@ -1,44 +0,0 @@ -from onnxscript.rewriter import pattern - -op = pattern.onnxop - -# TODO: Support 1-D constant tensors -# https://github.com/microsoft/onnx-rewriter/issues/186 - - -# Pattern to match against -def mul_by_1(x): - return x * 1 - - -def add_0(x): - return x + 0 - - -def sub_0(x): - return x - 0 - - -def div_by_1(x): - return x / 1 - - -# Replacement -def identity(op, x): - return op.Identity(x) - - -mul_by_1_rule = pattern.RewriteRule(mul_by_1, identity) -add_0_rule = pattern.RewriteRule(add_0, identity) -sub_0_rule = pattern.RewriteRule(sub_0, identity) -div_by_1_rule = pattern.RewriteRule(div_by_1, identity) -# TODO: Include Mul by 0, 0 by Mul, 0 by Div? Those would be 0s, but not no-ops - -rules = pattern.RewriteRuleSet( - [ - *mul_by_1_rule.commute(), - *add_0_rule.commute(), - sub_0_rule, - div_by_1_rule, - ] -) diff --git a/onnxscript/rewriter/onnx_fusions/__init__.py b/onnxscript/rewriter/onnx_fusions/__init__.py new file mode 100644 index 0000000000..d2e8d885f0 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter.onnx_fusions._onnx_fusions import fuse + +__all__ = [ + "fuse", +] diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py new file mode 100644 index 0000000000..008a995764 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir + +from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding + + +def _get_onnx_opset_version(model: ir.Model) -> int | None: + """Get the ONNX opset version imported by the model.""" + model_version1 = model.opset_imports.get("") + model_version2 = model.opset_imports.get("ai.onnx") + if model_version1 is not None and model_version2 is not None: + if model_version1 != model_version2: + raise ValueError( + f"Model imports multiple onnx opsets: {model_version1} and {model_version2}." + ) + return model_version1 or model_version2 + + +def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: + """Apply fusions targeting ONNX opset 23.""" + counts: dict[str, int] = {} + counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) + counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug) + counts["GQA"] = _gqa.fuse_gqa(model, debug=debug) + return counts + + +def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: + """Apply fusions targeting ONNX ops.""" + model_opset_version = _get_onnx_opset_version(model) + if model_opset_version == 23: + return _opset_23_fuse(model, debug=debug) + return {} diff --git a/onnxscript/rewriter/onnxruntime/README.md b/onnxscript/rewriter/onnxruntime/README.md new file mode 100644 index 0000000000..b1a5d205a0 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/README.md @@ -0,0 +1 @@ +This folder (and function_rule based rewrites) are deprecated. The folder will be removed soon. diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index 4e9007e36b..6ca67d171b 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -1,58 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Deprecated. This module is kept for backward compatibility.""" + from __future__ import annotations +from typing import Sequence + import onnx -from onnxscript import ir -from onnxscript.optimizer import remove_unused, remove_unused_function -from onnxscript.rewriter import function_rule, pattern -from onnxscript.rewriter.onnxruntime import ( - group_normalization_merge_silu, - instance_to_group_normalization, - softmax, - transformers, -) - -ORT_FUNCTION_REWRITE_RULES = [*transformers.TRANSFORMERS_FUNCTION_REWRITE_RULES] - -ORT_PATTERN_REWRITE_RULES = [ - *softmax.rules.rules, - *instance_to_group_normalization.rules.rules, - # NOTE: group normalization merge silu should be applied after instance to group normalization - *group_normalization_merge_silu.rules.rules, +from onnxscript.rewriter import pattern +from onnxscript.rewriter import rewrite as _rewrite +from onnxscript.rewriter.ort_fusions import ORT_PATTERN_REWRITE_RULES + +__all__ = [ + "rewrite", + "ORT_PATTERN_REWRITE_RULES", ] def rewrite( model_proto: onnx.ModelProto, /, - function_rules: list[type[function_rule.FunctionRewriteRule]] | None = None, - pattern_rules: list[pattern.RewriteRule] | None = None, + pattern_rules: Sequence[pattern.RewriteRule] | None = None, ) -> onnx.ModelProto: """Rewrite the model using the given rules. Args: model_proto: The model to rewrite. - function_rules: The function rewrite rules to apply. If None, the default rules - for onnxruntime are used. pattern_rules: The pattern rewrite rules to apply. If None, the default rules for onnxruntime are used. Returns: The rewritten model. """ - function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES - model = ir.serde.deserialize_model(model_proto) - # TODO(bowenbao): Function rules first, or pattern rules first? - if function_rules: - for rule_cls in function_rules: - count, model = rule_cls().apply_to_model(model) - print(f"Applied {count} of onnxruntime specific function rewrite rules.") - if pattern_rules: - count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model) - print(f"Applied {count} of onnxruntime specific pattern rewrite rules.") - - model_proto = ir.serde.serialize_model(model) - remove_unused.remove_unused_nodes(model_proto) - remove_unused_function.remove_unused_functions(model_proto) - return model_proto + return _rewrite(model_proto, pattern_rewrite_rules=pattern_rules) diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py index e4afb432d7..42a3837aa7 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py @@ -1,12 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import logging from onnxscript import ir logger = logging.getLogger(__name__) -CREATED_CAST_BFLOAT16_NAME_SUFFIX = "_cast_bfloat16" -def _convert_inputs_from_bfloat16_to_float16(value: ir.Input) -> None: +def _convert_inputs_from_bfloat16_to_float16(value: ir.Value) -> None: if value.dtype != ir.DataType.BFLOAT16: return value.dtype = ir.DataType.FLOAT16 @@ -19,7 +20,7 @@ def _convert_outputs_from_bfloat16_to_float16(value: ir.Value) -> None: _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value) -def _insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value: ir.Input) -> None: +def _insert_cast_nodes_for_float16_to_bfloat16_to_inputs(value: ir.Value) -> None: user_nodes_and_indices = tuple(value.uses()) attr = ir.AttrInt64(name="to", value=ir.DataType.BFLOAT16) @@ -61,9 +62,6 @@ def _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value: ir.Value) -> No ) cast.outputs[0].dtype = ir.DataType.FLOAT16 cast.outputs[0].shape = node.outputs[index].shape - # To prevent naming conflicts, we need to append suffix to the output name of the cast node - # TODO: Remove this after naming authority covers this case - cast.outputs[0].name = node.outputs[index].name + CREATED_CAST_BFLOAT16_NAME_SUFFIX # type: ignore[operator] node.append(cast) assert node.graph is not None, "Node graph should not be None" diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index 8effd0b28f..a64d6e6023 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import numpy as np @@ -12,11 +14,11 @@ class Bfloat16ConversionTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4])) + self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4])) self.v0.dtype = ir.DataType.BFLOAT16 - self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4])) + self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4])) self.v1.dtype = ir.DataType.BFLOAT16 - self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4])) + self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4])) self.v2.dtype = ir.DataType.BFLOAT16 self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) diff --git a/onnxscript/rewriter/onnxruntime/transformers/__init__.py b/onnxscript/rewriter/onnxruntime/transformers/__init__.py deleted file mode 100644 index 84c73d7b74..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from onnxscript.rewriter import function_rule -from onnxscript.rewriter.onnxruntime.transformers import ( - biassplitgelu, - fastgelu, - layernorm, - multihead_attention, -) - -TRANSFORMERS_FUNCTION_REWRITE_RULES: list[type[function_rule.FunctionRewriteRule]] = [ - multihead_attention.GQALlama2RewriteRule, - multihead_attention.GQALlamaSdpa2RewriteRule, - multihead_attention.AttnPhi15RewriteRule, - multihead_attention.MHAStableDiffusionUnetRewriteRule, - layernorm.LNRewriteRule, - fastgelu.GeluRewriteRule, - biassplitgelu.GegluRewriteRule, -] diff --git a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py deleted file mode 100644 index 591527b597..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -import logging - -import onnxscript -from onnxscript import ir -from onnxscript.rewriter import function_rule - -logger = logging.getLogger(__name__) - - -class GegluRewriteRule(function_rule.FunctionRewriteRule): - FUNCTION_KEYWORD = "GEGLU" - PACKAGE_NAME = "diffusers" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version() # type: ignore[misc] - def _fusion(self, function: ir.Function) -> ir.Function: - del function # Unused - op = self.onnx_opset - msft_opset = onnxscript.values.Opset("com.microsoft", 1) - - def ggelu(input, weight, bias): - weight_transpose = op.Transpose(weight, [1, 0]) - matmul_input = op.MatMul(input, weight_transpose) - return msft_opset.BiasSplitGelu(matmul_input, bias) - - function_proto = onnxscript.script(default_opset=op)(ggelu).to_function_proto() # type: ignore[arg-type] - return ir.serde.deserialize_function(function_proto) diff --git a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py deleted file mode 100644 index 196367c006..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu_test.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -import unittest - -import numpy as np - -from tests.common import testutils - - -class BiasSplitGeluParityTest(unittest.TestCase): - def setUp(self): - np.random.seed(0) - - @testutils.skip_if_no_cuda("BiasSplitGelu Kernel unsupported on CPU.") - def test_geglu_stable_diffusion_unet(self): - testutils.test_onnxruntime_rewrite( - "geglu_stable_diffusion_unet", 4, {("com.microsoft", "BiasSplitGelu", "")} - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py b/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py deleted file mode 100644 index b852401f9b..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -import logging - -import onnxscript -from onnxscript import ir -from onnxscript.rewriter import function_rule - -logger = logging.getLogger(__name__) - - -class GeluRewriteRule(function_rule.FunctionRewriteRule): - FUNCTION_KEYWORD = "GELUActivation" - PACKAGE_NAME = "transformers" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version() - def _fusion(self, function: ir.Function) -> ir.Function: - del function # Unused - op = self.onnx_opset - msft_opset = onnxscript.values.Opset("com.microsoft", 1) - - def gelu(input): - return msft_opset.FastGelu(input) - - function_proto = onnxscript.script(default_opset=op)(gelu).to_function_proto() - return ir.serde.deserialize_function(function_proto) diff --git a/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py b/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py deleted file mode 100644 index db26adf284..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -import unittest - -import numpy as np - -from tests.common import testutils - - -class FastGeluParityTest(unittest.TestCase): - def setUp(self): - np.random.seed(0) - - def test_gelu_phi_1_5(self): - testutils.test_onnxruntime_rewrite( - "gelu_phi_1_5", 4, {("com.microsoft", "FastGelu", "")} - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py deleted file mode 100644 index 54ccfa86ba..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -import logging - -import onnxscript -from onnxscript import ir -from onnxscript.rewriter import _ir_utils, function_rule - -logger = logging.getLogger(__name__) - - -class LNRewriteRule(function_rule.FunctionRewriteRule): - FUNCTION_KEYWORD = "layernorm" - PACKAGE_NAME = "transformers" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version() - def _fusion(self, function: ir.Function) -> ir.Function: - # TODO(bowbao): Might be more desirable to annotate as attribute in nn.Module - aten_add_node = self._find_node_by_type(function, "", "Add") - if aten_add_node is None: - raise function_rule.FunctionRewriteError("Could not find Add node") - - eps_ir_value = _ir_utils.propagate_const_value(aten_add_node.inputs[1]) - eps_numpy_value = _ir_utils.get_numpy_from_ir_value(eps_ir_value) - if eps_numpy_value is None: - raise function_rule.FunctionRewriteError("Could not find eps") - eps = eps_numpy_value.item() - logger.info("eps: %s", eps) - - # TODO(ORT): SimplifiedLayerNormalization in ort is defined under onnx domain. - # https://github.com/microsoft/onnxruntime/issues/7573 - # msft_op = onnxscript.values.Opset("com.microsoft", 1) - op = self.onnx_opset - - def ln(input, weight): - return op.SimplifiedLayerNormalization( - input, weight, axis=-1, epsilon=eps, stash_type=1 - ) - - function_proto = onnxscript.script(default_opset=op)(ln).to_function_proto() - return ir.serde.deserialize_function(function_proto) diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py deleted file mode 100644 index f4f494aa10..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -import unittest - -import numpy as np - -from tests.common import testutils - - -class LNParityTest(unittest.TestCase): - def setUp(self): - np.random.seed(0) - - def test_ln_llama2(self): - testutils.test_onnxruntime_rewrite( - "ln_llama2", 4, {("", "SimplifiedLayerNormalization", "")} - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py deleted file mode 100644 index 9c16ef975e..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py +++ /dev/null @@ -1,706 +0,0 @@ -r"""POC experimenting function aware pattern re-write. - -In this case we don't want to spell-out the entire source pattern. -Instead, we want to replace an entire function call a new subgraph. - -Source function: LlamaAttention -inputs (positional args, the names in function definition are unfortunately arbitrary and don't provide value): - - hidden_states - - position_id - - attention_mask - - q_proj.weight - - k_proj.weight - - v_proj.weight - - cos_cached - - sin_cached - - o_proj.weight -outputs (similarly, positional) - - present_value - - present_key - - attn_output (o_proj) - -The rewriting algorithm is as follows: - -The final new function graph should look like this: - - function_proj_q function_proj_k - | | - | | -com.microsoft::RotaryEmbedding com.microsoft::RotaryEmbedding function_proj_v - \ / / - \ / / - \ / / - \--------------- / -----------------------/ - com.microsoft::MultiHeadAttention - | | | - attn_output (present_key) (present_value) - | - function_proj_o - | - (output) - -So all we need, is to locate 'function_proj_q', 'function_proj_k', 'function_proj_v', 'function_proj_o'. -Construct the 4 nodes with new contrib op nodes, and properly name their inputs/outputs. - -""" - -from __future__ import annotations - -import abc -import dataclasses -import logging - -import onnx -from onnx import helper as onnx_helper - -import onnxscript -from onnxscript import ir -from onnxscript.rewriter import _ir_utils, function_rule - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class AttnSizeConfig: - num_attention_heads: int - num_key_value_heads: int | None - head_size: int - hidden_size: int - - -class AttentionRewriteRule(function_rule.FunctionRewriteRule, abc.ABC): - def infer_attn_size_config(self, function: ir.Function) -> AttnSizeConfig: - if len(function.outputs) == 3: - # Usually the Attention related modules have 3 outputs: - # present_value, present_key, attn_output - present_value, _, attn_output = function.outputs - if present_value.shape is None: - raise function_rule.FunctionRewriteError( - "Failed to find shape for present_value." - ) - if attn_output.shape is None: - raise function_rule.FunctionRewriteError( - "Failed to find shape for attn_output." - ) - head_size = present_value.shape[3] - num_key_value_heads = present_value.shape[1] - hidden_size = attn_output.shape[2] - num_attention_heads = hidden_size // head_size - return AttnSizeConfig( - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_size=head_size, - hidden_size=hidden_size, - ) - elif any("scaled_dot_product_attention" in node.op_type for node in function): - # If the Attention related modules use scaled_dot_product_attention, - # present_value and present_key are not present in the output. - hidden_size = function.outputs[0].shape[2] - # Get head size and number of heads from the Reshape node. - # Reference: - # https://github.com/huggingface/diffusers/blob/ae05050db9d37d5af48a6cd0d6510a5ffb1c1cd4/src/diffusers/models/attention_processor.py#L1269 - reshape_nodes = [node for node in function if node.op_type == "Reshape"] - assert ( - len(reshape_nodes) == 4 - ), "Expected 3 Reshape nodes for Q, K and V, and 1 reshape node for output of scaled_dot_product_attention." - for reshape_node in reshape_nodes: - constant_node = reshape_node.inputs[1].producer() - assert ( - constant_node.op_type == "Constant" - ), "Expected the second input to Reshape to be a Constant node." - value = _ir_utils.propagate_const_value(reshape_node.inputs[1]) - constant_numpy_value = _ir_utils.get_numpy_from_ir_value(value) - if constant_numpy_value.shape[0] == 4: - num_attention_heads = constant_numpy_value[2] - head_size = constant_numpy_value[3] - return AttnSizeConfig( - num_attention_heads=num_attention_heads, - num_key_value_heads=None, - head_size=head_size, - hidden_size=hidden_size, - ) - raise function_rule.FunctionRewriteError( - "Failed to infer head size and number of heads from QKV Reshape nodes. \ - Expected 4D shape in the constant node (batch_size, seq_length, num_attention_heads, head_size)." - ) - raise function_rule.FunctionRewriteError( - f"Attenion modules should have 3 outputs or scaled_dot_product_attention node, " - f"got output: {len(function.outputs)} and no scaled_dot_product_attention." - ) - - -class MHALlama2RewriteRule(AttentionRewriteRule): - FUNCTION_KEYWORD = "LlamaAttention" - PACKAGE_NAME = "transformers" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version(min_version="4.33", max_version="4.36") - def _fusion_with_4d_cache(self, function: ir.Function) -> ir.Function: - if len(function.inputs) != 9: - raise function_rule.FunctionRewriteError( - f"Unexpected number of inputs. Expected 9, got {len(function.inputs)}." - ) - - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_op = onnxscript.values.Opset("com.microsoft", 1) - - # Workaround onnxscript error by specifying the output shape here. - cos_sin_gather_size = [attn_size_config.head_size // 2] - expand_shape = [1, attn_size_config.num_attention_heads, 1, 1] - - def mha( - hidden_states, - position_id, - attention_mask, - q_proj_weight, - k_proj_weight, - v_proj_weight, - cos_cached, - sin_cached, - o_proj_weight, - ): - q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) - k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) - v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) - - # TODO(onnxscript) - # ValueError: ERROR: Unsupported expression type . - # at: Function 'mha', line 16 - # cos = op.Slice(op.Squeeze(cos_cached, [0, 1]), [0], [cos_sin_gather_size], [1]) - # NOTE: Depending on transformers version, the shape of cos/sin is different. - # In later version, the shape is [seq_len, head_size], so the Squeeze is not needed. - # In this version, the shape is [1, 1, seq_len, head_size], hence the below Squeeze. - cos = op.Slice(op.Squeeze(cos_cached, [0, 1]), [0], cos_sin_gather_size, [1]) - sin = op.Slice(op.Squeeze(sin_cached, [0, 1]), [0], cos_sin_gather_size, [1]) - - q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) - k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) - - # TODO(onnxscript) - # ValueError: ERROR: Unsupported expression type . - # expanded_mask = op.Expand(attention_mask, [1, self.num_heads, 1, 1]) - expanded_mask = op.Expand(attention_mask, expand_shape) - - mha_output, present_key, present_value = msft_op.MultiHeadAttention( - q_rope, - k_rope, - v, - None, - None, - expanded_mask, - num_heads=attn_size_config.num_attention_heads, - ) - attn_output = op.MatMul(mha_output, op.Transpose(o_proj_weight, [1, 0])) - return present_value, present_key, attn_output - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - mha - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - @_version_controller.register_version(min_version="4.36", max_version="4.38") - def _fusion_with_2d_cache(self, function: ir.Function) -> ir.Function: - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - - if len(function.inputs) != 9: - raise function_rule.FunctionRewriteError( - f"Unexpected number of inputs. Expected 9, got {len(function.inputs)}." - ) - - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_op = onnxscript.values.Opset("com.microsoft", 1) - - # Workaround onnxscript error by specifying the output shape here. - cos_sin_gather_size = [attn_size_config.head_size // 2] - expand_shape = [1, attn_size_config.num_attention_heads, 1, 1] - - def mha( - hidden_states, - position_id, - attention_mask, - q_proj_weight, - k_proj_weight, - v_proj_weight, - cos_cached, - sin_cached, - o_proj_weight, - ): - q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) - k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) - v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) - - cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) - sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) - - q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) - k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) - - # TODO(onnxscript) - # ValueError: ERROR: Unsupported expression type . - # expanded_mask = op.Expand(attention_mask, [1, self.num_heads, 1, 1]) - expanded_mask = op.Expand(attention_mask, expand_shape) - - mha_output, present_key, present_value = msft_op.MultiHeadAttention( - q_rope, - k_rope, - v, - None, - None, - expanded_mask, - num_heads=attn_size_config.num_attention_heads, - ) - attn_output = op.MatMul(mha_output, op.Transpose(o_proj_weight, [1, 0])) - return present_value, present_key, attn_output - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - mha - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - -class GQALlama2RewriteRule(AttentionRewriteRule): - FUNCTION_KEYWORD = "LlamaAttention" - PACKAGE_NAME = "transformers" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version(min_version="4.33", max_version="4.36") - def _fusion_with_4d_cache(self, function: ir.Function) -> ir.Function: - if len(function.inputs) != 9: - raise function_rule.FunctionRewriteError( - f"Unexpected number of inputs. Expected 9, got {len(function.inputs)}." - ) - - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_op = onnxscript.values.Opset("com.microsoft", 1) - - # Workaround onnxscript error by specifying the output shape here. - cos_sin_gather_size = [attn_size_config.head_size // 2] - - def gqa( - hidden_states, - position_id, - attention_mask, - q_proj_weight, - k_proj_weight, - v_proj_weight, - cos_cached, - sin_cached, - o_proj_weight, - ): - q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) - k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) - v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) - - # NOTE: Depending on transformers version, the shape of cos/sin is different. - # In later version, the shape is [seq_len, head_size], so the Squeeze is not needed. - # In this version, the shape is [1, 1, seq_len, head_size], hence the below Squeeze. - cos = op.Slice(op.Squeeze(cos_cached, [0, 1]), [0], cos_sin_gather_size, [1]) - sin = op.Slice(op.Squeeze(sin_cached, [0, 1]), [0], cos_sin_gather_size, [1]) - - q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) - k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) - - batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) - sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) - past_seq_lengths = op.ConstantOfShape( - batch_size, - value=onnx_helper.make_tensor( - "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] - ), - ) - total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) - - gqa_output, present_key, present_value = msft_op.GroupQueryAttention( - q_rope, - k_rope, - v, - None, - None, - past_seq_lengths, - total_seq_lengths, - kv_num_heads=attn_size_config.num_key_value_heads, - num_heads=attn_size_config.num_attention_heads, - ) - attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) - return present_value, present_key, attn_output - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - gqa - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - @_version_controller.register_version(min_version="4.36", max_version="4.38") - def _fusion_with_2d_cache(self, function: ir.Function) -> ir.Function: - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - - if len(function.inputs) != 9: - raise function_rule.FunctionRewriteError( - f"Unexpected number of inputs. Expected 9, got {len(function.inputs)}." - ) - - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_op = onnxscript.values.Opset("com.microsoft", 1) - - # Workaround onnxscript error by specifying the output shape here. - cos_sin_gather_size = [attn_size_config.head_size // 2] - - def gqa( - hidden_states, - position_id, - attention_mask, - q_proj_weight, - k_proj_weight, - v_proj_weight, - cos_cached, - sin_cached, - o_proj_weight, - ): - q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) - k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) - v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) - - cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) - sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) - - q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) - k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) - - batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) - sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) - past_seq_lengths = op.ConstantOfShape( - batch_size, - value=onnx_helper.make_tensor( - "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] - ), - ) - total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) - - gqa_output, present_key, present_value = msft_op.GroupQueryAttention( - q_rope, - k_rope, - v, - None, - None, - past_seq_lengths, - total_seq_lengths, - kv_num_heads=attn_size_config.num_key_value_heads, - num_heads=attn_size_config.num_attention_heads, - ) - attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) - return present_value, present_key, attn_output - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - gqa - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - -class GQALlamaSdpa2RewriteRule(AttentionRewriteRule): - # TODO: There are a lot of duplicated code with `MHALlama2RewriteRule`. - # The pitfall is that the source function signature is slightly different. - # One has `attention_mask` as input while the other does not. - # Possibly designing a function template system could help reduce the boilerplate. - FUNCTION_KEYWORD = "LlamaSdpaAttention" - PACKAGE_NAME = "transformers" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version(min_version="4.36", max_version="4.38") - def _fusion(self, function: ir.Function) -> ir.Function: - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_op = onnxscript.values.Opset("com.microsoft", 1) - - cos_sin_gather_size = [attn_size_config.head_size // 2] - - def gqa( - hidden_states, - position_id, - q_proj_weight, - k_proj_weight, - v_proj_weight, - cos_cached, - sin_cached, - o_proj_weight, - ): - q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) - k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) - v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) - - cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) - sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) - - q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) - k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) - - batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) - sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) - past_seq_lengths = op.ConstantOfShape( - batch_size, - value=onnx_helper.make_tensor( - "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] - ), - ) - total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) - - gqa_output, present_key, present_value = msft_op.GroupQueryAttention( - q_rope, - k_rope, - v, - None, - None, - past_seq_lengths, - total_seq_lengths, - kv_num_heads=attn_size_config.num_key_value_heads, - num_heads=attn_size_config.num_attention_heads, - ) - attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) - return present_value, present_key, attn_output - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - gqa - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - @_version_controller.register_version(min_version="4.38") - def _fusion_without_cos_sin_cache(self, function: ir.Function) -> ir.Function: - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_op = onnxscript.values.Opset("com.microsoft", 1) - - cos_sin_gather_size = [attn_size_config.head_size // 2] - - def gqa( - hidden_states, - position_id, - causal_mask, - cache_position, - q_proj_weight, - k_proj_weight, - v_proj_weight, - inv_freq, - o_proj_weight, - ): - q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) - k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) - v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) - - # In 4.38 and later, cos/sin are not cached, but computed on the fly. - # This can be further optimized by constant folding for scenarios where - # the position_id is known at compile time. - seq_len = op.Slice(op.Shape(hidden_states), [1], [2], [0]) - seq_len_scalar = op.Squeeze(seq_len, [0]) - t = op.Unsqueeze( - op.Cast(op.Range(0, seq_len_scalar, 1), to=onnx.TensorProto.FLOAT), [1] - ) - inv_freq = op.Cast(op.Unsqueeze(inv_freq, [0]), to=onnx.TensorProto.FLOAT) - freqs = op.MatMul(t, inv_freq) - - emb = op.Concat(freqs, freqs, axis=-1) - cos = op.CastLike(op.Cos(emb), hidden_states) - sin = op.CastLike(op.Sin(emb), hidden_states) - cos = op.Slice(cos, [0], cos_sin_gather_size, [1]) - sin = op.Slice(sin, [0], cos_sin_gather_size, [1]) - - q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) - k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) - - batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) - sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) - past_seq_lengths = op.ConstantOfShape( - batch_size, - value=onnx_helper.make_tensor( - "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] - ), - ) - total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) - - gqa_output, present_key, present_value = msft_op.GroupQueryAttention( - q_rope, - k_rope, - v, - None, - None, - past_seq_lengths, - total_seq_lengths, - kv_num_heads=attn_size_config.num_key_value_heads, - num_heads=attn_size_config.num_attention_heads, - ) - attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) - return present_value, present_key, attn_output - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - gqa - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - -class AttnPhi15RewriteRule(AttentionRewriteRule): - FUNCTION_KEYWORD = "PhiAttention" - PACKAGE_NAME = "transformers_modules" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version() - def _fusion(self, function: ir.Function) -> ir.Function: - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_opset = onnxscript.values.Opset("com.microsoft", 1) - - def phi_attention( - hidden_states, - position_id, - attention_mask, - q_proj_weight, - q_proj_bias, - k_proj_weight, - k_proj_bias, - v_proj_weight, - v_proj_bias, - cos_cached, - sin_cached, - dense_weight, - dense_bias, - ): - qkv_weight = op.Transpose( - op.Concat(q_proj_weight, k_proj_weight, v_proj_weight, axis=0), - perm=[1, 0], - ) - qkv_bias = op.Concat(q_proj_bias, k_proj_bias, v_proj_bias, axis=0) - - # [batch_size, sequence_length] - attention_mask_shape = op.Slice(op.Shape(hidden_states), [0], [2], [0]) - - # Create 2d mask to mimic 4d causal mask. - attention_mask = op.ConstantOfShape( - attention_mask_shape, - value=onnx_helper.make_tensor("mask_value", onnx.TensorProto.INT32, [1], [1]), - ) - attn_output, present = msft_opset.Attention( - hidden_states, - qkv_weight, - qkv_bias, - attention_mask, - unidirectional=1, - do_rotary=1, - # Attention.rotary_embedding_dim only supports 32, 64 or 128 - rotary_embedding_dim=attn_size_config.head_size // 2 // 32 * 32, - num_heads=attn_size_config.num_attention_heads, - ) - present_key = op.Gather(present, 0) - present_value = op.Gather(present, 1) - output = op.Add( - op.MatMul(attn_output, op.Transpose(dense_weight, [1, 0])), dense_bias - ) - - return present_value, present_key, output - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - phi_attention - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - -class MHAStableDiffusionUnetRewriteRule(AttentionRewriteRule): - """Rewrite rule for Attention in diffusers.""" - - FUNCTION_KEYWORD = "Attention" - PACKAGE_NAME = "diffusers" - _version_controller = function_rule.VersionController() - - @_version_controller.register_version() - def _fusion(self, function: ir.Function) -> ir.Function: - # Attention inputs could be 6 or 7: - # hidden_states, encoder_hidden_states(optional), q_weight, k_weight, v_weight, o_weight, o_bias - if len(function.inputs) != 6 and len(function.inputs) != 7: - raise function_rule.FunctionRewriteError( - f"Unexpected number of inputs. Expected 6 or 7, got {len(function.inputs)}." - ) - - # Infer size configurations from the function. - attn_size_config = self.infer_attn_size_config(function) - - # Code new pattern with onnxscript. - op = onnxscript.opset18 - msft_op = onnxscript.values.Opset("com.microsoft", 1) - - def attention( - hidden_states, - q_weight, - k_weight, - v_weight, - o_weight, - o_bias, - ): - qkv_weight = op.Transpose( - op.Concat(q_weight, k_weight, v_weight, axis=0), - perm=[1, 0], - ) - - # NOTE: MHA does not work when Q, K, and V has the same root inputs. - attn_output, _ = msft_op.Attention( - hidden_states, - qkv_weight, - None, - None, - num_heads=attn_size_config.num_attention_heads, - ) - - # linear projection - output = op.Add(op.MatMul(attn_output, op.Transpose(o_weight, [1, 0])), o_bias) - return output - - def mha( - hidden_states, - encoder_hidden_states, - q_weight, - k_weight, - v_weight, - o_weight, - o_bias, - ): - q = op.MatMul(hidden_states, op.Transpose(q_weight, [1, 0])) - k = op.MatMul(encoder_hidden_states, op.Transpose(k_weight, [1, 0])) - v = op.MatMul(encoder_hidden_states, op.Transpose(v_weight, [1, 0])) - - # NOTE: Q and K needs to have the sequence length (dim 1) to use - # GQA. - mha_output, _, _ = msft_op.MultiHeadAttention( - q, - k, - v, - None, - None, - num_heads=attn_size_config.num_attention_heads, - ) - attn_output = op.Add(op.MatMul(mha_output, op.Transpose(o_weight, [1, 0])), o_bias) - return attn_output - - if len(function.inputs) == 6: - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - attention - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) - - function_proto = onnxscript.script(default_opset=onnxscript.opset18)( - mha - ).to_function_proto() - return ir.serde.deserialize_function(function_proto) diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py deleted file mode 100644 index 1e2f1d51ca..0000000000 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py +++ /dev/null @@ -1,85 +0,0 @@ -from __future__ import annotations - -import unittest - -import numpy as np - -from tests.common import testutils - - -class MHAParityTest(unittest.TestCase): - def setUp(self): - np.random.seed(0) - - @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") - def test_attn_llama2_4_34(self): - testutils.test_onnxruntime_rewrite( - "attn_llama2_4_34", 2, {("com.microsoft", "GroupQueryAttention", "")} - ) - - @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") - def test_attn_llama2_4_36(self): - testutils.test_onnxruntime_rewrite( - "attn_llama2_4_36", 1, {("com.microsoft", "GroupQueryAttention", "")} - ) - - @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") - def test_attn_yi_4_37(self): - testutils.test_onnxruntime_rewrite( - "attn_yi_4_37", 1, {("com.microsoft", "GroupQueryAttention", "")} - ) - - @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") - def test_sdpa_llama2_4_36(self): - # TODO: Clean-up naming logic of test models. - # Package version was not considered. - testutils.test_onnxruntime_rewrite( - "sdpa_llama2", 4, {("com.microsoft", "GroupQueryAttention", "")} - ) - - @unittest.skip("TODO: Fails parity check") - def test_sdpa_llama2_4_38(self): - testutils.test_onnxruntime_rewrite( - "sdpa_llama2_4_38", 1, {("com.microsoft", "GroupQueryAttention", "")} - ) - - @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") - def test_sdpa_yi_4_36(self): - testutils.test_onnxruntime_rewrite( - "sdpa_yi", 2, {("com.microsoft", "GroupQueryAttention", "")} - ) - - @unittest.skip("TODO: Fails parity check") - def test_sdpa_yi_4_38(self): - testutils.test_onnxruntime_rewrite( - "sdpa_yi_4_38", 1, {("com.microsoft", "GroupQueryAttention", "")} - ) - - @testutils.skip_if_no_cuda("CPU has parity issue.") - def test_attn_stable_diffusion_unet(self): - testutils.test_onnxruntime_rewrite( - "attn_stable_diffusion_unet", 2, {("com.microsoft", "MultiHeadAttention", "")} - ) - - -class AttnParityTest(unittest.TestCase): - def setUp(self): - np.random.seed(0) - - @testutils.skip_if_no_cuda("CPU has parity issue.") - def test_attn_phi_1_5(self): - testutils.test_onnxruntime_rewrite( - "attn_phi_1_5", 4, {("com.microsoft", "Attention", "")} - ) - - @testutils.skip_if_no_cuda("CPU has parity issue.") - def test_attn_stable_diffusion_unet_without_encoder_hidden_states(self): - testutils.test_onnxruntime_rewrite( - "attn_stable_diffusion_unet_without_encoder_hidden_states", - 2, - {("com.microsoft", "Attention", "")}, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/__init__.py b/onnxscript/rewriter/ort_fusions/__init__.py new file mode 100644 index 0000000000..963fb47ef8 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fusion optimizations for ORT backend.""" + +__all__ = [ + "optimize_for_ort", + "ORT_PATTERN_REWRITE_RULES", +] + + +from onnxscript.rewriter.ort_fusions._core import ORT_PATTERN_REWRITE_RULES, optimize_for_ort diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py new file mode 100644 index 0000000000..ea7af31b3e --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir +import onnx_ir.passes.common as common_passes + +import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets +import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization +from onnxscript.optimizer import optimize +from onnxscript.rewriter import rewrite +from onnxscript.rewriter.ort_fusions import ( + instance_to_group_normalization, + softmax, +) +from onnxscript.rewriter.ort_fusions.attention import fuse_attention +from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu +from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu +from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu +from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa +from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa +from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2 +from onnxscript.rewriter.ort_fusions.mha_bias import fuse_mha_bias +from onnxscript.rewriter.ort_fusions.mha_scale import fuse_mha_scale +from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization +from onnxscript.rewriter.ort_fusions.rotary_embedding import ( + fuse_partial_rotary_embedding, + fuse_rotary_embedding, +) +from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa +from onnxscript.rewriter.ort_fusions.skip_normalization import ( + fuse_skip_layer_normalization, + fuse_skip_rms_normalization, +) +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add + +ORT_PATTERN_REWRITE_RULES = [ + *softmax.rules.rules, + *instance_to_group_normalization.rules.rules, + # NOTE: group normalization merge silu should be applied after instance to group normalization + # *group_normalization_merge_silu.rules.rules, + *fused_matmul_rule_sets.fused_matmul_rule_sets(), +] + + +# Preliminary optimizations before applying the transformer fusions. +# TODO: There are some potential redundancies below. Can be targeted for optimization +# once we have robust fusion. +def _pre_optimize(model: ir.Model) -> ir.Model: + # TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some + # extra shape-propagation and partial-data-propagation rules in ONNX that are not yet + # incorporated in our optimizer. + common_passes.ShapeInferencePass()(model) + optimize(model) + shape_optimization.rules.apply_to_model(model) + optimize(model) + return model + + +def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[str, int]]: + """ + Apply transformer-specific fusions to the given model. + + Args: + model: The input ONNX model represented as an `ir.Model`. + debug: If debug is True, enable pattern matching tracer for debugging. + + Returns: + A tuple containing: + - The optimized `ir.Model` after applying transformer-specific fusions. + - A dictionary with a count of each of the fusions applied. + """ + fusion_count = dict() + + model = _pre_optimize(model) + + def fuse(func, **kwargs): + return func(model, debug=debug, **kwargs) + + fusion_count["erf_gelu"] = fuse(fuse_erfgelu) + fusion_count["rms_normalization"] = fuse(fuse_rms_normalization) + fusion_count["skip_layer_normalization"] = fuse(fuse_skip_layer_normalization) + fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization) + fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding) + fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache) + common_passes.CommonSubexpressionEliminationPass()(model) + fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding) + + # We apply shape inference after the SDPA fusion as new nodes are added + # in the rewrite rule for certain patterns of SDPA. + fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True) + + fusion_count["gqa"] = fuse(fuse_gqa) + fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa) + fusion_count["mha1"] = fuse(fuse_mha1) + fusion_count["mha2"] = fuse(fuse_mha2) + fusion_count["mha_scale"] = fuse(fuse_mha_scale) + if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0): + fusion_count["mha_bias"] = 0 + fusion_count["attention"] = 0 + else: + fusion_count["mha_bias"] = fuse(fuse_mha_bias) + fusion_count["attention"] = fuse(fuse_attention) + fusion_count["gelu"] = fuse(fuse_gelu) + fusion_count["bias_gelu"] = fuse(fuse_bias_gelu) + # Finally: inline any intermediate fusion functions introduced that were not + # consumed by other fusions, and eliminate any remaining unused nodes. + optimize(model) + return model, fusion_count + + +def optimize_for_ort( + model: ir.Model, + config_name: str | None = None, + *, + debug: bool = False, +) -> tuple[ir.Model, dict[str, int]]: + """ + Optimize the model for ORT backend. + + TODO: config_name is not used yet. It should be used to select the appropriate + optimization configuration (for an EP). Currently, a default implementation is used. + + Args: + model: The model to optimize. + config_name: The name of the configuration to use for optimization. + Typically it identifies the Execution Provider (EP) to optimize for. + If None, the default configuration will be used. + debug: If debug is True, enable pattern matching tracer for debugging. + + Returns: + A tuple containing: + - The optimized `ir.Model` after applying transformer-specific fusions. + - A dictionary with a count of each of the fusions applied. + """ + rewrite(model, [_gemm_to_matmul_add.gemm_to_matmul_add_rule]) + model, fusion_count = fuse_xformers( + model, + debug=debug, + ) + # Apply the ORT pattern rewrite rules. + rewrite(model, ORT_PATTERN_REWRITE_RULES) + + passes = ir.passes.Sequential( + # Apply the ORT optimization passes. + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 + common_passes.ClearMetadataAndDocStringPass(), + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 + common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1), + common_passes.RemoveInitializersFromInputsPass(), + common_passes.ShapeInferencePass(), + ) + assert passes.in_place + result = passes(model) + assert result.model is model + return model, fusion_count diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py new file mode 100644 index 0000000000..24e9bcce61 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnx_ir as ir +import onnxruntime +import packaging.version + +ORT_VERSION = packaging.version.Version(onnxruntime.__version__) + + +def ort_run(model_name: str, model, inputs): + providers = ["CPUExecutionProvider"] + model_proto = ir.serde.serialize_model(model) + options = onnxruntime.SessionOptions() + options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + session = onnxruntime.InferenceSession( + model_proto.SerializeToString(), options, providers=providers + ) + return session.run(None, inputs) + + +def assert_allclose(outputs, expected_outputs, rtol=1e-3, atol=1e-3): + for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)): + try: + np.testing.assert_equal(baseline_output.shape, optimized_output.shape) + np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol) + except AssertionError as e: + diff_mask = ~np.isclose(baseline_output, optimized_output, rtol=rtol, atol=atol) + diff = np.where(diff_mask, "X", " ") + print(diff) + print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}") + raise diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py new file mode 100644 index 0000000000..ce234bbb63 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -0,0 +1,340 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +Dim = Union[int, ir.SymbolicDim] + + +# TODO: Maybe add this check to utilities + + +class AttentionFusion(pattern.RewriteRuleClassBase): + def __init__( + self, + name, + *, + has_past: bool, + no_slice: bool, + ): + super().__init__(name) + self._has_past = has_past + self._no_slice = no_slice + + def pattern( + self, + op, + input, + qkv_weight, + qkv_bias, + # mask_index, + past, + num_heads, + # scale, + start1, + end1, + start2, + end2, + start3, + end3, + q_mul, + k_mul, + v_mul, + ): + if self._no_slice: + query_BSD = op.MatMul(input, q_mul) + key_BSD = op.MatMul(input, k_mul) + value_BSD = op.MatMul(input, v_mul) + else: + projected = op.MatMul(input, qkv_weight, _outputs=["projected"]) + + # Slice packed Matmul QKV into Q, K, and V + # Q, K, and V are of shape (B, S, D) + query_BSD = op.Slice( + projected, + start1, # starts + end1, # ends + [2], # axes + _outputs=["query_mm_sliced"], + ) + key_BSD = op.Slice( + projected, + start2, # starts + end2, # ends + [2], # axes + _outputs=["key_mm_sliced"], + ) + value_BSD = op.Slice( + projected, + start3, # starts + end3, # ends + [2], # axes + _outputs=["value_mm_sliced"], + ) + + # TODO: Add other attributes + + if self._has_past: + # Split past into past_key and past_value + # past_key and past_value are of shape (B, H, S, D/H) + past_key = op.Slice( + past, + [0], # starts + [1], # ends + [0], # axes + _outputs=["past_key_sliced"], + ) + past_key = op.Squeeze(past_key, [0]) + past_value = op.Slice( + past, + [1], # starts + [2], # ends + [0], # axes + _outputs=["past_value_sliced"], + ) + past_value = op.Squeeze(past_value, [0]) + + attention, present_key, present_value = op.MultiHeadAttention( + query_BSD, + key_BSD, + value_BSD, + qkv_bias, + None, # key_padding_mask + pattern.Var("attention_bias", can_match_none=True), + past_key, + past_value, + num_heads=num_heads, + # scale=scale, + _domain="com.microsoft", + _outputs=["mha_output", "present_key", "present_value"], + ) + # Concat present_key and present_value to form present + present_key = op.Unsqueeze(present_key, [0]) + present_value = op.Unsqueeze(present_value, [0]) + present = op.Concat(present_key, present_value, axis=0) + # Return present output first as it captures the complete pattern graph + return present, attention + else: + attention = op.MultiHeadAttention( + query_BSD, + key_BSD, + value_BSD, + qkv_bias, + None, # key_padding_mask + pattern.Var("attention_bias", can_match_none=True), + None, # past_key + None, # past_value + num_heads=num_heads, + # scale=scale, + _domain="com.microsoft", + _outputs=["mha_output"], + ) + return attention + + def check( + self, + op, + input, + qkv_weight, + projected=None, + query_mm_sliced=None, + key_mm_sliced=None, + value_mm_sliced=None, + start1=None, + end1=None, + start2=None, + end2=None, + start3=None, + end3=None, + q_mul=None, + k_mul=None, + v_mul=None, + **_, + ): + check_result = pattern.MatchResult() + self.bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) + + if no_match(input, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']", + input, + ) + if not self._no_slice: + # Ensure slicing is done correctly + if projected is None or projected.shape is None or len(projected.shape) != 3: + return check_result.fail("Input projection is not a 3D tensor.", projected) + hidden_size = projected.shape[2] + if not isinstance(hidden_size, int): + return check_result.fail("Hidden size is not an integer.", projected) + if not ( + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.get_singleton_value(end1) + == _ir_utils.get_singleton_value(start2) + and _ir_utils.get_singleton_value(end2) + == _ir_utils.get_singleton_value(start3) + and _ir_utils.is_singleton_value(end3, lambda x: x >= hidden_size) + ): + return check_result.fail( + "Projected input is not being split into q, k, v correctly based on hidden sizes.", + projected, + ) + + if no_match(qkv_weight, ["D", "Dh"]): + return check_result.fail( + f"Shape mismatch: {qkv_weight} does not match expected dimensions ['D', 'Dh']", + qkv_weight, + ) + if no_match(query_mm_sliced, ["B", "S", "Dh_q"]): + return check_result.fail( + f"Shape mismatch: {query_mm_sliced} does not match expected dimensions ['B', 'S', 'Dh_q']", + query_mm_sliced, + ) + if no_match(key_mm_sliced, ["B", "S", "Dh_k"]): + return check_result.fail( + f"Shape mismatch: {key_mm_sliced} does not match expected dimensions ['B', 'S', 'Dh_k']", + key_mm_sliced, + ) + if no_match(value_mm_sliced, ["B", "S", "Dh_v"]): + return check_result.fail( + f"Shape mismatch: {value_mm_sliced} does not match expected dimensions ['B', 'S', 'Dh_v']", + value_mm_sliced, + ) + else: + if no_match(q_mul, ["D", "Dh_q"]): + return check_result.fail( + f"Shape mismatch: {q_mul} does not match expected dimensions ['D', 'Dh_q']", + q_mul, + ) + if no_match(k_mul, ["D", "Dh_k"]): + return check_result.fail( + f"Shape mismatch: {k_mul} does not match expected dimensions ['D', 'Dh_k']", + k_mul, + ) + if no_match(v_mul, ["D", "Dh_v"]): + return check_result.fail( + f"Shape mismatch: {v_mul} does not match expected dimensions ['D', 'Dh_v']", + v_mul, + ) + + # Ensure Dh = Dh_q + Dh_k + Dh_v + Dh = self.bindings.get("Dh") + Dh_q = self.bindings.get("Dh_q") + Dh_k = self.bindings.get("Dh_k") + Dh_v = self.bindings.get("Dh_v") + + if not isinstance(Dh_q, int) or not isinstance(Dh_k, int) or not isinstance(Dh_v, int): + return check_result.fail( + "Could not determine the hidden sizes of query, key, and value.", + ) + + if not self._no_slice: + if not isinstance(Dh, int): + return check_result.fail( + "Could not determine the total hidden size of weight.", + ) + + if Dh != Dh_q + Dh_k + Dh_v: # type: ignore[operator] + return check_result.fail( + f"Hidden size of query, key and value do not add up to hidden size: {Dh} != {Dh_q} + {Dh_k} + {Dh_v}", + ) + + # TODO: Add mask check once mask is added to the pattern + return check_result + + def rewrite( + self, + op, + input, + qkv_weight, + qkv_bias, + # mask_index, + past, + attention_bias, + num_heads, + # scale, + mha_output, + q_mul=None, + k_mul=None, + v_mul=None, + **_, + ): + # Use bindings to get the values of Dh_q, Dh_k, and Dh_v + # and construct qkv_hidden_sizes + Dh_q = self.bindings.get("Dh_q") + Dh_k = self.bindings.get("Dh_k") + Dh_v = self.bindings.get("Dh_v") + qkv_hidden_sizes = [Dh_q, Dh_k, Dh_v] + if self._no_slice: + qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=1) + + scale = mha_output.producer().attributes.get_float("scale", None) + + if self._has_past: + attention, present = op.Attention( + input, + qkv_weight, + qkv_bias, + None, # mask_index + past, + attention_bias, + # past_sequence_length + num_heads=num_heads, + qkv_hidden_sizes=qkv_hidden_sizes, + scale=scale, + _domain="com.microsoft", + _outputs=2, + ) + # Use same output ordering as in pattern + return present, attention + else: + return op.Attention( + input, + qkv_weight, + qkv_bias, + None, # mask_index + None, # past + attention_bias, + None, # past_sequence_length + num_heads=num_heads, + qkv_hidden_sizes=qkv_hidden_sizes, + scale=scale, + _domain="com.microsoft", + _outputs=1, + ) + + +# Define all combinations of parameters +parameter_combinations = [ + { + "name": f"attention_{'with_past_' if has_past else ''}{'no_slice' if no_slice else ''}".strip( + "_" + ), + "has_past": has_past, + "no_slice": no_slice, + } + for has_past in [False, True] + for no_slice in [False, True] +] + +# Dynamically create the rules +attention_rules = pattern.RewriteRuleSet( + [ + AttentionFusion.rule( + params["name"], + has_past=params["has_past"], + no_slice=params["no_slice"], + ) + for params in parameter_combinations + ] +) + + +fuse_attention = _fusion_utils.apply_fusion_rules(attention_rules) diff --git a/onnxscript/rewriter/ort_fusions/attention_test.py b/onnxscript/rewriter/ort_fusions/attention_test.py new file mode 100644 index 0000000000..4559bc205c --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/attention_test.py @@ -0,0 +1,195 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import onnx_ir as ir +import onnx_ir.passes.common as common_passes +import packaging.version +import parameterized + +import onnxscript +import onnxscript.optimizer +import onnxscript.rewriter.ort_fusions._core as xformers +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.rewriter.models._whisper_encoder import whisper_encoder_test +from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run + +msft_op = onnxscript.values.Opset("com.microsoft", 1) + + +class TestAttentionFusion(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.batchsize = 2 + self.seqlen = 8 + self.past_seqlen = 32 + self.headsize = 16 + self.num_heads = 10 + self.input_hidden_size = self.headsize * self.num_heads + self.q_hidden_size = 160 + self.k_hidden_size = 160 + self.v_hidden_size = 160 + + def random_inputs(self, with_past=False): + """Generate random inputs for the model.""" + B = self.batchsize + S = self.seqlen + Sp = self.past_seqlen + D = self.input_hidden_size + N = self.num_heads + H = self.headsize + D_qkv = self.q_hidden_size + self.k_hidden_size + self.v_hidden_size + + inputs = { + "input": np.random.rand(B, S, D).astype(np.float32), + "weight": np.random.rand(D, D_qkv).astype(np.float32), + "bias": np.random.rand(D_qkv).astype(np.float32), + } + if with_past: + inputs["past"] = np.random.rand(2, B, N, Sp, H).astype(np.float32) + return inputs + + def create_model(self, with_past=False): + """Create a model with or without past inputs.""" + D = self.input_hidden_size + D_qkv = self.q_hidden_size + self.k_hidden_size + self.v_hidden_size + + @script() + def model_with_mha(input, weight, bias): + qkv = op.MatMul(input, weight) + + query_BSDh = op.Slice(qkv, [0], [160], [2]) + key_BSDh = op.Slice(qkv, [160], [320], [2]) + value_BSDh = op.Slice(qkv, [320], [480], [2]) + + mha = msft_op.MultiHeadAttention( + query_BSDh, + key_BSDh, + value_BSDh, + bias, + None, + None, + None, + None, + num_heads=self.num_heads, + ) + return mha + + @script() + def model_with_mha_past(input, weight, bias, past): + qkv = op.MatMul(input, weight) + + query_BSDh = op.Slice(qkv, [0], [160], [2]) + key_BSDh = op.Slice(qkv, [160], [320], [2]) + value_BSDh = op.Slice(qkv, [320], [480], [2]) + + past_key_5d = op.Slice(past, [0], [1], [0]) + past_value_5d = op.Slice(past, [1], [2], [0]) + past_key = op.Squeeze(past_key_5d, [0]) + past_value = op.Squeeze(past_value_5d, [0]) + + mha, present_key, present_value = msft_op.MultiHeadAttention( + query_BSDh, + key_BSDh, + value_BSDh, + bias, + None, + None, + past_key, + past_value, + num_heads=self.num_heads, + ) + + present_key = op.Unsqueeze(present_key, [0]) + present_value = op.Unsqueeze(present_value, [0]) + present = op.Concat(present_key, present_value, axis=0) + return mha, present + + input_types = ( + FLOAT["B", "S", D], + FLOAT[D, D_qkv], + FLOAT[D_qkv], + ) + output_types = (FLOAT["B", "S", self.v_hidden_size],) + + if with_past: + # "T" indicates total sequence length (after concatenation of past and current key/value) + input_types += (FLOAT[2, "B", self.num_heads, "S", self.headsize],) + output_types += (FLOAT[2, "B", self.num_heads, "T", self.headsize],) + model_proto = model_with_mha_past.to_model_proto( + input_types=input_types, + output_types=output_types, + ) + else: + model_proto = model_with_mha.to_model_proto( + input_types=input_types, + output_types=output_types, + ) + return ir.serde.deserialize_model(model_proto) + + @parameterized.parameterized.expand( + [ + ("without_past", False), + ("with_past", True), + ] + ) + def test_model_with_mha(self, name, with_past): + """Test the model with or without past inputs.""" + inputs = self.random_inputs(with_past=with_past) + model = self.create_model(with_past=with_past) + model = common_passes.ShapeInferencePass()(model).model + + test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION + if test_with_ort: + # Run model + original_outputs = ort_run("original", model, inputs) + + # Fuse Attention + attention_count = xformers.fuse_attention(model, debug=True) + self.assertGreater(attention_count, 0) + + if test_with_ort: + # Run model again + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_whisper_encoder(self): + # Generate model + whisper_encoder = whisper_encoder_test() + model = whisper_encoder.get_onnx_model() + onnxscript.optimizer.optimize(model) + + test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION + if test_with_ort: + # Run model + inputs = whisper_encoder.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + # Fuse SDPA and MHA + sdpa_count = xformers.fuse_sdpa(model) + self.assertGreater(sdpa_count, 0) + model = common_passes.ShapeInferencePass()(model).model + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) + self.assertGreater(mha_count, 0) + mha_scale_count = xformers.fuse_mha_scale(model) + self.assertGreater(mha_scale_count, 0) + fused_mha_bias_count = xformers.fuse_mha_bias(model) + self.assertGreater(fused_mha_bias_count, 0) + # TODO: Enable once source of discrepancy is found + # attention_count = xformers.fuse_attention(model) + # self.assertGreater(attention_count, 0) + onnxscript.optimizer.optimize(model) + + if test_with_ort: + # Run model again + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu.py b/onnxscript/rewriter/ort_fusions/bias_gelu.py new file mode 100644 index 0000000000..eff36e8940 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/bias_gelu.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + + +class BiasGeluFusion(pattern.RewriteRuleClassBase): + """Fuses a Bias-Gelu pattern into a single BiasGelu operator. + + Attributes: + contrib_op (bool): If True, matches the Gelu operator from the 'com.microsoft' domain. + If False, matches the standard ONNX Gelu operator. + """ + + def __init__( + self, + name: str, + *, + contrib_op: bool, + ): + super().__init__(name) + self._contrib_op = contrib_op + + def pattern(self, op, input, bias): + gelu_add = op.Add(input, bias) + + if self._contrib_op: + return op.Gelu(gelu_add, _domain="com.microsoft", _outputs=["gelu"]) + else: + return op.Gelu(gelu_add, _outputs=["gelu"]) + + def check(self, op, gelu, input, bias, **_) -> pattern.MatchResult: + check_result = pattern.MatchResult() + approximate = gelu.producer().attributes.get_string("approximate") + if approximate is not None and approximate == "tanh": + return check_result.fail( + "Gelu operator with 'approximate' set to 'tanh' is not supported." + ) + + if not _ir_utils.has_rank(bias, 1): + return check_result.fail("bias is not of shape 1D tensor", bias) + + return check_result + + def rewrite(self, op, input, bias, **_): + return op.BiasGelu(input, bias, _domain="com.microsoft") + + +bias_gelu_rules = pattern.RewriteRuleSet( + [ + *BiasGeluFusion.rule("gelu_onnx_op", contrib_op=False).commute(), + *BiasGeluFusion.rule("gelu_contrib_op", contrib_op=True).commute(), + ] +) + + +fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py new file mode 100644 index 0000000000..964fed6285 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx_ir as ir +import parameterized + +import onnxscript +import onnxscript.rewriter.ort_fusions._test_utils as test_utils +from onnxscript import FLOAT, OnnxFunction, script +from onnxscript import opset20 as op +from onnxscript.optimizer import optimize, remove_unused_nodes +from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu + +msft_op = onnxscript.values.Opset("com.microsoft", 1) + + +@script() +def _test_script_onnx_default(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(x, y) + return op.Gelu(gelu_add) + + +@script() +def _test_script_onnx_none(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(x, y) + return op.Gelu(gelu_add, approximate="none") + + +@script() +def _test_script_msft_op(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(x, y) + return msft_op.Gelu(gelu_add) + + +@script() +def _test_script_reversed_order(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(y, x) + return op.Gelu(gelu_add) + + +@script() +def _test_script_onnx_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(x, y) + return op.Gelu(gelu_add, approximate="tanh") + + +@script() +def _test_script_shape_unsupported(x: FLOAT[10, 10], y: FLOAT[10]) -> FLOAT[10]: + gelu_add = op.Add(x, x) + return op.Gelu(gelu_add) + + +class BiasGeluFusionTest(unittest.TestCase): + def _check( + self, + test_data_constructor: OnnxFunction, + expected_graph_len: int, + expected_op_type: str, + ): + """Helper method to run a fusion test scenario.""" + model_proto = test_data_constructor.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + optimize(model) + + input = { + "x": np.random.randn(10, 10).astype(np.float32), + "y": np.random.randn(10).astype(np.float32), + } + original_output = test_utils.ort_run("Original", model, input) + + fuse_bias_gelu(model) + remove_unused_nodes(model) + + self.assertEqual(len(model.graph), expected_graph_len) + self.assertEqual(model.graph.node(0).op_type, expected_op_type) + + optimized_output = test_utils.ort_run("Optimized", model, input) + test_utils.assert_allclose(original_output, optimized_output) + + @parameterized.parameterized.expand( + [ + ("with_onnx_op_default", _test_script_onnx_default, 1, "BiasGelu"), + ("with_onnx_op_none", _test_script_onnx_none, 1, "BiasGelu"), + ("with_contrib_op", _test_script_msft_op, 1, "BiasGelu"), + ("reversed_order", _test_script_reversed_order, 1, "BiasGelu"), + ] + ) + def test_bias_gelu_fusion( + self, + _, + test_data_constructor: OnnxFunction, + expected_graph_len: int, + expected_op_type: str, + ): + self._check(test_data_constructor, expected_graph_len, expected_op_type) + + @parameterized.parameterized.expand( + [ + ("approximate_tanh", _test_script_onnx_unsupported, 2, "Add"), + ("unsupported_shape", _test_script_shape_unsupported, 2, "Add"), + ] + ) + def test_bias_gelu_fusion_unsupported_attr( + self, + _, + test_data_constructor: OnnxFunction, + expected_graph_len: int, + expected_op_type: str, + ): + self._check(test_data_constructor, expected_graph_len, expected_op_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py new file mode 100644 index 0000000000..cba06d2fb7 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -0,0 +1,232 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. + +# We match against the following code pattern: +# Original code (from transformers) for computing cos/sin cache for RoPE: +# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 +# position_ids_expanded = position_ids[:, None, :].float() +# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) +# emb = torch.cat((freqs, freqs), dim=-1) +# cos = emb.cos() +# sin = emb.sin() +# +# We rewrite this pattern into the following form: +# inv_freq_values = inv_freq_expanded.reshape(1, -1) +# pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1) +# angles = np.matmul(pos_id_range, inv_freq_values) +# cos_value = np.cos(angles) +# sin_value = np.sin(angles) +# cos_2d = op.Constant(value=ir.tensor(cos_value)) +# sin_2d = op.Constant(value=ir.tensor(sin_value)) +# +# This produces cos/sin values in a form that can be used by ORT's custom ops. + + +class CosSinCacheFusion(pattern.RewriteRuleClassBase): + def __init__( + self, + name: str, + *, + cast: bool = False, + reshape: bool = False, + const_freqs: bool = False, + ): + # This pattern makes use of shared Cos/Sin values. So, we can't remove the + # matched nodes as part of the rewrite-step. We apply a separate final + # pass to remove unused nodes. + super().__init__(name, remove_nodes=False) + # TODO: Determine what should be the default max_pos_id value + self._max_pos_id = None + # map from inv_freq to (cos, sin) values for transformed graph + self._inv_freq_cos_sin_cache: dict[ir.Value, tuple[ir.Value, ir.Value]] = {} + self._reshape = reshape + self._cast = cast + self._const_freqs = const_freqs + + @property + def max_pos_id(self) -> int | None: + return self._max_pos_id + + @max_pos_id.setter + def max_pos_id(self, max_pos_id: int): + self._max_pos_id = max_pos_id # type: ignore[assignment] + + def _compute_const_freqs(self, op, angles: np.ndarray): + """Compute cos/sin values when frequencies are constant.""" + cos_value = np.cos(angles) + sin_value = np.sin(angles) + cos_2d = op.Constant(value=ir.tensor(cos_value)) + sin_2d = op.Constant(value=ir.tensor(sin_value)) + return cos_2d, sin_2d + + def _compute_dynamic_freqs(self, op, inv_freq, position_ids, dtype): + """Compute cos/sin values dynamically based on inv_freq and position_ids.""" + if self._max_pos_id is not None: + # Use max_pos_id from the model metadata + max_pos_id = self._max_pos_id + elif position_ids.const_value is not None: + # Calculate max_pos_id from the position_ids tensor + max_pos_id = int(np.max(position_ids.const_value.numpy())) + else: + # Dynamically compute max_pos_id from position_ids using ONNX ops + inv_freq = op.Reshape(inv_freq, op.Constant(value_ints=[1, -1])) + max_pos_id = op.ReduceMax(position_ids, keepdims=0) + max_pos_id = op.Add(max_pos_id, op.Constant(value_int=1)) + pos_id_range = op.Range( + op.Constant(value_int=0), + max_pos_id, + op.Constant(value_int=1), + ) + pos_id_range = op.Reshape(pos_id_range, op.Constant(value_ints=[-1, 1])) + pos_id_range = op.Cast(pos_id_range, to=ir.DataType.FLOAT) + # Compute angles and cos/sin values + angles = op.MatMul(pos_id_range, inv_freq) + cos_2d = op.Cos(angles) + sin_2d = op.Sin(angles) + return cos_2d, sin_2d + + # If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids + # to compute angles and cos/sin values + # Note: The one is added to max_pos_id as position_ids are 0-indexed + # and the range of position ids should be [0, max_pos_id], max_pos_id inclusive. + inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) + pos_id_range = np.arange(max_pos_id + 1, dtype=np.float32).reshape(-1, 1) + angles = np.matmul(pos_id_range, inv_freq_values) + return self._compute_const_freqs(op, angles) + + def cleanup(self): + self._inv_freq_cos_sin_cache.clear() + + def pattern( + self, + op, + x, + inv_freq, + position_ids, + interleaved, + num_heads, + freqs, + dtype, + extra_dims, + ): + if not self._const_freqs: + # Compute freqs from inv_freq and position_ids. In the _const_freqs case, + # this computation has been constant-folded away and freqs is a constant. + # B: batch size, S: sequence length, E: embedding dimension + # position_ids: [B, S] or [S] + # inv_freq: [1, E, 1] + position_ids_expanded = op.Unsqueeze( + position_ids, extra_dims + ) # [B, S] | [S] => [B, 1, S] + position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + # if self._reshape: + # position_ids_expanded = op.Expand(position_ids_expanded, _allow_other_inputs=True) + # position_ids_expanded = op.Reshape(position_ids_expanded, _allow_other_inputs=True) + # inv_freq may optionally be expanded to shape [B, E, 1] + inv_freq = pattern.OrValue( + [ + op.Expand(inv_freq, pattern.ANY_VALUE, _outputs=["expanded_inv_freq"]), + inv_freq, + ] + ) + freqs = op.MatMul(inv_freq, position_ids_expanded) # [B, E, S] + # if self._reshape: + # freqs = op.Reshape(freqs, freqs_3d_shape) # redundant reshape + freqs = op.Transpose(freqs, perm=[0, 2, 1]) # [B, S, E] + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(emb) + if self._cast: + cos = op.Cast(cos, to=dtype) + sin = op.Sin(emb) + if self._cast: + sin = op.Cast(sin, to=dtype) + cos_4d = op.Unsqueeze(cos, 1) # convert + sin_4d = op.Unsqueeze(sin, 1) + return op.RotaryEmbedding( + x, + cos_4d, + sin_4d, + interleaved=interleaved, + num_heads=num_heads, + _domain="ai.onnxruntime._fusion", + ) + + def check( + self, context, inv_freq, position_ids, freqs, extra_dims, expanded_inv_freq=None, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() + # TODO(rama): handle redundant reshape/expand + if self._const_freqs: + if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3): + return check_result.fail("freqs is not a constant or not 3D.", freqs) + else: + return check_result + if ( + _ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1) + ) or ( + _ir_utils.has_rank(position_ids, 1) and _ir_utils.is_1d_value(extra_dims, [0, 1]) + ): + pass + else: + return check_result.fail("position_ids is not a 1D or 2D tensor.", position_ids) + if not _ir_utils.has_rank(inv_freq, 3): + return check_result.fail("inv_freq is not 3D.", inv_freq) + inv_freq_shape = inv_freq.shape + if expanded_inv_freq is not None: + if not _ir_utils.has_rank(expanded_inv_freq, 3): + return check_result.fail("expanded_inv_freq is not 3D.", expanded_inv_freq) + # TODO: check expanded_inv_freq shape + if inv_freq.const_value is None: # TODO: should this be inv_freq_shape? + return check_result.fail("inv_freq is not a constant.", inv_freq) + if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1: + return check_result.fail("inv_freq is not of shape [1, ., 1].", inv_freq) + return check_result + + def rewrite( + self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, **_ + ): + if inv_freq in self._inv_freq_cos_sin_cache: + cos_2d, sin_2d = self._inv_freq_cos_sin_cache[inv_freq] + else: + # Compute cos/sin values based on whether frequencies are constant + if self._const_freqs: + cos_2d, sin_2d = self._compute_const_freqs(op, freqs.const_value.numpy()) + else: + cos_2d, sin_2d = self._compute_dynamic_freqs(op, inv_freq, position_ids, dtype) + if self._cast: + cos_2d = op.Cast(cos_2d, to=dtype) + sin_2d = op.Cast(sin_2d, to=dtype) + self._inv_freq_cos_sin_cache[inv_freq] = (cos_2d, sin_2d) + if _ir_utils.has_rank(position_ids, 1): + zero_1d = op.Constant(value_ints=[0]) + position_ids = op.Unsqueeze(position_ids, zero_1d) + return op.RotaryEmbedding( + x, + position_ids, + cos_2d, + sin_2d, + interleaved=interleaved, + num_heads=num_heads, + _domain="com.microsoft", + ) + + +_cast_const_freqs = CosSinCacheFusion.rule( + "CosSinCache_cast_const_freqs", cast=True, const_freqs=True +) +_cast = CosSinCacheFusion.rule("CosSinCache_cast", cast=True, const_freqs=False) +_const_freqs = CosSinCacheFusion.rule("CosSinCache_const_freqs", cast=False, const_freqs=True) +_basic = CosSinCacheFusion.rule("CosSinCache", cast=False) + +cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic]) + + +fuse_cos_sin_cache = _fusion_utils.apply_fusion_rules(cos_sin_cache_rules) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py new file mode 100644 index 0000000000..48842aa429 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +from parameterized import parameterized + +import onnxscript.optimizer +from onnxscript.rewriter.models import _rotary_embedding_models, _smollm_1 +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.ort_fusions.rotary_embedding import ( + fuse_partial_rotary_embedding, + fuse_rotary_embedding, +) + + +class TestCosSinCacheTransform(unittest.TestCase): + @parameterized.expand( + [ + ( + "smollm_test_1", + _smollm_1.smollm_test_1, + ), + ( + "test_case_1", + _rotary_embedding_models.test_case_1, + ), + ( + "test_case_2", + _rotary_embedding_models.test_case_2, + ), + ( + "partial_rotary_test_case", + _rotary_embedding_models.partial_rotary_test_case, + ), + ] + ) + def test_cos_sin_fusion(self, name, test_data_constructor): + test = test_data_constructor() + model = test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_rotary_embedding(model) + self.assertGreater(count, 0) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_partial_rotary_fusion(self): + test = _rotary_embedding_models.partial_rotary_test_case() + model = test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_rotary_embedding(model) + self.assertGreater(count, 0) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + count = fuse_partial_rotary_embedding(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/erfgelu.py b/onnxscript/rewriter/ort_fusions/erfgelu.py new file mode 100644 index 0000000000..ba515a5572 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/erfgelu.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math + +from onnxscript.rewriter import _fusion_utils, pattern + + +# Pattern to match against +def erf_gelu_pattern_1(op, x): + # erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + # half = pattern.Constant(0.5) + # sqrt2 = pattern.Constant(1.4142) + # x_div_sqrt2 = op.Div(x, sqrt2) + # erf = op.Erf(x_div_sqrt2) + # one = pattern.Constant(1.0) + # one_plus_erf = op.Add(erf, one) + # x_mul_one_plus_erf = op.Mul(x, one_plus_erf) + # return op.Mul(half, x_mul_one_plus_erf) + return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) + + +def erf_gelu_pattern_2(op, x): + return x * (0.5 * (op.Erf(x / math.sqrt(2)) + 1.0)) + + +# Replacement +def gelu(op, x): + return op.Gelu(x, _domain="com.microsoft") + + +rule1 = pattern.RewriteRule(erf_gelu_pattern_1, gelu) +rule2 = pattern.RewriteRule(erf_gelu_pattern_2, gelu) + +rules = pattern.RewriteRuleSet([rule1, rule2]) + +fuse_erfgelu = _fusion_utils.apply_fusion_rules(rules) diff --git a/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py b/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py new file mode 100644 index 0000000000..e7808ea699 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/fuse_xformers_test.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.models._smollm_1 import smollm_test_1 +from onnxscript.rewriter.ort_fusions._core import fuse_xformers +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run + + +class TestFuseXformers(unittest.TestCase): + def test_fuse_xformers(self): + test = smollm_test_1() + model = test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + model, fusion_count = fuse_xformers(model) + + # Check if the number of fusions applied for each fusion is correct + self.assertEqual(fusion_count["rms_normalization"], 3) + self.assertEqual(fusion_count["skip_layer_normalization"], 0) + self.assertEqual(fusion_count["skip_rms_normalization"], 2) + self.assertEqual(fusion_count["rotary_embedding"], 2) + self.assertEqual(fusion_count["partial_rotary_embedding"], 0) + self.assertEqual(fusion_count["cos_sin_cache"], 2) + self.assertEqual(fusion_count["sdpa"], 1) + self.assertEqual(fusion_count["mha1"] + fusion_count["mha2"], 1) + self.assertEqual(fusion_count["attention"], 0) + self.assertEqual(fusion_count["gqa"], 0) + self.assertEqual(fusion_count["gelu"], 0) + + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py new file mode 100644 index 0000000000..cdc50c99ae --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -0,0 +1,353 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import ClassVar + +import onnxscript.rewriter.pattern as orp +from onnxscript import ir +from onnxscript.rewriter import _ir_utils + + +def _get_node(value: ir.Value, name: str) -> ir.Node: + """Get the node from the output value.""" + node = value.producer() + assert node is not None, f"{name} node should not be None" + return node + + +def _get_kwargs(node: ir.Node) -> dict[str, float | int]: + """Get the kwargs from the node.""" + kwargs = {key: val.value for key, val in node.attributes.items()} + return kwargs + + +class FusedMatMulDiv1(orp.RewriteRuleClassBase): + """Replaces ``MatMul + Div`` with MatMul.""" + + def pattern(self, op, x, y, cst): + return op.Div(op.MatMul(x, y), cst) + + def check(self, context, x, y, cst) -> orp.MatchResult: + check_result = orp.MatchResult() + if cst.const_value is None: + return check_result.fail("Divisor is not a constant value.") + value = cst.const_value.numpy() + if value.size > 1: + return check_result.fail("Divisor is not a scalar value.") + return check_result + + def rewrite(self, op, x, y, cst): + value = cst.const_value.numpy() + c = float(value[0] if value.shape == (1,) else value) + return op.FusedMatMul(x, y, alpha=1 / c, _domain="com.microsoft") + + +class FusedMatMulDiv2(orp.RewriteRuleClassBase): + """Replaces ``FusedMatMul + Div`` with FusedMatMul.""" + + def pattern(self, op, x, y, cst): + return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), cst) + + def check(self, context, x, y, cst, **_) -> orp.MatchResult: + check_result = orp.MatchResult() + if cst.const_value is None: + return check_result.fail("Divisor is not a constant value.") + if cst.const_value.numpy().size > 1: + return check_result.fail("Divisor is not a scalar value.") + return check_result + + def rewrite(self, op, x, y, cst, fused: ir.Value): + value = cst.const_value.numpy() + c = float(value[0] if value.shape == (1,) else value) + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) + kwargs["alpha"] = kwargs.get("alpha", 1.0) / c + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") + + +class _TransposeMatMulBase(orp.RewriteRuleClassBase): + _pos: ClassVar = 1 + + def check( + self, context, x, y, transposed: ir.Value, fused: ir.Value | None = None, **_ + ) -> orp.MatchResult: + check_result = orp.MatchResult() + transposed_node = _get_node(transposed, "Transpose") + perm = transposed_node.attributes.get_ints("perm") + if perm: + # Check that last two dimensions are swapped + expected_perm = list(range(len(perm))) + expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] + if list(perm) != expected_perm: + return check_result.fail("Permutation values for Transpose are not correct.") + elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or ( + self._pos == 2 and not _ir_utils.has_rank(y, 2) + ): + # If perm is not defined, the default transpose behavior is to swap + # all dimensions, which is correct for MatMul with rank = 2. + return check_result.fail( + "If perm is not defined, rank must be 2 for TransposeMatMul rule." + ) + if fused: + fused_node = _get_node(fused, "FusedMatMul") + trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" + if fused_node.attributes.get_int(trans_batch_property, 0): + return check_result.fail( + "FusedMatMul with transposed batch cannot be used with op.Transpose in this rule." + ) + return check_result + + def rewrite(self, op, x, y, fused: ir.Value | None = None, **_): + kwargs = {} + if fused: + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) + trans_name = "transA" if self._pos == 1 else "transB" + kwargs[trans_name] = 1 - kwargs.get(trans_name, 0) + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") + + +class TransposeMatMul1(_TransposeMatMulBase): + """Replaces ``Transpose + MatMul`` with FusedMatMul.""" + + def pattern(self, op, x, y): + return op.MatMul(op.Transpose(x, _outputs=["transposed"]), y) + + +class TransposeFusedMatMul1(TransposeMatMul1): + """Replaces ``Transpose + FusedMatMul`` with FusedMatMul.""" + + def pattern(self, op, x, y): + return op.FusedMatMul( + op.Transpose(x, _outputs=["transposed"]), + y, + _domain="com.microsoft", + _outputs=["fused"], + ) + + +class TransposeMatMul2(_TransposeMatMulBase): + """Replaces ``Transpose + MatMul`` with FusedMatMul.""" + + _pos: ClassVar = 2 + + def pattern(self, op, x, y): + return op.MatMul(x, op.Transpose(y, _outputs=["transposed"])) + + +class TransposeFusedMatMul2(TransposeMatMul2): + """Replaces ``Transpose + FusedMatMul`` with FusedMatMul.""" + + def pattern(self, op, x, y): + return op.FusedMatMul( + x, + op.Transpose(y, _outputs=["transposed"]), + _domain="com.microsoft", + _outputs=["fused"], + ) + + +class _TransposeFusedMatMulBaseWithBatch(orp.RewriteRuleClassBase): + """Replaces ``Transpose + FusedMatMul`` with FusedMatMul, either + when transBatchA or transBatchB in FusedMatMul is 1, or + can be inverted based on the permutation dims of the Transpose, in + contrast to the original FusedMatMul rule which assumes that + transBatchA and transBatchB are always 0 before and after rewriting. + + transBatchA = 1, transA = 0 applies a batch transpose by moving the first dimension to the second-to-last position + i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-2, 0, N-1]. + transBatchA = 0, transA = 1 flips the last two dimensions + i.e., equivalent to a Transpose with "perm" [0, 1, ... N-3, N-1, N-2]. + transBatchA = 1, transA = 1 applies a batch transpose, then flips the last two dimensions + i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-1, 0]. + + The flipping logic is based on the following cases: + Case 1: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-1, 0] + or transBatchA is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2] + - Then transBatchA and transA can be flipped in FusedMatMul when rewriting. + Case 2: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1] + or transBatchA is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1] + - Then transBatchA can be flipped in FusedMatMul when rewriting. + Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0] + - Then transA can be flipped in FusedMatMul when rewriting. + The same logic applies for transBatchB and transB, when _pos is set to 2. + The _flip_transpose_batch and _flip_transpose flags are used to control + which case is applied by the rules of inheriting classes that change these class vars. + """ + + _pos: ClassVar = 1 + _flip_transpose_batch: ClassVar = False + _flip_transpose: ClassVar = False + + def check( + self, context, x, y, transposed: ir.Value, fused: ir.Value, **_ + ) -> orp.MatchResult: + check_result = orp.MatchResult() + fused_node = _get_node(fused, "FusedMatMul") + trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" + trans_batch = fused_node.attributes.get_int(trans_batch_property, 0) + transposed_node = _get_node(transposed, "Transpose") + perm = list(transposed_node.attributes["perm"].as_ints()) + if not perm: + return check_result.fail("Permutation values for Transpose are not correct.") + + list_perm = list(range(len(perm))) + if self._flip_transpose_batch and self._flip_transpose: + # Case 1: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-1, 0] + # or transBatchA/B is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2] + # - Then transBatchA/B and transA/B can be flipped in FusedMatMul when rewriting. + if trans_batch == 0: + expected_perm = [*list_perm[1:], list_perm[0]] + else: + expected_perm = [list_perm[-1], *list_perm[0:-1]] + if expected_perm == perm: + return check_result + elif self._flip_transpose_batch: + # Case 2: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1] + # or transBatchA/B is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1] + # - Then transBatchA/B can be flipped in FusedMatMul when rewriting. + if trans_batch == 0: + expected_perm = [*list_perm[1:-1], list_perm[0], list_perm[-1]] + else: + expected_perm = [list_perm[-2], *list_perm[0:-2], list_perm[-1]] + if expected_perm == perm: + return check_result + elif self._flip_transpose: + # Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0] + # - Then transA can be flipped in FusedMatMul when rewriting. + expected_perm = [list_perm[-1], *list_perm[1:-1], list_perm[0]] + if expected_perm == perm and trans_batch == 1: + return check_result + + return check_result.fail("Permutation values for Transpose are not correct.") + + def rewrite(self, op, x, y, fused: ir.Value, **_): + kwargs = {} + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) + name = "A" if self._pos == 1 else "B" + if self._flip_transpose_batch: + trans_batch_property = f"transBatch{name}" + kwargs[trans_batch_property] = 1 - kwargs.get(trans_batch_property, 0) + if self._flip_transpose: + trans_property = f"trans{name}" + kwargs[trans_property] = 1 - kwargs.get(trans_property, 0) + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") + + def pattern(self, op, x, y): + if self._pos == 1: + return op.FusedMatMul( + op.Transpose(x, _outputs=["transposed"]), + y, + _domain="com.microsoft", + _outputs=["fused"], + ) + else: + return op.FusedMatMul( + x, + op.Transpose(y, _outputs=["transposed"]), + _domain="com.microsoft", + _outputs=["fused"], + ) + + +class TransposeFusedMatMulWithFlippedBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch): + _flip_transpose = True + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithFlippedBatchAndTranspose2(_TransposeFusedMatMulBaseWithBatch): + _pos = 2 + _flip_transpose = True + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithFlippedBatch1(_TransposeFusedMatMulBaseWithBatch): + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithFlippedBatch2(_TransposeFusedMatMulBaseWithBatch): + _pos = 2 + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch): + _flip_transpose = True + + +class TransposeFusedMatMulWithBatchAndTranspose2(_TransposeFusedMatMulBaseWithBatch): + _pos = 2 + _flip_transpose = True + + +class MatMulTranspose(orp.RewriteRuleClassBase): + """Replaces ``MatMul + Transpose`` with FusedMatMul.""" + + def pattern(self, op, x, y): + return op.Transpose(op.MatMul(x, y), _outputs=["transposed"]) + + def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult: + check_result = orp.MatchResult() + transpose_node = _get_node(transposed, "Transpose") + perm = transpose_node.attributes.get_ints("perm") + # transA/transB only work on the last two dimensions of the input, + # so we can only apply this rule if the inputs are rank 2. + if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2): + if perm: + # Check that the two dimensions are swapped + if tuple(perm) != (1, 0): + return check_result.fail( + "Permutation values for Transpose are not correct." + ) + # If perm is not defined, the default transpose behavior is to swap + # all dimensions, which is correct for MatMul with rank = 2. + else: + return check_result.fail("Rank must be 2 for MatMulTranspose rule.") + return check_result + + def rewrite(self, op, x, y, fused: ir.Value | None = None, **_): + kwargs = {} + if fused: + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) + for name in ["transA", "transB"]: + kwargs[name] = 1 - kwargs.get(name, 0) + return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft") + + +class FusedMatMulTranspose(MatMulTranspose): + """Replaces ``FusedMatMul + Transpose`` with FusedMatMul.""" + + def pattern(self, op, x, y): + return op.Transpose( + op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), + _outputs=["transposed"], + ) + + +def fused_matmul_rule_sets() -> orp.RewriteRuleSet: + """Returns a set of rules introducing onnxruntime contrib ops. + This requires onnxruntime to run the model after it is rewritten. + + Returns: + RewriteRuleSet + """ + return orp.RewriteRuleSet( + [ + FusedMatMulDiv1.rule(), + FusedMatMulDiv2.rule(), + FusedMatMulTranspose.rule(), + MatMulTranspose.rule(), + TransposeMatMul1.rule(), + TransposeFusedMatMul1.rule(), + TransposeMatMul2.rule(), + TransposeFusedMatMul2.rule(), + TransposeFusedMatMulWithFlippedBatch1.rule(), + TransposeFusedMatMulWithFlippedBatch2.rule(), + TransposeFusedMatMulWithFlippedBatchAndTranspose1.rule(), + TransposeFusedMatMulWithFlippedBatchAndTranspose2.rule(), + TransposeFusedMatMulWithBatchAndTranspose1.rule(), + TransposeFusedMatMulWithBatchAndTranspose2.rule(), + ] + ) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py new file mode 100644 index 0000000000..f82702d557 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -0,0 +1,448 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from typing import Any, Tuple + +import numpy as np +import onnx +import onnx.reference +import onnx.reference.op_run +import onnx_ir.passes.common as common_passes +import parameterized + +import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets +from onnxscript import FLOAT, ir, script +from onnxscript.onnx_opset import opset18 as op +from onnxscript.values import Opset + +ms_op = Opset("com.microsoft", 1) + + +class FusedMatMul(onnx.reference.op_run.OpRun): + op_domain = "com.microsoft" + + def _run( + self, + A, + B, + alpha: float = 1, + transA: int = 0, + transB: int = 0, + transBatchA: int = 0, + transBatchB: int = 0, + ): + if transBatchA != 0 or transBatchB != 0: + assert len(A.shape) >= 3 and len(B.shape) >= 3, ( + f"Batch dimensions must be at least 3 for A: {A.shape} and B: {B.shape}" + ) + assert len(A.shape) == len(B.shape), ( + f"Batch dimensions must match for A: {A.shape} and B: {B.shape}" + ) + if transBatchA: + perm = list(range(len(A.shape))) + dim = len(perm) + perm = [*perm[1 : dim - 1], perm[0], perm[dim - 1]] + A = np.transpose(A, perm) + if transBatchB: + perm = list(range(len(B.shape))) + dim = len(perm) + perm = [*perm[1 : dim - 1], perm[0], perm[dim - 1]] + B = np.transpose(B, perm) + if transA: + perm = list(range(len(A.shape))) + dim = len(perm) + perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2] + A = np.transpose(A, perm) + if transB: + perm = list(range(len(B.shape))) + dim = len(perm) + perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2] + B = np.transpose(B, perm) + a = np.array(alpha, dtype=A.dtype) + return (np.matmul(A, B) * a,) + + +@script() +def _fused_matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + C = 0.6 + ab = ms_op.FusedMatMul(A, B, alpha=0.4, transA=1) + out = op.Div(ab, C) + return out + + +@script() +def _matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + C = 0.8 + ab = op.MatMul(A, B) + out = op.Div(ab, C) + return out + + +@script() +def _matmul_div_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + C = 0.6 + ab = op.MatMul(A, B) + abd = op.Div(ab, C) + out = op.Div(abd, C) + return out + + +@script() +def _fused_matmul_transpose(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + ab = ms_op.FusedMatMul(A, B, alpha=0.5) + out = op.Transpose(ab, perm=[1, 0]) + return out + + +@script() +def _matmul_transpose(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + ab = op.MatMul(A, B) + out = op.Transpose(ab, perm=[1, 0]) + return out + + +@script() +def _transpose_matmul_1(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + At = op.Transpose(A, perm=[1, 0]) + out = op.MatMul(At, B) + return out + + +@script() +def _transpose_fused_matmul_1(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + At = op.Transpose(A, perm=[1, 0]) + out = ms_op.FusedMatMul(At, B) + return out + + +@script() +def _transpose_matmul_2(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + Bt = op.Transpose(B, perm=[1, 0]) + out = op.MatMul(A, Bt) + return out + + +@script() +def _transpose_fused_matmul_2(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + Bt = op.Transpose(B, perm=[1, 0]) + out = ms_op.FusedMatMul(A, Bt) + return out + + +@script() +def _should_not_match(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> Tuple[FLOAT[4, 4], FLOAT[4, 4]]: + At = op.Transpose(A, perm=[1, 0]) + ab = op.MatMul(At, B) + C = op.Transpose(At, perm=[1, 0]) + return ab, C + + +# Add unit tests to check that fusion rewrite can work even if MatMul is not the first node. +@script() +def _fused_matmul_with_identity_before_matmul(A: FLOAT[4, 4]) -> FLOAT[4, 4]: + B = op.Identity(A) + ab = op.MatMul(A, B) + out = op.Transpose(ab, perm=[1, 0]) + return out + + +@script() +def _fused_matmul_with_identity_before_transpose(A: FLOAT[4, 4]) -> FLOAT[4, 4]: + B = op.Identity(A) + ab = op.Transpose(A, perm=[1, 0]) + out = op.MatMul(ab, B) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_0_and_transA( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[1, 2, 3, 0]) + out = ms_op.FusedMatMul(Xt, Y, alpha=0.5, transA=0, transB=0, transBatchA=0, transBatchB=0) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_1_and_transA( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[3, 0, 1, 2]) + out = ms_op.FusedMatMul(Xt, Y, transBatchA=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_0( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[1, 2, 0, 3]) + out = ms_op.FusedMatMul(Xt, Y) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_1( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[2, 0, 1, 3]) + out = ms_op.FusedMatMul(Xt, Y, transBatchA=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transA( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[3, 1, 2, 0]) + out = ms_op.FusedMatMul(Xt, Y, transBatchA=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_0_and_transB( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[1, 2, 3, 0]) + out = ms_op.FusedMatMul(X, Yt) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_1_and_transB( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[3, 0, 1, 2]) + out = ms_op.FusedMatMul(X, Yt, transBatchB=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_0( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[1, 2, 0, 3]) + out = ms_op.FusedMatMul(X, Yt) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_1( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[2, 0, 1, 3]) + out = ms_op.FusedMatMul(X, Yt, transBatchB=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transB( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[3, 1, 2, 0]) + out = ms_op.FusedMatMul(X, Yt, transBatchB=1) + return out + + +class TestFusedMatmulRules(unittest.TestCase): + def _apply_fusion_rules(self, ir_model: ir.Model): + rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() + rule_set.apply_to_model(ir_model) + + def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: + feeds: dict[str, Any] = {} + for i in model.graph.input: + ish = tuple(i.type.tensor_type.shape.dim) + # Creates an input tensor with a dimension defined by the onnx model + # or equals to i + 2 with i being the dimension index. + # The tensor is kept small to make the test fast. + shape = tuple( + (d.dim_value if d.dim_value > 0 else i + 2) for i, d in enumerate(ish) + ) + if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT: + if shape: + feeds[i.name] = np.random.randn(*shape).astype(np.float32) + else: + feeds[i.name] = np.random.randn(1).astype(np.float32) + else: + raise AssertionError(f"Not implemented for input {i}") + return feeds + + def _check_model( + self, + model: onnx.ModelProto, + optimized_model: onnx.ModelProto, + feeds: dict[str, Any] | None = None, + atol: float = 0.0, + rtol: float = 1e-7, + ): + if not feeds: + feeds = self._get_random_inputs(model) + ref = onnx.reference.ReferenceEvaluator(model, new_ops=[FusedMatMul]) + opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul]) + expected = ref.run(None, feeds) + got = opt.run(None, feeds) + self.assertEqual(len(got), len(expected)) + for a, b in zip(expected, got): + np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) + + @parameterized.parameterized.expand( + [ + ( + "fused_matmul_div", + _fused_matmul_div, + [FLOAT[6, "a"], FLOAT[6, "b"]], + [FLOAT[None, None]], + ), + ( + "matmul_div", + _matmul_div, + [FLOAT["a", 6], FLOAT[6, "b"]], + [FLOAT[None, None]], + ), + ( + "matmul_div_div", + _matmul_div_div, + [FLOAT["a", 6], FLOAT[6, "b"]], + [FLOAT[None, None]], + ), + ] + ) + def test_fused_matmul_div_models(self, name, script_func, input_types, output_types): + model_proto = script_func.to_model_proto( + input_types=input_types, + output_types=output_types, + ) + ir_model = ir.serde.deserialize_model(model_proto) + rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() + rule_set.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual([n.op_type for n in ir_model.graph], ["Constant", "FusedMatMul"]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @parameterized.parameterized.expand( + [ + ( + "fused_matmul_transpose", + _fused_matmul_transpose, + ), + ( + "matmul_transpose", + _matmul_transpose, + ), + ( + "transpose_matmul_1", + _transpose_matmul_1, + ), + ( + "transpose_fused_matmul_1", + _transpose_fused_matmul_1, + ), + ("transpose_matmul_2", _transpose_matmul_2), + ( + "transpose_fused_matmul_2", + _transpose_fused_matmul_2, + ), + ] + ) + def test_fused_matmul_with_transpose(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4], FLOAT[4, 4]], output_types=[FLOAT[4, 4]] + ) + ir_model = ir.serde.deserialize_model(model_proto) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @parameterized.parameterized.expand([("should_not_match", _should_not_match)]) + def test_should_not_match(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4], FLOAT[4, 4]], output_types=[FLOAT[4, 4], FLOAT[4, 4]] + ) + ir_model = ir.serde.deserialize_model(model_proto) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual( + [n.op_type for n in ir_model.graph], + ["Transpose", "MatMul", "Transpose"], + ) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @parameterized.parameterized.expand( + [ + ( + "fused_matmul_with_identity_before_matmul", + _fused_matmul_with_identity_before_matmul, + ), + ( + "fused_matmul_with_identity_before_transpose", + _fused_matmul_with_identity_before_transpose, + ), + ] + ) + def test_fused_matmul_with_other_node_in_middle(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4]], output_types=[FLOAT[4, 4]] + ) + ir_model = ir.serde.deserialize_model(model_proto) + common_passes.ShapeInferencePass()(ir_model) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual([n.op_type for n in ir_model.graph], ["Identity", "FusedMatMul"]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @parameterized.parameterized.expand( + [ + ( + "transpose_fused_matmul_flip_transBatchA_0_and_transA", + _transpose_fused_matmul_flip_transBatchA_0_and_transA, + ), + ( + "transpose_fused_matmul_flip_transBatchA_1_and_transA", + _transpose_fused_matmul_flip_transBatchA_1_and_transA, + ), + ( + "transpose_fused_matmul_flip_transBatchA_0", + _transpose_fused_matmul_flip_transBatchA_0, + ), + ( + "transpose_fused_matmul_flip_transBatchA_1", + _transpose_fused_matmul_flip_transBatchA_1, + ), + ("transpose_fused_matmul_flip_transA", _transpose_fused_matmul_flip_transA), + ( + "transpose_fused_matmul_flip_transBatchB_0_and_transB", + _transpose_fused_matmul_flip_transBatchB_0_and_transB, + ), + ( + "transpose_fused_matmul_flip_transBatchB_1_and_transB", + _transpose_fused_matmul_flip_transBatchB_1_and_transB, + ), + ( + "transpose_fused_matmul_flip_transBatchB_0", + _transpose_fused_matmul_flip_transBatchB_0, + ), + ( + "transpose_fused_matmul_flip_transBatchB_1", + _transpose_fused_matmul_flip_transBatchB_1, + ), + ("transpose_fused_matmul_flip_transB", _transpose_fused_matmul_flip_transB), + ] + ) + def test_transpose_fused_matmul_with_batch(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4, 4, 4], FLOAT[4, 4, 4, 4]], + output_types=[FLOAT[4, 4, 4, 4]], + ) + ir_model = ir.serde.deserialize_model(model_proto) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/ort_fusions/gelu.py b/onnxscript/rewriter/ort_fusions/gelu.py new file mode 100644 index 0000000000..f4f27a03b5 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gelu.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math + +from onnxscript.rewriter import _fusion_utils, pattern + +_SQRT_TWO_OVER_PI = math.sqrt(2.0 / math.pi) +_SQRT_TWO = math.sqrt(2.0) + + +class GeluTanhFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x): + # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} + t1 = op.Pow(x, 3) + t2 = op.Mul(0.044715, t1) + t3 = op.Add(x, t2) + + t4 = op.Mul(_SQRT_TWO_OVER_PI, t3) + t5 = op.Tanh(t4) + t6 = op.Add(t5, 1) + t7 = op.Mul(0.5, t6) + result = op.Mul(x, t7) + return result + + def rewrite(self, op, x): + return op.FastGelu(x, _domain="com.microsoft") + + +class GeluErfFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x): + # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + t1 = op.Div(x, _SQRT_TWO) + t2 = op.Erf(t1) + t3 = op.Add(t2, 1.0) + t4 = op.Mul(x, t3) + result = op.Mul(t4, 0.5) + return result + + def rewrite(self, op, x): + return op.Gelu(x, _domain="com.microsoft") + + +_tanh_rule = GeluTanhFusion.rule() +_erf_rule = GeluErfFusion.rule() + +gelu_rules = pattern.RewriteRuleSet([_tanh_rule, _erf_rule]) + +fuse_gelu = _fusion_utils.apply_fusion_rules(gelu_rules) diff --git a/onnxscript/rewriter/ort_fusions/gelu_test.py b/onnxscript/rewriter/ort_fusions/gelu_test.py new file mode 100644 index 0000000000..9726e39756 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gelu_test.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import math +import unittest + +import numpy as np +import onnx_ir as ir + +import onnxscript.rewriter.ort_fusions._test_utils as test_utils +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.optimizer import optimize, remove_unused_nodes +from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu + + +class GeluFusionTest(unittest.TestCase): + def test_gelu_fusion(self): + _sqrt_two_over_pi = math.sqrt(2.0 / math.pi) + + @script() + def gelu_model(x): + # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} + t1 = op.Pow(x, 3) + t2 = op.Mul(0.044715, t1) + t3 = op.Add(x, t2) + + t4 = op.Mul(_sqrt_two_over_pi, t3) + t5 = op.Tanh(t4) + t6 = op.Add(t5, 1) + t7 = op.Mul(0.5, t6) + result = op.Mul(x, t7) + return result + + model_proto = gelu_model.to_model_proto( + input_types=[FLOAT[10]], output_types=[FLOAT[10]] + ) + model = ir.serde.deserialize_model(model_proto) + + # Eliminate redundant CastLike ops: + optimize(model) + + input = {"x": np.random.randn(10).astype(np.float32)} + original_output = test_utils.ort_run("Original", model, input) + + fuse_gelu(model) + remove_unused_nodes(model) + + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph.node(0).op_type, "FastGelu") + + optimized_output = test_utils.ort_run("Optimized", model, input) + test_utils.assert_allclose(original_output, optimized_output) + + def test_gelu_erf_fusion(self): + _sqrt_two = math.sqrt(2.0) + + @script() + def gelu_erf_model(x): + # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + t1 = op.Div(x, _sqrt_two) + t2 = op.Erf(t1) + t3 = op.Add(t2, 1.0) + t4 = op.Mul(x, t3) + result = op.Mul(t4, 0.5) + return result + + model_proto = gelu_erf_model.to_model_proto( + input_types=[FLOAT[10]], output_types=[FLOAT[10]] + ) + model = ir.serde.deserialize_model(model_proto) + + # Eliminate redundant CastLike ops: + optimize(model) + + input = {"x": np.random.randn(10).astype(np.float32)} + original_output = test_utils.ort_run("Original", model, input) + + fuse_gelu(model) + remove_unused_nodes(model) + + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph.node(0).op_type, "Gelu") + + optimized_output = test_utils.ort_run("Optimized", model, input) + test_utils.assert_allclose(original_output, optimized_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py new file mode 100644 index 0000000000..5fff910bcf --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -0,0 +1,362 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import numpy as np +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _basics, _ir_utils, pattern + +""" +GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different +for query and key/value. + +We use the following abbreviations for the dimensions: +B: Batch size +S: Sequence length (for current query/key/value) + +Hkv: number of heads for key/value +G = number of groups +H: number of heads = G * Hkv + +Dh: head size or embedding dimension per head +D: input embedding dimension (hidden size) = H * Dh +Dkv: key/value hidden size = Hkv * Dh + +T: total sequence length (after concatenation of past and current key/value) +""" + +Dim = Union[int, ir.SymbolicDim] + + +def _is_model_input(value: ir.Value, name: str, model: ir.Model) -> bool: + return value in model.graph.inputs and value.name == name + + +def _causal_mask( + op, + input_ids, + past_kv_cache, + shape_B111, + min_val, + window_size, + dtype, +): + """Defines a pattern for a pure causal mask, with optional sliding window support.""" + seq_len = op.Shape(input_ids, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len = op.Shape(past_kv_cache, end=3, start=2) + past_seq_len_0D = op.Squeeze(past_seq_len) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len, axis=0) + mask_all_min_expand = op.Expand(min_val, mask_shape) + # The following Trilu is optional: not used in Phi models, but used in LLama. + mask_all_min_trilu = op.Trilu(mask_all_min_expand, 1, upper=1) + mask_all_min = pattern.OrValue([mask_all_min_expand, mask_all_min_trilu]) + total_range_as_row = op.Range(0, total_seq_len_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + + non_causal = op.Greater(total_range_as_row, current_range_as_column) + + # sliding window support: + current_range_minus_window = op.Sub(current_range_as_column, window_size) + out_of_sliding_window = op.LessOrEqual(total_range_as_row, current_range_minus_window) + non_causal_sliding_window = op.Or(non_causal, out_of_sliding_window) + + boolean_mask = pattern.OrValue([non_causal, non_causal_sliding_window]) + + float_0_1_mask = op.Cast(boolean_mask, to=dtype) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d_11ST = op.Unsqueeze(float_0_min_mask, [0, 1]) + mask_4d_B1ST = op.Expand(mask_4d_11ST, shape_B111) + + return mask_4d_B1ST + + +class _CausalMaskPattern(pattern.PatternBase): + def pattern( + self, + op, + input_ids, + past_kv_cache, + shape_B111, + min_val, + window_size, + dtype1, + attn_mask_2d, + dtype2, + ): + causal_mask = _causal_mask( + op, + input_ids, + past_kv_cache, + shape_B111, + min_val, + window_size, + dtype1, + ) + + attn_mask_4d = op.Unsqueeze(attn_mask_2d, [1, 2]) + attn_mask_4d_cast = op.Cast(attn_mask_4d, to=dtype2) + + sum = op.Add(causal_mask, attn_mask_4d_cast) + sum_fp32 = op.Cast(sum, to=ir.DataType.FLOAT) + # The cast is optional, and may be absent if the sum is already in float32. + sum_fp32 = pattern.OrValue([sum_fp32, sum]) + is_zero = op.Equal(sum_fp32, 0.0) + result = op.Where(is_zero, min_val, causal_mask) + return result + + def check(self, context, dtype1, dtype2, min_val, attn_mask_2d, sliding_window=None, **_): + # Check that attn_mask_2d is the model input "attention_mask" + if not _is_model_input(attn_mask_2d, "attention_mask", context.model): + return pattern.MatchResult().fail("Invalid attention_mask input", attn_mask_2d) + + if dtype1.as_int() != dtype2.as_int(): + return pattern.MatchResult().fail("Dtype mismatch", [dtype1, dtype2]) + + # Check that min_val is a constant and matches the expected minimum value for the dtype. + min_value = _ir_utils.get_singleton_value(min_val) + if min_value is None: + return pattern.MatchResult().fail("Minval is not a constant.", min_val) + expected_min_value = np.finfo(min_val.dtype.numpy()).min + if min_value != expected_min_value: + return pattern.MatchResult().fail( + f"Expected min value {expected_min_value}, got {min_value}", min_val + ) + + # TODO(rama) Sliding window: not yet supported. + if sliding_window: + return pattern.MatchResult().fail( + "Sliding window not yet supported", sliding_window + ) + return True + + +_causal_mask_pattern = _CausalMaskPattern() + + +class GroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQA", remove_nodes=False) + + def pattern( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + position_ids, + cos, + sin, + mask, + ): + # Reshape query from (B, S, D) to (B, S, H, D/H) + query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) + key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) + value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + query_BHSDh_rope = op.RotaryEmbedding( + query_BHSDh, + position_ids, + cos, + sin, + _domain="com.microsoft", + _outputs=["query_BHSDh_rope"], + ) + key_BHkvSDh_rope = op.RotaryEmbedding( + key_BHkvSDh, + position_ids, + cos, + sin, + _domain="com.microsoft", + _outputs=["key_BHkvSDh_rope"], + ) + + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) + key_seq_BHTDh = op.Reshape( + key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) + value_seq_BHTDh = op.Reshape( + value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] + ) + + attention_BHSDh = op.SDPA( + query_BHSDh_rope, + key_seq_BHTDh, + value_seq_BHTDh, + mask, + key_format="BHSd", + _domain="ai.onnxruntime._fusion", + ) + + # Transpose attention back to (B, S, H, D/H) + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_BSD = op.Reshape( + attention_BSHDh, pattern.ANY_VALUE, _outputs=["attention_BSD"] + ) + return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh + + def check( + self, + context: _basics.MatchContext, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + query_BHSDh_rope, + key_BHkvSDh_rope, + query_BSHDh, + key_BSHkvDh, + mask, + **_, + ): + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils.check_shape_bool(bindings, val, dims) + + if no_match(query_BSD, ["B", "S", "D"]): + return False + if no_match(key_BSDkv, ["B", "S", "Dkv"]): + return False + if no_match(value_BSDkv, ["B", "S", "Dkv"]): + return False + + if no_match(past_key, ["B", "Hkv", "P", "Dh"]): + return False + if no_match(past_value, ["B", "Hkv", "P", "Dv"]): + return False + + # TODO: verify Reshapes: + # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: + # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: + # or check Reshape's shape-input value + + result = pattern.MatchResult() + num_heads = _ir_utils.get_dim(query_BSHDh, 2) + kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + if not isinstance(num_heads, int): + return result.fail("Unable to determine num_heads value", query_BSHDh) + if not isinstance(kv_num_heads, int): + return result.fail("Unable to determine kv_num_heads value", key_BSHkvDh) + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + + # Rotary embedding attributes + query_rotary_attributes = query_BHSDh_rope.producer().attributes + key_rotary_attributes = key_BHkvSDh_rope.producer().attributes + query_interleaved = query_rotary_attributes.get_int("interleaved", 0) + key_interleaved = key_rotary_attributes.get_int("interleaved", 0) + if query_interleaved != key_interleaved: + return pattern.MatchResult().fail( + "Rotary embedding interleaved attribute mismatch", + [query_BHSDh_rope.producer(), key_BHkvSDh_rope.producer()], + ) + self._interleaved = query_interleaved + + # Check mask: + mask_node = mask.producer() + if mask_node is None: + return pattern.MatchResult().fail("Unhandled mask pattern", mask) + mask_match_result = _causal_mask_pattern.match( + context.model, + context.graph_or_function, + mask_node, + check_nodes_are_removable=False, + ) + if mask_match_result is None: + return pattern.MatchResult().fail("Mask does not match causal mask pattern", mask) + # TODO: handle sliding window support in mask + + return True + + def rewrite( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + position_ids, + cos, + sin, + mask, + **_, + ): + # Note that the following optimization is specific to current ORT GenAI attention-mask + # usage. Specifically, it assumes that the model-input "attention_mask" is a 2D + # mask with shape [batch_size, sequence_length], and that the mask is a 0/1 mask + # that is used only to indicate the current tokens. Hence, the input attention_mask + # is redundant as long as past-sequence-length and current-sequence-length can be + # computed. + + # Construct seqlens_k and total_seq_length_int32 from position_ids + # seqlens_k : int32[batch_size] indicates total_sequence-length-1 for each batch + # position_ids: int64[batch_size, sequence_length] indicates the position of each token + one_int32_0d = op.Constant(value=ir.tensor(1, dtype=ir.DataType.INT32)) + one_int64_1d = op.Constant(value=ir.tensor([1], dtype=ir.DataType.INT64)) + zero_int64_1d = op.Constant(value=ir.tensor([0], dtype=ir.DataType.INT64)) + seqlens_k_int64 = op.ReduceMax(position_ids, one_int64_1d, keepdims=0) + seqlens_k = op.Cast(seqlens_k_int64, to=ir.DataType.INT32) + max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0) + total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d) + return op.GroupQueryAttention( + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + seqlens_k, + total_seq_length_int32, + cos, + sin, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + do_rotary=1, + rotary_interleaved=self._interleaved, + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap + _domain="com.microsoft", + _outputs=3, + ) + + +_basic_gqa_rule = GroupQueryAttention.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) diff --git a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py new file mode 100644 index 0000000000..51355fc8cf --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class PackedQKVForGQAFusion(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("PackedQKVForGQA", remove_nodes=False) + + def pattern( + self, + op, + packed_qkv, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + q_num_heads, + kv_num_heads, + interleaved, + start1, + end1, + start2, + end2, + start3, + end3, + ): + """Pattern to detect sliced Q, K, V passed to GQA and replace with packed QKV.""" + + # Slice packed QKV into query, key, and value + query_BSD = op.Slice(packed_qkv, start1, end1, [2], [1], _outputs=["query_sliced"]) + key_BSDkv = op.Slice(packed_qkv, start2, end2, [2], [1], _outputs=["key_sliced"]) + value_BSDkv = op.Slice(packed_qkv, start3, end3, [2], [1], _outputs=["value_sliced"]) + + # Pass sliced Q, K, V to GroupQueryAttention + return op.GroupQueryAttention( + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + # mask, # TODO: this is not a valid input for GQA + num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + do_rotary=1, + rotary_interleaved=interleaved, + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap + _domain="com.microsoft", + _outputs=3, + ) + + def check( + self, + op, + packed_qkv, + query_sliced, + key_sliced, + value_sliced, + q_num_heads, + kv_num_heads, + start1, + end1, + start2, + end2, + start3, + end3, + **_, + ): + check_result = pattern.MatchResult() + self.bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) + + # Check that if x is being split into q, k, v correctly + # based on hidden sizes + if packed_qkv is None or packed_qkv.shape is None or len(packed_qkv.shape) != 3: + return check_result.fail("packed_qkv is not a 3D tensor.", packed_qkv) + hidden_size = packed_qkv.shape[2] + if not isinstance(hidden_size, int): + return check_result.fail("Hidden size is not an integer.", packed_qkv) + q_nh = q_num_heads.value + kv_nh = kv_num_heads.value + if not isinstance(q_nh, int) or not isinstance(kv_nh, int): + return check_result.fail( + "Could not determine the number of heads for query, key and value.", + ) + head_size = hidden_size // (q_nh + (2 * kv_nh)) + q_hidden_size = head_size * q_nh + kv_hidden_size = head_size * kv_nh + if not ( + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.is_singleton_value(end1, q_hidden_size) + and _ir_utils.is_singleton_value(start2, q_hidden_size) + and _ir_utils.is_singleton_value(end2, (q_hidden_size + kv_hidden_size)) + and _ir_utils.is_singleton_value(start3, (q_hidden_size + kv_hidden_size)) + and _ir_utils.is_singleton_value(end3, lambda x: x >= hidden_size) + ): + return check_result.fail( + "packed_qkv is not being split into q, k, v correctly based on hidden sizes.", + packed_qkv, + ) + + # Check packed_qkv shape (B, S, D) + if no_match(packed_qkv, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {packed_qkv} does not match expected dimensions ['B', 'S', 'D']", + packed_qkv, + ) + + # Check query, key, and value shapes (B, S, Dh) + if no_match(query_sliced, ["B", "S", "Dq"]): + return check_result.fail( + f"Shape mismatch: {query_sliced} does not match expected dimensions ['B', 'S', 'Dq']", + query_sliced, + ) + if no_match(key_sliced, ["B", "S", "Dkv"]): + return check_result.fail( + f"Shape mismatch: {key_sliced} does not match expected dimensions ['B', 'S', 'Dkv']", + key_sliced, + ) + if no_match(value_sliced, ["B", "S", "Dkv"]): + return check_result.fail( + f"Shape mismatch: {value_sliced} does not match expected dimensions ['B', 'S', 'Dkv']", + value_sliced, + ) + + # Ensure Dh = Dg + 2*Dkv + D = self.bindings.get("D") + Dq = self.bindings.get("Dq") + Dkv = self.bindings.get("Dkv") + + if not isinstance(D, int) or not isinstance(Dq, int) or not isinstance(Dkv, int): + return check_result.fail( + "Could not determine the hidden sizes of query, key, and value.", + ) + + if Dq + (2 * Dkv) != D: # type: ignore[operator] + return check_result.fail( + f"Hidden size of query, key and value do not add up to hidden size: {D} != {Dq} + (2 * {Dkv})", + ) + + return True + + def rewrite( + self, + op, + packed_qkv, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + q_num_heads, + kv_num_heads, + interleaved, + **_, + ): + """Rewrite the sliced Q, K, V into a packed QKV MatMul input for GQA.""" + + # Pass packed QKV directly to GroupQueryAttention + return op.GroupQueryAttention( + packed_qkv, + None, + None, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + do_rotary=1, + rotary_interleaved=interleaved, + _domain="com.microsoft", + _outputs=3, + ) + + +# Define the fusion rule +packed_qkv_for_gqa_rule = PackedQKVForGQAFusion.rule() + +# Add the rule to the GQA rewrite rule set +fuse_qkv_gqa_rules = pattern.RewriteRuleSet([packed_qkv_for_gqa_rule]) + +# Apply the fusion rules +fuse_qkv_gqa = _fusion_utils.apply_fusion_rules(fuse_qkv_gqa_rules) diff --git a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv_test.py b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv_test.py new file mode 100644 index 0000000000..d42ba83144 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv_test.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import onnx_ir as ir +import onnx_ir.passes.common.shape_inference as shape_inference +import onnxruntime as ort + +import onnxscript +import onnxscript.optimizer +from onnxscript import FLOAT, INT32, script +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose +from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa + +msft_op = onnxscript.values.Opset("com.microsoft", 1) + +# Test case for fusion of separate query, key and value inputs +# into a single packed QKV input for the GroupQueryAttention operator. + + +class PackedQKVforGQAFusionTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Config parameters + self.batchsize = 1 + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 + self.head_size = 16 + self.q_num_heads = 20 + self.kv_num_heads = 10 + + # Computed config parameters + self.q_hidden_size = self.head_size * self.q_num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + self.hidden_size = self.q_hidden_size + self.kv_hidden_size + self.kv_hidden_size + + # Abbreviations + B = self.batchsize + S = self.seqlen + P = self.past_seqlen + D = self.hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + total_seqlen = S + P + max_seqlen = total_seqlen + + self.input_types = ( + FLOAT["B", "S", D], # packed_qkv + FLOAT["B", Hkv, "P", Dh], # past_key + FLOAT["B", Hkv, "P", Dh], # past_value + INT32["B"], # seqlens_k + INT32[1], # total_sequence_length + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "T", Dh], # present_key + FLOAT["B", Hkv, "T", Dh], # present_value + ) + + self.inputs = { + "packed_qkv": np.random.rand(B, S, D).astype(np.float32), + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "seqlens_k": np.full((B,), total_seqlen - 1, dtype=np.int32), + "total_sequence_length": np.array([total_seqlen], dtype=np.int32), + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + } + + def source_model_script(self): + Hq = self.q_num_heads + Hkv = self.kv_num_heads + + @script() + def gqa(packed_qkv, past_key, past_value, seqlens_k, total_sequence_length, cos, sin): + # Slice packed_qkv into query, key and value + query_BSD = op.Slice(packed_qkv, [0], [320], [2], [1]) + key_BSDkv = op.Slice(packed_qkv, [320], [480], [2], [1]) + value_BSDkv = op.Slice(packed_qkv, [480], [640], [2], [1]) + + attn, past_key, past_value = msft_op.GroupQueryAttention( + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos, + sin, + num_heads=Hq, + kv_num_heads=Hkv, + do_rotary=1, + rotary_interleaved=0, + ) + return attn, past_key, past_value + + return gqa + + def test_fuse_packed_qkv_for_gqa(self): + """ + Test that fusion from query, key and value to a packed QKV for GQA + is successful on source model and produces an equivalent model. + """ + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + source_model_ir = ir.serde.from_proto(source_model) + inferred_model = shape_inference.infer_shapes(source_model_ir) + onnxscript.optimizer.optimize(inferred_model) + + count = fuse_qkv_gqa(inferred_model, debug=True) + self.assertEqual(count, 1) + + fused_model = ir.serde.to_proto(inferred_model) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + fused_model_outputs = session.run(None, inputs) + + self.assertEqual(len(fused_model_outputs), len(source_model_outputs)) + assert_allclose(fused_model_outputs, source_model_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py new file mode 100644 index 0000000000..64cb84d18e --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -0,0 +1,376 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +import onnx_ir.passes.common.shape_inference as shape_inference +import onnxruntime as ort +import torch + +import onnxscript +import onnxscript.optimizer +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.rewriter.models._phi4lm import phi4lm_test +from onnxscript.rewriter.ort_fusions import optimize_for_ort +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose +from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa +from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa + +msft_op = onnxscript.values.Opset("com.microsoft", 1) + +# Test case for GroupQueryAttention (GQA) fusion. + + +class GQAFusionTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Config parameters + self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 + self.head_size = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + + # Computed config parameters + self.hidden_size = self.head_size * self.num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + assert (self.num_heads % self.kv_num_heads) == 0, ( + "num_heads must be divisible by kv_num_heads" + ) + self.num_groups = self.num_heads // self.kv_num_heads + self.total_seqlen = self.seqlen + self.past_seqlen + + # Abbreviations + B = self.batchsize + S = self.seqlen + P = self.past_seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + total_seqlen = S + P + max_seqlen = total_seqlen + + # Input/output types have some dimensions as dynamic (even though the + # test case instance has specific values above). + self.input_types = ( + FLOAT["B", "S", D], # query + FLOAT["B", "S", Dkv], # key + FLOAT["B", "S", Dkv], # value + FLOAT["B", Hkv, "P", Dh], # past_key + FLOAT["B", Hkv, "P", Dh], # past_value + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "T", Dh], # present_key + FLOAT["B", Hkv, "T", Dh], # present_value + ) + + self.inputs = { + "query": np.random.rand(B, S, D).astype(np.float32), + "key": np.random.rand(B, S, Dkv).astype(np.float32), + "value": np.random.rand(B, S, Dkv).astype(np.float32), + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + } + + def target_model_script(self): + H = self.num_heads + Hkv = self.kv_num_heads + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Generate seqlens_k and total_seqlen inputs for GQA: + # In this test case, all batch elements have same sequence length. + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + total_seqlen_int32 = op.Cast(total_seq_length, to=6) + total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) + batchsize = op.Shape(query, start=0, end=1) + seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) + + attn, past_key, past_value = msft_op.GroupQueryAttention( + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seqlen_int32, + cos, + sin, + num_heads=H, + kv_num_heads=Hkv, + do_rotary=1, + ) + return attn, past_key, past_value + + return gqa + + def source_model_script(self): + scale_factor = math.sqrt(math.sqrt(self.head_size)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.head_size] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + plus_1 = [1] + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + # past_seq_length = op.Squeeze(past_seq_length_1D, [0]) + # S_0D = op.Squeeze(S,[0]) + + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) + + shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + # Concat past and do rotary embedding + position_ids_1d = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids_1d, [0]) + position_ids_k = op.Unsqueeze(position_ids_1d, [0]) + + # Note: The above code pattern for position-ids is from exported Phi model. + # However, for use with ORT's RotaryEmbedding it needs the following for batchsize > 1 + # But we currently target batchsize=1 since GQA requires it when there is a past key/value. + # + # position_ids_2d = op.Unsqueeze(position_ids_1d, [0]) + # tile_B_1 = op.Concat(B, plus_1, axis=0) + # position_ids = op.Tile(position_ids_2d, tile_B_1) + + query_BHSDh_rope = msft_op.RotaryEmbedding( + query_BHSDh, + position_ids_q, + cos, + sin, + ) + key_BHkvSDh_rope = msft_op.RotaryEmbedding( + key_BHkvSDh, + position_ids_k, + cos, + sin, + ) + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + + # Now, expand from shared heads to all heads + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate causal mask: + # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + seq_len = op.Shape(query, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len_0D = op.Squeeze(past_seq_length) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But duplicating same logic here. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_val = op.Constant(value=minval_tp) + mask_all_min = op.Expand(min_val, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + shape_B111 = op.Concat(B, plus_1, plus_1, plus_1, axis=0) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh_rope, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask_B1ST) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to BSD format + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh + + return gqa + + def test_equivalence(self): + """Test that the source and target models produce the same outputs.""" + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + target_model = self.target_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + target_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + target_model_outputs = session.run(None, inputs) + + self.assertEqual(len(source_model_outputs), len(target_model_outputs)) + assert_allclose(source_model_outputs, target_model_outputs) + + def test_fusion(self): + """Test that GQA fusion is successful on source model and produces an equivalent model.""" + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + # Some shapes need to be present in input model for fusion to be successful. + # (i) Shape inference doesn't handle handle ORT contrib ops. + # (ii) TODO: investigate if Reshape(..., ["B", "S", -1, Dh]) handled precisely + # by shape inference. + query_BHSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "query_BHSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.seqlen, self.head_size], + ) + key_BHkvSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "key_BHkvSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.kv_num_heads, self.seqlen, self.head_size], + ) + query_BSHDh_value_info = onnx.helper.make_tensor_value_info( + "query_BSHDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.num_heads, self.head_size], + ) + key_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "key_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) + key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info( + "key_BSHkvDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.kv_num_heads, self.head_size], + ) + key_transposed_value_info = onnx.helper.make_tensor_value_info( + "key_transposed", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.head_size, self.total_seqlen], + ) + value_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "value_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) + source_model.graph.value_info.extend( + [ + query_BHSDh_rope_value_info, + key_BHkvSDh_rope_value_info, + query_BSHDh_value_info, + key_BHSDh_value_info, + key_BSHkvDh_value_info, + key_transposed_value_info, + value_BHSDh_value_info, + ] + ) + + source_model_ir = ir.serde.from_proto(source_model) + inferred_model = shape_inference.infer_shapes(source_model_ir) + onnxscript.optimizer.optimize(inferred_model) + + count = fuse_sdpa(inferred_model, debug=True) + self.assertGreater(count, 0) + + count = fuse_gqa(inferred_model, debug=True) + self.assertGreater(count, 0) + + fused_model = ir.serde.to_proto(inferred_model) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs3 = session.run(None, inputs) + + self.assertEqual(len(outputs3), len(source_model_outputs)) + assert_allclose(outputs3, source_model_outputs) + + +class GQAFusionTest2(unittest.TestCase): + @unittest.skip("Needs too much memory.") + def test_phi4lm(self): + test_case = phi4lm_test() + model = test_case.get_onnx_model() + onnxscript.optimizer.optimize(model) + optimize_for_ort(model, debug=True) + gqa_nodes = [n for n in model.graph if n.op_type == "GQA"] + self.assertEqual(len(gqa_nodes), 2, "Expected 2i GQA nodes after fusion") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py b/onnxscript/rewriter/ort_fusions/group_normalization_merge_silu.py similarity index 64% rename from onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py rename to onnxscript/rewriter/ort_fusions/group_normalization_merge_silu.py index d4c60e59e1..4bac759ff7 100644 --- a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py +++ b/onnxscript/rewriter/ort_fusions/group_normalization_merge_silu.py @@ -1,24 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import logging -from onnxscript.rewriter import pattern - -op = pattern.onnxop -msft_op = pattern.msft_op -torch_module_op = pattern.torch_module_op +from onnxscript.rewriter._pattern_ir import torch_module_op +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet logger = logging.getLogger(__name__) def group_normalization_and_silu_submodule( + op, input, weight, bias, epsilon, groups, ): - group_norm = msft_op.GroupNorm( + group_norm = op.GroupNorm( input, weight, bias, @@ -26,9 +26,12 @@ def group_normalization_and_silu_submodule( channels_last=1, epsilon=epsilon, groups=groups, + _domain="com.microsoft", ) transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2]) - return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed) + return torch_module_op.submodule("torch_nn_modules_activation_SiLU")( + transposed + ) # TODO(rama) def group_normalization_with_silu( @@ -47,14 +50,14 @@ def group_normalization_with_silu( channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(group_norm, perm=[0, 3, 1, 2]) -group_normalization_merge_silu_submodule_rule = pattern.RewriteRule( +group_normalization_merge_silu_submodule_rule = RewriteRule( group_normalization_and_silu_submodule, group_normalization_with_silu, ) -rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule]) +rules = RewriteRuleSet([group_normalization_merge_silu_submodule_rule]) diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py b/onnxscript/rewriter/ort_fusions/group_normalization_merge_silu_test.py similarity index 97% rename from onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py rename to onnxscript/rewriter/ort_fusions/group_normalization_merge_silu_test.py index ced611685b..dabeaf3851 100644 --- a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py +++ b/onnxscript/rewriter/ort_fusions/group_normalization_merge_silu_test.py @@ -1,10 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import numpy as np import onnx.parser from onnxscript import ir -from onnxscript.rewriter.onnxruntime import ( +from onnxscript.rewriter.ort_fusions import ( group_normalization_merge_silu, instance_to_group_normalization, ) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py similarity index 79% rename from onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py rename to onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py index 1a53d59d3f..8ea43e4b84 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import logging @@ -5,16 +7,13 @@ import numpy as np import onnx -from onnxscript.rewriter import _ir_utils, pattern - -op = pattern.onnxop -msft_op = pattern.msft_op -torch_module_op = pattern.torch_module_op +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet logger = logging.getLogger(__name__) def check_if_simulated_instance_norm_is_used( + context, input_x, adjusted_input_shape, original_input_shape, @@ -41,11 +40,15 @@ def check_if_simulated_instance_norm_is_used( Returns: bool: True if the simulated instance normalization is used, False otherwise. """ - weight_for_norm = _ir_utils.propagate_const_value(weight_for_norm) - weight_for_norm = _ir_utils.get_numpy_from_ir_value(weight_for_norm) + weight_for_norm_const_value = weight_for_norm.const_value + if weight_for_norm_const_value is None: + return False + weight_for_norm = weight_for_norm_const_value.numpy() - bias_for_norm = _ir_utils.propagate_const_value(bias_for_norm) - bias_for_norm = _ir_utils.get_numpy_from_ir_value(bias_for_norm) + bias_for_norm_const_value = bias_for_norm.const_value + if bias_for_norm_const_value is None: + return False + bias_for_norm = bias_for_norm_const_value.numpy() if not np.all(weight_for_norm == 1): return False @@ -69,23 +72,28 @@ def check_if_simulated_instance_norm_is_used( if not all(dim == 1 for dim in bias_full_shape[1:]): return False - adjusted_input_shape = _ir_utils.propagate_const_value(adjusted_input_shape) - adjusted_input_shape = _ir_utils.get_numpy_from_ir_value(adjusted_input_shape) + adjusted_input_shape_const_value = adjusted_input_shape.const_value g = weight_for_norm.shape[0] - if adjusted_input_shape is None or adjusted_input_shape.tolist() != [0, g, -1]: + if ( + adjusted_input_shape_const_value is None + or adjusted_input_shape_const_value.numpy().tolist() != [0, g, -1] + ): return False # NOTE: Restrict the rule to only support constant shape - original_input_shape = _ir_utils.propagate_const_value(original_input_shape) - original_input_shape = _ir_utils.get_numpy_from_ir_value(original_input_shape) - if original_input_shape is None or original_input_shape.tolist() != input_x.shape: + original_input_shape_const_value = original_input_shape.const_value + if ( + original_input_shape_const_value is None + or original_input_shape_const_value.numpy().tolist() != input_x.shape + ): return False return True def instance_simulates_group_normalization_pattern( + op, input_x, adjusted_input_shape, original_input_shape, @@ -128,13 +136,13 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep channels_last=1, epsilon=epsilon, groups=groups, - domain="com.microsoft", + _domain="com.microsoft", ) return op.Transpose(output, perm=[0, 3, 1, 2]) # Register the rewrite rules -instance_norm_to_group_norm_rule = pattern.RewriteRule( +instance_norm_to_group_norm_rule = RewriteRule( instance_simulates_group_normalization_pattern, group_normalization, check_if_simulated_instance_norm_is_used, @@ -142,4 +150,4 @@ def group_normalization(op, input_x, weight_for_norm, weight_full, bias_full, ep # NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule, # so we need to run instance_norm_to_group_norm_with_silu_rule first. -rules = pattern.RewriteRuleSet([instance_norm_to_group_norm_rule]) +rules = RewriteRuleSet([instance_norm_to_group_norm_rule]) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py similarity index 99% rename from onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py rename to onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py index 991a3d44a0..e5754d78d6 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py +++ b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py @@ -1,10 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import numpy as np import onnx.parser from onnxscript import ir -from onnxscript.rewriter.onnxruntime import instance_to_group_normalization +from onnxscript.rewriter.ort_fusions import instance_to_group_normalization class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase): diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py new file mode 100644 index 0000000000..321e895f44 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -0,0 +1,376 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +The MultiHeadAttention pattern: generate an instance + MHA (query, key, value, None, None, mask, past_key, past_value) +where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv). +The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias) +must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh). + +We use the following abbreviations for the dimensions: +B: Batch size +S: Sequence length +D: input embedding dimension +Dv: value hidden size (usually, Dv = D) +H: number of heads +Dh: head size or embedding dimension per head (usually, D = H * Dh) +Skv: key/value sequence length +St: total sequence length + +In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). +The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). +""" + +Dim = Union[int, ir.SymbolicDim] + + +class MultiHeadAttention(pattern.RewriteRuleClassBase): + def __init__( + self, + name, + *, + is_rotary: bool, + has_past_present: bool, + is_cross_attention: bool, + ): + super().__init__(name) + self._is_rotary = is_rotary + self._has_past_present = has_past_present + self._is_cross_attention = is_cross_attention + + def pattern( + self, + op, + query_BSD, + key, + value, + past_key, + past_value, + position_ids, + cos, + sin, + ): + # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H) + + # Reshape from (B, S, D) to (B, S, H, D/H) + query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + if not self._is_cross_attention: + # Reshape from (B, S, D) to (B, S, H, D/H) + key = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"]) + # Key may or may not be transposed at this point, based on usage pattern + key = pattern.OrValue( + [op.Transpose(key, perm=[0, 2, 1, 3]), key], + tag_var="key_transposed", + tag_values=[True, False], + ) + + # Reshape from (B, S, D) to (B, S, H, D/H) + value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3]) + else: + # For cross-attention, key and value are not reshaped + value_BHSDh = value + + if self._is_rotary: + query_BHSDh_emb = op.RotaryEmbedding( + query_BHSDh, position_ids, cos, sin, _domain="com.microsoft" + ) + if not self._is_cross_attention: + key_BHSDh_emb = op.RotaryEmbedding( + key, position_ids, cos, sin, _domain="com.microsoft" + ) + else: + key_BHSDh_emb = key + else: + # If rotary embedding is not used, we fuse with positional_embeddings + query_BHSDh_emb = query_BHSDh + key_BHSDh_emb = key + + # Concatenate past_key cache and current key, and transpose to enable + # dot-product attention computation. + if self._has_past_present: + key_seq = op.Concat(past_key, key_BHSDh_emb, axis=-2) + else: + key_seq = key_BHSDh_emb + + # Concatenate past_value cache and current value + if self._has_past_present: + value_seq = op.Concat(past_value, value_BHSDh, axis=-2) + else: + value_seq = value_BHSDh + + # Key/value to be used for dot-product attention computation + key_seq_to_sdpa = key_seq + value_seq_to_sdpa = value_seq + + sdpa = op.SDPA( + query_BHSDh_emb, + key_seq_to_sdpa, + value_seq_to_sdpa, + _allow_other_inputs=True, + _outputs=["sdpa_output"], + _domain="ai.onnxruntime._fusion", + ) + + # Transpose attention back to (B, S, H, D/H) + attention_transposed = op.Transpose(sdpa, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention = op.Reshape( + attention_transposed, pattern.ANY_VALUE, _outputs=["attention_reshaped"] + ) + if self._has_past_present: + return attention, key_seq, value_seq + else: + return attention + + def check( + self, + op, + query_BSD, + key, + value, + sdpa_output, + past_key, + past_value, + query_BSHDh, + key_transposed=None, + key_BSHDh=None, + value_BSHDh=None, + **_, + ) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() + + sdpa_node = sdpa_output.producer() + + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils.check_shape_bool(bindings, val, dims) + + if no_match(query_BSD, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']", + query_BSD, + ) + + if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): + return check_result.fail( + f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']", + query_BSHDh, + ) + # If cross-attention, key/value shapes are 4D + if self._is_cross_attention: + if no_match(key, ["B", "H", "Skv", "Dh"]): + return check_result.fail( + f"Shape mismatch: {key} does not match expected dimensions ['B', 'H', 'Skv', 'Dh']", + key, + ) + if no_match(value, ["B", "H", "Skv", "Dv"]): + return check_result.fail( + f"Shape mismatch: {value} does not match expected dimensions ['B', 'H', 'Skv', 'Dv']", + value, + ) + # Ensure that no past_key/past_value is used in cross-attention + if past_key is not None: + return check_result.fail( + "past_key should be None in cross-attention.", + past_key, + ) + if past_value is not None: + return check_result.fail( + "past_value should be None in cross-attention.", + past_value, + ) + else: + if no_match(key, ["B", "Skv", "D"]): + return check_result.fail( + f"Shape mismatch: {key} does not match expected dimensions ['B', 'Skv', 'D']", + query_BSD, + ) + sdpa_key_format = sdpa_node.attributes.get_string("key_format") + expected_key_format = "BHSd" if key_transposed else "BSHd" + if sdpa_key_format != expected_key_format: + return check_result.fail( + f"Unexpected key format: {sdpa_key_format}. Expected: {expected_key_format}", + sdpa_node, + ) + if no_match(value, ["B", "Skv", "D"]): + return check_result.fail( + f"Shape mismatch: {value} does not match expected dimensions ['B', 'Skv', 'D']", + value, + ) + if self._has_past_present: + if no_match(past_key, ["B", "H", "Spast", "Dh"]): + return check_result.fail( + f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']", + past_key, + ) + if no_match(past_value, ["B", "H", "Spast", "Dv"]): + return check_result.fail( + f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']", + past_value, + ) + + # mask (aka attention_bias) shape check: + # ONNX's Attention op (named SDPA here) allows a mask broadcastable to (B, H, S, St) + # ORT's contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S, St) + # That is: broadcast allowed only for the first two dimensions. (Even that is not + # supported by some earlier versions of ORT, which are not supported here.) + mask = None + if len(sdpa_node.inputs) > 3: + mask = sdpa_node.inputs[3] + self.mask = mask + if mask is not None: + if (mask_shape := mask.shape) is None: + return check_result.fail( + "Mask shape cannot be determined.", + mask, + ) + if mask_shape.rank() == 4: + if no_match(mask, ["B_or_1", "H_or_1", "S_or_1", "St"]): + return check_result.fail( + f"Shape mismatch: {mask} does not match expected dimensions ['1 or B', '1 or H', '1 or S', 'St']", + mask, + ) + mask_dim_2 = bindings.get("S_or_1") + if mask_dim_2 == bindings.get("S"): + self._use_mask_broadcast = False + elif mask_dim_2 == 1: + self._use_mask_broadcast = True + else: + return check_result.fail( + "Mask dimension 2 cannot be verified to be 1 or S" + ) + elif mask_shape.rank() == 2: + if no_match(mask, ["S_or_1", "St"]): + return check_result.fail( + f"Shape mismatch: {mask} does not match expected dimensions ['1 or S', 'St']", + mask, + ) + self._use_mask_broadcast = True + else: + return check_result.fail( + f"Mask shape {mask_shape} is not supported. Expected 2D or 4D.", + mask, + ) + else: + self._use_mask_broadcast = False + + self._scale = sdpa_node.attributes.get_float("scale", None) + # TODO: verify Reshapes: + # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: + # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: + # or check Reshape's shape-input value + return check_result + + def rewrite( + self, + op, + query_BSD, + key, + value, + past_key, + past_value, + query_BSHDh, + position_ids, + cos, + sin, + **_, + ): + num_heads = _ir_utils.get_dim(query_BSHDh, 2) + if not isinstance(num_heads, int): + return None + + # TODO: forward other attributes + + if self._is_rotary: + query_BSD_emb = op.RotaryEmbedding( + query_BSD, position_ids, cos, sin, _domain="com.microsoft" + ) + if not self._is_cross_attention: + key_BSD_emb = op.RotaryEmbedding( + key, position_ids, cos, sin, _domain="com.microsoft" + ) + else: + key_BSD_emb = key + elif self._is_cross_attention: + query_BSD_emb = query_BSD + # Must convert key/value from 4D to 3D for use in MHA + key = op.Transpose(key, perm=[0, 2, 1, 3]) + key_BSD_emb = op.Reshape(key, op.Constant(value_ints=[0, 0, -1])) + value = op.Transpose(value, perm=[0, 2, 1, 3]) + value = op.Reshape(value, op.Constant(value_ints=[0, 0, -1])) + else: + query_BSD_emb = query_BSD + key_BSD_emb = key + + mask = self.mask + if self._use_mask_broadcast: + one = op.Constant(value_ints=[1]) + S = op.Shape(query_BSD, start=1, end=2) + shape_11S1 = op.Concat(one, one, S, one, axis=0) + mask = op.Expand(mask, shape_11S1) + + num_outputs = 1 + (2 * self._has_past_present) + return op.MultiHeadAttention( + query_BSD_emb, + key_BSD_emb, + value, + None, # bias + None, # key padding mask + mask, # attention mask/bias + past_key, + past_value, + num_heads=num_heads, + _domain="com.microsoft", + _outputs=num_outputs, + scale=self._scale, + ) + + +def _make_rule_set(has_past_present: bool): + parameter_combinations = [ + { + "is_rotary": is_rotary, + "has_past_present": has_past_present, + "is_cross_attention": is_cross_attention, + } + for is_rotary in [False, True] + for is_cross_attention in ([False] if has_past_present else [False, True]) + ] + + # Dynamically create the rules + mha_rules = pattern.RewriteRuleSet( + [ + MultiHeadAttention.rule( + f"MHA" + f"{'_Rotary' if params['is_rotary'] else ''}" + f"{'_Past' if params['has_past_present'] else ''}" + f"{'_CrossAttention' if params['is_cross_attention'] else ''}", + **params, + ) + for params in parameter_combinations + ] + ) + + return mha_rules + + +mha_rules_no_past = _make_rule_set(has_past_present=False) +mha_rules_with_past = _make_rule_set(has_past_present=True) + +# Try rules with past first, and then rules without past. +fuse_mha1 = _fusion_utils.apply_fusion_rules(mha_rules_with_past) +fuse_mha2 = _fusion_utils.apply_fusion_rules(mha_rules_no_past) diff --git a/onnxscript/rewriter/ort_fusions/mha_bias.py b/onnxscript/rewriter/ort_fusions/mha_bias.py new file mode 100644 index 0000000000..9ecf2ce017 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/mha_bias.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import numpy +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, pattern + +valid_float_types = [ir.DataType.FLOAT, ir.DataType.FLOAT16] + +Dim = Union[int, ir.SymbolicDim] + + +class FuseBiasMHA(pattern.RewriteRuleClassBase): + def pattern( + self, + op, + query_matmul, + key_matmul, + value_matmul, + q_bias, + k_bias, + v_bias, + mask, + past_key, + past_value, + num_heads, + ): + query_BSD = pattern.OrValue( + [op.Add(query_matmul, q_bias), query_matmul], + tag_var="has_q_bias", + tag_values=[True, False], + ) + key_BSD = pattern.OrValue( + [op.Add(key_matmul, k_bias), key_matmul], + tag_var="has_k_bias", + tag_values=[True, False], + ) + value_BSD = pattern.OrValue( + [op.Add(value_matmul, v_bias), value_matmul], + tag_var="has_v_bias", + tag_values=[True, False], + ) + + return op.MultiHeadAttention( + query_BSD, + key_BSD, + value_BSD, + None, # bias + None, # key padding mask + pattern.Var("mask", can_match_none=True), # attention mask/bias + pattern.Var("past_key", can_match_none=True), + pattern.Var("past_value", can_match_none=True), + num_heads=num_heads, + scale=pattern.AttrVar("scale", can_match_none=True), + _domain="com.microsoft", + ) + + def check( + self, + context, + query_matmul, + key_matmul, + value_matmul, + has_q_bias, + has_k_bias, + has_v_bias, + **_, + ) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() + + if not (has_q_bias or has_k_bias or has_v_bias): + return check_result.fail("None of query, key, or value have a bias.") + + self.bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) + + if query_matmul.dtype not in valid_float_types: + return check_result.fail("Query is not a float or float16 type.", query_matmul) + if key_matmul.dtype not in valid_float_types: + return check_result.fail("Key is not a float or float16 type.", key_matmul) + if value_matmul.dtype not in valid_float_types: + return check_result.fail("Value is not a float or float16 type.", value_matmul) + + if no_match(query_matmul, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {query_matmul} does not match expected dimensions ['B', 'S', 'D']", + query_matmul, + ) + if no_match(key_matmul, ["B", "Skv", "Dk"]): + return check_result.fail( + f"Shape mismatch: {key_matmul} does not match expected dimensions ['B', 'Skv', 'Dk']", + key_matmul, + ) + if no_match(value_matmul, ["B", "Skv", "Dv"]): + return check_result.fail( + f"Shape mismatch: {value_matmul} does not match expected dimensions ['B', 'Skv', 'Dv']", + value_matmul, + ) + + self.Dh_q = self.bindings.get("D") + self.Dh_k = self.bindings.get("Dk") + self.Dh_v = self.bindings.get("Dv") + + if ( + not isinstance(self.Dh_q, int) + or not isinstance(self.Dh_k, int) + or not isinstance(self.Dh_v, int) + ): + return check_result.fail( + "Could not determine the hidden sizes of query, key, and value.", + ) + + return check_result + + def rewrite( + self, + op, + query_matmul, + key_matmul, + value_matmul, + q_bias, + k_bias, + v_bias, + mask, + past_key, + past_value, + num_heads, + scale, + **_, + ): + if q_bias is None: + q_bias = op.Constant( + value=ir.tensor(numpy.zeros((self.Dh_q,), dtype=query_matmul.dtype.numpy())) + ) + if k_bias is None: + k_bias = op.Constant( + value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=key_matmul.dtype.numpy())) + ) + if v_bias is None: + v_bias = op.Constant( + value=ir.tensor(numpy.zeros((self.Dh_v,), dtype=value_matmul.dtype.numpy())) + ) + bias = op.Concat(q_bias, k_bias, v_bias, axis=0) + return op.MultiHeadAttention( + query_matmul, + key_matmul, + value_matmul, + bias, + None, + mask, + past_key, + past_value, + num_heads=num_heads, + scale=scale, + _domain="com.microsoft", + ) + + +mha_bias_rules = pattern.RewriteRuleSet([FuseBiasMHA.rule()]) + + +fuse_mha_bias = _fusion_utils.apply_fusion_rules(mha_bias_rules) diff --git a/onnxscript/rewriter/ort_fusions/mha_scale.py b/onnxscript/rewriter/ort_fusions/mha_scale.py new file mode 100644 index 0000000000..e02e6c49e3 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/mha_scale.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +Multi-Head Attention (MHA) pre-scaling fusion patterns. + +This module contains rewrite rules for fusing scale operations that occur before +Multi-Head Attention operations. The fusion optimizes patterns where a query tensor +is scaled before being passed to MHA by incorporating the scaling directly into +the MHA operation. + +Example pattern: + query -> Mul(scale) -> MultiHeadAttention -> output + +Gets rewritten to: + query -> MultiHeadAttention(with integrated scaling) -> output +""" + + +class FuseMHAScale(pattern.RewriteRuleClassBase): + def pattern(self, op, query, scale): + scaled_query = op.Mul(query, scale) + mha_output = op.MultiHeadAttention( + scaled_query, + _allow_other_inputs=True, + _domain="com.microsoft", + _outputs=["mha_output"], + ) + return mha_output + + def check(self, context, scale, **_): + scale_value = _ir_utils.get_singleton_value(scale) + if scale_value is None or not isinstance(scale_value, (int, float)): + return pattern.MatchResult().fail("Scale must be a constant numeric value.", scale) + self._scale = scale_value + return True + + def rewrite(self, op, query, mha_output, **_): + # Integrate the scale into the MHA operation + mha_node = mha_output.producer() + assert mha_node is not None + # Compute original scale factor for MHA: + attributes = mha_node.attributes + original_scale = attributes.get_float("scale", None) + if original_scale is None: + num_heads = attributes.get_int("num_heads", None) + if num_heads is None: + return None + head_size = query.shape[-1] // num_heads + original_scale = 1.0 / math.sqrt(head_size) + self._scale *= original_scale + inputs = list(mha_node.inputs) + inputs[0] = query + attributes = dict(attributes) + attributes["scale"] = self._scale + return op.MultiHeadAttention( + *inputs, **attributes, _domain="com.microsoft", _outputs=1 + ) + + +_mha_scale_rules = pattern.RewriteRuleSet([FuseMHAScale.rule()]) + +fuse_mha_scale = _fusion_utils.apply_fusion_rules(_mha_scale_rules) diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py new file mode 100644 index 0000000000..b3fbfafd3d --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx_ir.passes.common as common_passes +import packaging.version + +import onnxscript.optimizer +import onnxscript.rewriter.ort_fusions._core as xformers +from onnxscript.rewriter.models._phi2lm import phi2lm_test +from onnxscript.rewriter.models._smollm_2 import smollm_test_2 +from onnxscript.rewriter.models._whisper_decoder import whisper_decoder_test +from onnxscript.rewriter.models._whisper_encoder import whisper_encoder_test +from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run + + +class TestMultiHeadAttention(unittest.TestCase): + def test_smollm(self): + # Generate model + smollm_test = smollm_test_2() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + xformers.fuse_rms_normalization(model) + xformers.fuse_skip_rms_normalization(model) + xformers.fuse_rotary_embedding(model) + xformers.fuse_cos_sin_cache(model) + + test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION + if test_with_ort: + # Run model + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + # Fuse SDPA and MHA + sdpa_count = xformers.fuse_sdpa(model) + self.assertGreater(sdpa_count, 0) + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) + self.assertGreater(mha_count, 0) + + if test_with_ort: + # Run model again + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_whisper_encoder(self): + # Generate model + whisper_encoder = whisper_encoder_test() + model = whisper_encoder.get_onnx_model() + onnxscript.optimizer.optimize(model) + + test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION + if test_with_ort: + # Run model + inputs = whisper_encoder.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + # Fuse SDPA and MHA + sdpa_count = xformers.fuse_sdpa(model, debug=True) + self.assertGreater(sdpa_count, 0) + model = common_passes.ShapeInferencePass()(model).model + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) + self.assertGreater(mha_count, 0) + onnxscript.optimizer.optimize(model) + + if test_with_ort: + # Run model again + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_whisper_decoder(self): + # Generate model + whisper_decoder = whisper_decoder_test() + model = whisper_decoder.get_onnx_model() + onnxscript.optimizer.optimize(model) + + test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION + if test_with_ort: + # Run model + inputs = whisper_decoder.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + # Fuse SDPA and MHA + sdpa_count = xformers.fuse_sdpa(model) + self.assertGreater(sdpa_count, 0) + model = common_passes.ShapeInferencePass()(model).model + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) + self.assertGreater(mha_count, 0) + onnxscript.optimizer.optimize(model) + + if test_with_ort: + # Run model again + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_phi2lm(self): + test_case = phi2lm_test() + model = test_case.get_onnx_model() + onnxscript.optimizer.optimize(model) + xformers.optimize_for_ort(model) + mha_nodes = [n for n in model.graph if n.op_type == "MultiHeadAttention"] + self.assertEqual( + len(mha_nodes), + 1, + "Expected exactly one MultiHeadAttention node after optimization", + ) + mha_node = mha_nodes[0] + # Check that the MHA node has past kv cache inputs + self.assertEqual(len(mha_node.inputs), 8, "Expected MHA node to have 8 inputs") + self.assertIsNotNone(mha_node.inputs[6], "Expected MHA node to have past key input") + self.assertIsNotNone(mha_node.inputs[7], "Expected MHA node to have past value input") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py new file mode 100644 index 0000000000..de6e51a5c0 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +RMS Normalization: This is referred to as SimplifiedLayerNormalization in the ORT codebase. +See https://github.com/microsoft/onnxruntime/blob/6d9636f07cccdb6e4ac453087ad54c3bc9854d50/onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L2981 + +Key points for the fusion optimization: +* Input and scale are allowed to be of different types. +* The normalization of the input can be done in a different precision than the input type, +which is also the precision of reciprocal_rms returned by operation. +* Input (x) must be: float or double or float16 or bfloat16 +* Scale must be: float or double or float16 or bfloat16 +* Normalization precision must be float or double +""" + +float_types = frozenset( + [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, + ] +) +fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) + + +class RmsNormFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): + x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) + x_square = op.Pow(x, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x, reciprocal_rms) + normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) + # To support float16, we need to ensure the scale is casted or not. + scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale]) + return op.Mul(scale, normalized) + + def check( + self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] + """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() + # epsilon must be a scalar + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if not isinstance(epsilon_value, float): # TODO: support other types + return check_result.fail("Epsilon is not a float value.", epsilon) + if x.dtype not in float_types: + return check_result.fail("Input is not a float type.", x) + if scale.dtype not in float_types: + return check_result.fail("Scale is not a float type.", scale) + self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype + if self._stash_dtype not in fp_float_types: + return check_result.fail("Normalization precision is not a float or double type.") + # target_dtype is guaranteed to be the same as scale type in a well-typed input + # for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. + # TODO (rama): Consider adding checks to protect against incorrectly typed models: + return check_result + + def rewrite(self, op, x, scale, epsilon, **_): + # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. + # No need to use com.microsoft domain here; but this is a custom op in ORT. + return op.SimplifiedLayerNormalization( + x, + scale, + axis=-1, + epsilon=_ir_utils.get_singleton_value(epsilon), + stash_type=self._stash_dtype, + ) + + +_rule = RmsNormFusion.rule() +rms_normalization_rules = [_rule] +rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) + + +fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset) diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization_test.py b/onnxscript/rewriter/ort_fusions/rms_normalization_test.py new file mode 100644 index 0000000000..89b9f71253 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/rms_normalization_test.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.models._smollm_1 import smollm_test_1 +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization + + +class TestRmsNormalization(unittest.TestCase): + def test_smollm(self): + smollm_test = smollm_test_1() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + fuse_rms_normalization(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("SimplifiedLayerNormalization", op_types) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py new file mode 100644 index 0000000000..b9d4015f06 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern +# for full rotation without interleaving. +# TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation). + +# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet. +# so it can't be tested by running against ORT. See cos_sin_cache.py for a transformation that +# rewrites the pattern into one that can be run against ORT. + + +def _rotate_half_pattern(op, x, start1, end1, start2, end2): + # Slice(input, starts, ends, axes, steps) + x1 = op.Slice(x, start1, end1, [3], [1]) + x2 = op.Slice(x, start2, end2, [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + return rotated_x + + +class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__(name="RotaryEmbedding", as_function=True) + + def pattern(self, op, x, cos, sin, start1, end1, start2, end2): + return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + + def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) + if x is None or x.shape is None or len(x.shape) != 4: + return check_result.fail("Input is not a 4D tensor.", x) + if not isinstance(x.shape[1], int): + return check_result.fail("Input dimension 1 is not an integer.", x) + head_size = x.shape[3] + if not isinstance(head_size, int): + return check_result.fail("Head size is not an integer.", x) + half_head_size = head_size // 2 + + # Check that x is being split into two equal halves of size half_head_size + if not ( + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.is_singleton_value(end1, half_head_size) + and _ir_utils.is_singleton_value(start2, half_head_size) + and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) + ): + return check_result.fail( + "x is not being split into two equal halves of size half_head_size." + ) + return check_result + + def rewrite(self, op, x, cos, sin, **_): + num_heads = x.shape[1] + return op.RotaryEmbedding( + x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime._fusion" + ) + + +class PartialRotaryEmbeddingFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, end1, start2): + x_part_1 = op.Slice(x, [0], end1, [3], [1]) + x_part_2 = op.Slice(x, start2, [9223372036854775807], [3], [1]) + x_part_1_rope = op.RotaryEmbedding( + x_part_1, + _allow_other_inputs=True, + _allow_other_attributes=True, + _domain="com.microsoft", + _outputs=["x_part_1_rope"], + ) + return op.Concat(x_part_1_rope, x_part_2, axis=-1) + + def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() + end1_value = _ir_utils.get_singleton_value(end1) + start2_value = _ir_utils.get_singleton_value(start2) + if not isinstance(end1_value, int) or not isinstance(start2_value, int): + return check_result.fail( + "The end1 value of first slice and start2 value of second slice are not integers." + ) + if end1_value != start2_value: + return check_result.fail( + "The end1 value of first slice and start2 value of second slice are not equal." + ) + rotary_embedding_attributes = x_part_1_rope.producer().attributes + if "rotary_embedding_dim" in rotary_embedding_attributes: + return check_result.fail("rotary_embedding_dim attribute already specified.") + if ( + "interleaved" in rotary_embedding_attributes + and rotary_embedding_attributes["interleaved"].value != 0 + ): + return check_result.fail("interleaved is not equal to 0.") + return check_result + + def rewrite(self, op, x, end1, x_part_1_rope, **_): + # Create a modified version of the RotaryEmbedding op: + rotary_embedding_dim = _ir_utils.get_singleton_value(end1) + original_node = x_part_1_rope.producer() + inputs = list(original_node.inputs) + inputs[0] = x + attrs = dict(original_node.attributes) + attrs["rotary_embedding_dim"] = rotary_embedding_dim + return op.RotaryEmbedding( + *inputs, + **attrs, + _domain="com.microsoft", + ) + + +_rule = RotaryEmbeddingFusion.rule() + +_partial_embedding_rule = PartialRotaryEmbeddingFusion.rule() + +rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) + +partial_embedding_rules = pattern.RewriteRuleSet([_partial_embedding_rule]) + + +fuse_rotary_embedding = _fusion_utils.apply_fusion_rules(rotary_embedding_rules) + + +fuse_partial_rotary_embedding = _fusion_utils.apply_fusion_rules(partial_embedding_rules) diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding_test.py b/onnxscript/rewriter/ort_fusions/rotary_embedding_test.py new file mode 100644 index 0000000000..4ab945f653 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding_test.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +from parameterized import parameterized + +import onnxscript.optimizer +from onnxscript.rewriter.models import _rotary_embedding_models, _smollm_1 +from onnxscript.rewriter.ort_fusions import rotary_embedding + + +class TestRotaryEmbedding(unittest.TestCase): + @parameterized.expand( + [ + ( + "test_case_1", + _rotary_embedding_models.test_case_1, + ), + ( + "smollm_test_1", + _smollm_1.smollm_test_1, + ), + ] + ) + def test_rotary_embedding_fusion(self, _: str, test_data_constructor): + test = test_data_constructor() + model = test.get_onnx_model() + onnxscript.optimizer.optimize(model) + rotary_embedding.fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py new file mode 100644 index 0000000000..55b38e9ad4 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math +from typing import Union + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern +from onnxscript.rewriter._basics import MatchFailureError + +Dim = Union[int, ir.SymbolicDim] + + +class SDPA(pattern.RewriteRuleClassBase): + _scale: float | None + + def pattern( + self, + op, + query, + key, + value, + mask, + query_scale, + key_scale, + qk_scale, + ): + # The last two axes of key must be transposed before computing the dot product with query. + # Three patterns are observed in practice: + + # Pattern 1: Transpose 4D key directly: BHSd => BHdS + key_transposed_1 = op.Transpose(key, perm=[0, 1, 3, 2]) + + # Pattern 2: Transpose key after converting to 3D and then convert back to 4D: BHSd => 3D => BHdS + key_3d = op.Reshape(key, pattern.ANY_VALUE) + key_3d_transposed = op.Transpose(key_3d, perm=[0, 2, 1]) + key_transposed_2 = op.Reshape(key_3d_transposed, pattern.ANY_VALUE) + + # Pattern 3: This transpose is sometimes composed with an earlier transpose to convert + # the key from BSHd format to BHSd format. + key_transposed_3 = op.Transpose(key, perm=[0, 2, 3, 1]) + + key_transposed = pattern.OrValue( + [key_transposed_1, key_transposed_2, key_transposed_3], + tag_var="key_format", + tag_values=["BHSd", "BHSd", "BSHd"], + ) + + # Some implementations scale the query and key before computing the dot product + query = pattern.OrValue( + [ + op.Mul(query, query_scale), + op.Div(query, query_scale), + query, + ], + tag_var="query_scaling", + tag_values=["Mul", "Div", "None"], + ) + key_transposed = pattern.OrValue( + [ + op.Mul(key_transposed, key_scale), + op.Div(key_transposed, key_scale), + key_transposed, + ], + tag_var="key_scaling", + tag_values=["Mul", "Div", "None"], + ) + + attn_score = op.MatMul(query, key_transposed) + + # Some implementations scale the dot product. + attn_score = pattern.OrValue( + [ + op.Mul(attn_score, qk_scale), + op.Div(attn_score, qk_scale), + attn_score, + ], + tag_var="qk_scaling", + tag_values=["Mul", "Div", "None"], + ) + + # Some implementations add a mask to the dot product. + masked_attn_score = op.Add(attn_score, mask) + attn_score = pattern.OrValue( + [masked_attn_score, attn_score], tag_var="has_mask", tag_values=[True, False] + ) + + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + adj_attn_weight = op.Where(is_nan, 0.0, attn_weight) + attn_weight = pattern.OrValue([adj_attn_weight, attn_weight]) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + def check( + self, + context, + query: ir.Value | None, + key: ir.Value | None, + value: ir.Value | None, + mask: ir.Value | None, + key_format: str, + **match_bindings, + ): + check_result = pattern.MatchResult() + + bindings: dict[str, Dim] = {} + + # Check that query/key/value have the expected shapes: + # They all should have same batch-size (B) and number of heads (H). Conceptually, it is + # different for Q and K/V, but the certain op implementations require them to be the same, + # which is usually achieved via tiling/expanding K/V num-heads to match Q num-heads. + # Query and Key should have same head-size (Dh) while value can have different head-size (Dv). + # Key and Value should have same sequence length (Skv), while Query can have different sequence length (S). + _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) + if key_format == "BHSd": + _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) + else: + assert key_format == "BSHd", f"Unexpected key format: {key_format}" + _fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"]) + _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) + + def get_scale_value(tag_name: str, scale_name: str) -> float: + scaling_type = match_bindings.get(tag_name, "None") + if scaling_type == "None": + return 1.0 + else: + scale = match_bindings.get(scale_name) + value = _ir_utils.get_singleton_value(scale) + if value is None: + raise MatchFailureError(f"{scale_name} is not a scalar.", scale) + if scaling_type == "Mul": + return value + else: + assert scaling_type == "Div", f"Unexpected {scale_name} scaling operation" + return 1.0 / value + + query_scale_value = get_scale_value("query_scaling", "query_scale") + key_scale_value = get_scale_value("key_scaling", "key_scale") + qk_scale_value = get_scale_value("qk_scaling", "qk_scale") + + self._scale = query_scale_value * key_scale_value * qk_scale_value + + # If the scaling factor is the default one, we can skip passing it to SDPA. + + head_size = bindings["Dh"] + if not isinstance(head_size, int): + return check_result + + default_scaling_factor = 1.0 / math.sqrt(head_size) + + if math.isclose(self._scale, default_scaling_factor, rel_tol=1e-5, abs_tol=1e-8): + # Pass no scaling factor to SDPA, SDPA will use the default scaling factor + self._scale = None + + return check_result + + def rewrite( + self, + op, + query: ir.Value | None, + key: ir.Value | None, + value: ir.Value | None, + mask: ir.Value | None, + key_format: str, + **_, + ): + sdpa_args = [query, key, value] + if mask is not None: + sdpa_args.append(mask) + # If the scale is None, SDPA will use the default scaling factor, which is 1/sqrt(head_size). + return op.SDPA( + *sdpa_args, + scale=self._scale, + key_format=key_format, + _domain="ai.onnxruntime._fusion", + ) + + +# Dynamically create the rules +sdpa_rules = pattern.RewriteRuleSet( + [ + SDPA.rule(), + ] +) + +fuse_sdpa = _fusion_utils.apply_fusion_rules(sdpa_rules) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py new file mode 100644 index 0000000000..c5326a77b9 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -0,0 +1,424 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""SDPA fusion test cases.""" + +from __future__ import annotations + +import math +import unittest + +import numpy +import onnx_ir as ir +import parameterized + +import onnxscript.optimizer +from onnxscript import script +from onnxscript.onnx_opset import opset18 as op +from onnxscript.onnx_types import FLOAT +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa +from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha + +B = 2 # batch size +N = 4 # number of heads +S = 8 # sequence length +H = 128 # head size +SCALE_FACTOR = math.sqrt(H) +MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR +SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR) +SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) +# Custom scale factors for testing +CUSTOM_SCALE_FACTOR = 1.0 / math.sqrt(80) +CUSTOM_MUL_SCALE_FACTOR = CUSTOM_SCALE_FACTOR +CUSTOM_DIV_SCALE_FACTOR = 1.0 / CUSTOM_SCALE_FACTOR +SQRT_CUSTOM_MUL_SCALE_FACTOR = math.sqrt(CUSTOM_MUL_SCALE_FACTOR) +SQRT_CUSTOM_DIV_SCALE_FACTOR = math.sqrt(CUSTOM_DIV_SCALE_FACTOR) + + +@script() +def _unmasked_pre_div_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _unmasked_pre_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _unmasked_post_div_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _unmasked_post_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=MUL_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _custom_scale_pre_div_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _custom_scale_pre_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier_q = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) + multiplier_k = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier_q) + scaled_key = op.Mul(key_transposed, multiplier_k) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _custom_scale_post_div_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _custom_scale_post_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_pre_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_pre_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_post_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_post_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=MUL_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +@script() +def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +class SDPATestCase: + def __init__(self, script_func, *, with_mask): + self.script_func = script_func + self.with_mask = with_mask + + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + qkv_type = FLOAT[B, N, S, H] + mask_type = FLOAT[B, N, S, S] + input_types = [qkv_type, qkv_type, qkv_type] + if self.with_mask: + input_types.append(mask_type) + model_proto = self.script_func.to_model_proto( + input_types=input_types, output_types=[qkv_type] + ) + self._onnx_model = ir.serde.deserialize_model(model_proto) + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "value": numpy.random.rand(B, N, S, H).astype(numpy.float32), + } + if self.with_mask: + inputs["mask"] = numpy.random.rand(B, N, S, S).astype(numpy.float32) + self._ort_inputs = inputs + return self._ort_inputs + + +class InvalidSDPATestCase: + def __init__(self, script_func): + self.script_func = script_func + + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + qk_type = FLOAT[B, N, S, H] + # We broadcast value in the batch dimension, which is not supported by SDPA fusion + v_type = FLOAT[1, N, S, H] + mask_type = FLOAT[B, N, S, S] + model_proto = self.script_func.to_model_proto( + input_types=[qk_type, qk_type, v_type, mask_type], output_types=[qk_type] + ) + self._onnx_model = ir.serde.deserialize_model(model_proto) + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "value": numpy.random.rand(1, N, S, H).astype(numpy.float32), + "mask": numpy.random.rand(B, N, S, S).astype(numpy.float32), + } + self._ort_inputs = inputs + return self._ort_inputs + + +class TestSDPAFusion(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("pre_div", _unmasked_pre_div_sdpa_script), + ("pre_mul", _unmasked_pre_mul_sdpa_script), + ("post_div", _unmasked_post_div_sdpa_script), + ("post_mul", _unmasked_post_mul_sdpa_script), + ("masked_pre_div", _masked_pre_div_sdpa_script), + ("masked_pre_mul", _masked_pre_mul_sdpa_script), + ("masked_post_div", _masked_post_div_sdpa_script), + ("masked_post_mul", _masked_post_mul_sdpa_script), + ("custom_scale_post_mul", _custom_scale_post_mul_sdpa_script), + ("custom_scale_post_div", _custom_scale_post_div_sdpa_script), + ("custom_scale_pre_mul", _custom_scale_pre_mul_sdpa_script), + ("custom_scale_pre_div", _custom_scale_pre_div_sdpa_script), + ("masked_custom_scale_post_mul", _masked_custom_scale_post_mul_sdpa_script), + ("masked_custom_scale_post_div", _masked_custom_scale_post_div_sdpa_script), + ("masked_custom_scale_pre_mul", _masked_custom_scale_pre_mul_sdpa_script), + ("masked_custom_scale_pre_div", _masked_custom_scale_pre_div_sdpa_script), + ( + "_custom_multi_scale_pre_mul_sdpa_script", + _custom_multi_scale_pre_mul_sdpa_script, + ), + ] + ) + def test_sdpa_fusion(self, name, script_func): + test_case = SDPATestCase(script_func, with_mask="masked" in name) + model = test_case.get_onnx_model() + onnxscript.optimizer.optimize(model) + + inputs = test_case.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + count = fuse_sdpa(model, debug=True) + self.assertGreater(count, 0) + + # Check that the fusion was successful + op_types = [n.op_type for n in model.graph] + self.assertIn("SDPA", op_types) + + # Ensure that the scale of the SDPA node is set correctly + sdpa_node = next(n for n in model.graph if n.op_type == "SDPA") + self.assertEqual(sdpa_node.op_type, "SDPA") + + if "custom" in name: + self.assertIsNotNone(sdpa_node.attributes.get("scale")) + scale_factor = sdpa_node.attributes["scale"].value + self.assertAlmostEqual(scale_factor, CUSTOM_SCALE_FACTOR, delta=1e-8) + else: + # These tests are for the default scaling factors, no scale factor is passed to SDPA + # pattern rewriting check functions should be sufficient to check if expected value + # of scale_factor (is =default_scaling_factor) + self.assertIsNone(sdpa_node.attributes.get("scale")) + + replace_sdpa_by_mha(model, debug=True) + + self.assertNotIn("SDPA", [n.op_type for n in model.graph]) + + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_invalid_sdpa_fusion_value_batch_dim(self): + test_case = InvalidSDPATestCase(_masked_pre_mul_sdpa_script) + model = test_case.get_onnx_model() + onnxscript.optimizer.optimize(model) + count = fuse_sdpa(model) + self.assertEqual(count, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py new file mode 100644 index 0000000000..e6484406a9 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Union + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class SDPAImplementation(pattern.RewriteRuleClassBase): + def pattern(self, op, query, key, value): + return op.SDPA( + query, + key, + value, + key_format="BHSd", + _allow_other_inputs=True, # Mask is optional + _outputs=["sdpa_output"], + _domain="ai.onnxruntime._fusion", + ) + + def check(self, context, query, key, value, sdpa_output): + bindings: dict[str, Dim] = {} + _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) + _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) + _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) + + self._num_heads = bindings["H"] + if not isinstance(self._num_heads, int): + return False + self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed + return isinstance(self._num_heads, int) + + def rewrite(self, op, query, key, value, sdpa_output): + sdpa_node = sdpa_output.producer() + scale = sdpa_node.attributes.get("scale", None) + to_3d_shape = op.Constant(value_ints=[0, 0, -1]) + to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1]) + query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape) + key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) + value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape) + + inputs = [query_3d, key_3d, value_3d] + if len(sdpa_node.inputs) > 3: + mask = sdpa_node.inputs[3] + + if self._use_mask_broadcast: + one = op.Constant(value_ints=[1]) + query_length = op.Shape(query, start=2, end=3) + shape_11S1 = op.Concat(one, one, query_length, one, axis=0) + mask = op.Expand(mask, shape_11S1) + + inputs.extend([None, None, mask]) + + output = op.MultiHeadAttention( + *inputs, + num_heads=self._num_heads, + scale=scale, + _domain="com.microsoft", + ) + output_4d = op.Reshape(output, to_4d_shape) + output = op.Transpose(output_4d, perm=[0, 2, 1, 3]) + return output + + +_rules = pattern.RewriteRuleSet([SDPAImplementation.rule()]) + +replace_sdpa_by_mha = _fusion_utils.apply_fusion_rules(_rules) diff --git a/onnxscript/rewriter/ort_fusions/shape_optimization.py b/onnxscript/rewriter/ort_fusions/shape_optimization.py new file mode 100644 index 0000000000..521a32ed1e --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/shape_optimization.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Optimization for shape operations.""" + +from __future__ import annotations + +import onnx_ir as ir + +import onnxscript.rewriter._ir_utils as _ir_utils +import onnxscript.rewriter.pattern as pattern + + +class ExtractDim(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__(remove_nodes=False) + + """This is a pattern observed in causal mask generation that hinders fusion optimizations. + It can be simplified away. + """ + + def pattern(self, op, x, dim0, dim1, dim2, dim3, start, end): + shape = op.Concat(dim0, dim1, dim2, dim3, axis=0) + # Note: The allowzero=1 attribute enables us to infer that the shape of the + # reshaped tensor is the same as the value of the shape parameter below. + # Otherwise, we need to know that there are no zeros in the value of "shape" + # for this optimization to be valid. + reshaped = op.Reshape(x, shape, allowzero=1) + transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) + final_shape = op.Shape(transposed, _outputs=["final_shape"]) + final_dim = op.Slice(final_shape, start, end) + return final_dim + + def check(self, context, dim0, dim1, dim2, dim3, final_shape, start, end, **_) -> bool: + # All of the dimensions should have shape [1] + for dim in (dim0, dim1, dim2, dim3): + if dim.shape is None or dim.shape.dims != (1,): + return False + + # The Shape op should return the full shape, not a slice of the shape. + shape_node = final_shape.producer() + if "end" in shape_node.attributes: + return False + if "start" in shape_node.attributes: + start_attr = shape_node.attributes["start"] + if not (isinstance(start_attr, ir.Attr) and start_attr.value == 0): + return False + self._start_val = _ir_utils.get_singleton_value(start) + self._end_val = _ir_utils.get_singleton_value(end) + if self._start_val is None or self._end_val is None: + return False + return True + + def rewrite(self, op, dim0, dim1, dim2, dim3, **_): + transposed_dims = [dim0, dim2, dim1, dim3] + sliced_result = transposed_dims[self._start_val : self._end_val] + if len(sliced_result) == 0: + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) + if len(sliced_result) == 1: + return op.Identity(sliced_result[0]) + return op.Concat(*sliced_result, axis=0) + + +rules = pattern.RewriteRuleSet([ExtractDim.rule()]) diff --git a/onnxscript/rewriter/ort_fusions/shape_optimization_test.py b/onnxscript/rewriter/ort_fusions/shape_optimization_test.py new file mode 100644 index 0000000000..f563ef58d5 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/shape_optimization_test.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import parameterized + +from onnxscript import FLOAT, INT64, ir, opset18, script +from onnxscript.rewriter.ort_fusions import shape_optimization + + +def _make_model(starts: list[int], ends: list[int]) -> onnx.ModelProto: + @script() + def model_script( + x: FLOAT["N"], # noqa: F821 + dim0: INT64[1], + dim1: INT64[1], + dim2: INT64[1], + dim3: INT64[1], + ) -> INT64["M"]: # noqa: F821 + shape = opset18.Concat(dim0, dim1, dim2, dim3, axis=0) + reshaped = opset18.Reshape(x, shape, allowzero=1) + transposed = opset18.Transpose(reshaped, perm=[0, 2, 1, 3]) + final_shape = opset18.Shape(transposed) + final_dim = opset18.Slice(final_shape, starts, ends) + return opset18.Add(final_dim, final_dim) + + model_proto = model_script.to_model_proto() + return model_proto + + +# Example input data +_model_inputs = { + "x": np.zeros((24,), dtype=np.float32), + "dim0": np.array([2], dtype=np.int64), + "dim1": np.array([3], dtype=np.int64), + "dim2": np.array([4], dtype=np.int64), + "dim3": np.array([1], dtype=np.int64), +} + + +class ShapeOptimizationTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ([0], [1], "singleton"), + ([1], [3], "two_elements"), + ([1], [-1], "negative_index"), + ([-2], [1000], "out_of_bounds"), + ([-200], [-1], "negative_out_of_bounds"), + ([2], [2], "empty_slice"), + ] + ) + def test_shape_optimization(self, starts: list[int], ends: list[int], _name: str): + model_proto = _make_model(starts, ends) + model = ir.serde.deserialize_model(model_proto) + + count = shape_optimization.rules.apply_to_model(model) + self.assertEqual(count, 1) + optimized_proto = ir.serde.serialize_model(model) + + import onnxruntime as ort + + sess = ort.InferenceSession( + model_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + outputs = sess.run(None, _model_inputs) + sess = ort.InferenceSession( + optimized_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + optimized_outputs = sess.run(None, _model_inputs) + for orig, opt in zip(outputs, optimized_outputs): + np.testing.assert_array_equal(orig, opt) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py new file mode 100644 index 0000000000..c76a7454cb --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -0,0 +1,265 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, pattern + +Dim = Union[int, ir.SymbolicDim] + +# Fusion rule for SkipRMSNormalization + + +class SkipRmsNormFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False): + """Fusion rule for SkipRMSNormalization.""" + super().__init__(name=name) + self._has_bias = has_bias + self._bias_pre_add = bias_pre_add + + def pattern(self, op, input, skip, gamma, bias, epsilon, stash_type): + if self._has_bias and self._bias_pre_add: + input = op.Add(input, bias) + + # Support different combinations of addition of input and skip + skip_sum_pattern_1 = op.Add(skip, input) + skip_sum_pattern_2 = op.Add(input, skip) + skip_sum = pattern.OrValue( + [skip_sum_pattern_1, skip_sum_pattern_2], + name="skip_sum", + ) + + if self._has_bias and not self._bias_pre_add: + skip_sum = op.Add(skip_sum, bias) + # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. + # No need to use com.microsoft domain here; but this is a custom op in ORT. + normalized = op.SimplifiedLayerNormalization( + skip_sum, + gamma, + axis=-1, + _allow_other_attributes=True, + _outputs=["simplified_layer_norm"], + ) + return normalized, skip_sum + + def check( + self, + op, + input, + skip, + gamma, + bias, + simplified_layer_norm, + **_, + ) -> pattern.MatchResult: # type: ignore[name-defined] + """Check if the pattern matches conditions for use of SkipSimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils.check_shape_bool(bindings, val, dims) + + if no_match(input, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']", + input, + ) + if no_match(skip, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {skip} does not match expected dimensions ['B', 'S', 'D']", + skip, + ) + if no_match(gamma, ["D"]): + return check_result.fail( + f"Shape mismatch: {gamma} does not match expected dimensions ['D']", + gamma, + ) + if self._has_bias: + if no_match(bias, ["D"]): + return check_result.fail( + f"Shape mismatch: {bias} does not match expected dimensions ['D']", + bias, + ) + + stash_type = simplified_layer_norm.producer().attributes.get_int("stash_type", 1) + if stash_type != 1: + return check_result.fail("Stash type is not supported.") + + return check_result + + def rewrite( + self, + op, + input, + skip, + gamma, + bias, + simplified_layer_norm, + **_, + ): + epsilon = simplified_layer_norm.producer().attributes.get_float("epsilon", 1e-5) + + if self._has_bias: + normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( + input, + skip, + gamma, + bias, + epsilon=epsilon, + _outputs=4, + _domain="com.microsoft", + ) + else: + normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( + input, + skip, + gamma, + epsilon=epsilon, + _outputs=4, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +_skip_rms_add_bias_rule = SkipRmsNormFusion.rule( + "SkipRmsNormBias", has_bias=True, bias_pre_add=False +) +_skip_rms_pre_add_bias_rule = SkipRmsNormFusion.rule( + "SkipRmsNormPreBias", has_bias=True, bias_pre_add=True +) +_skip_rms_rule = SkipRmsNormFusion.rule("SkipRmsNorm", has_bias=False) + +skip_rms_normalization_ruleset = pattern.RewriteRuleSet( + [_skip_rms_pre_add_bias_rule, _skip_rms_add_bias_rule, _skip_rms_rule] +) +fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset) + + +# Fusion rule for SkipLayerNormalization +class SkipLayerNormFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False): + """Fusion rule for SkipLayerNormalization.""" + super().__init__(name=name) + self._has_bias = has_bias + self._bias_pre_add = bias_pre_add + + def pattern(self, op, input, skip, gamma, beta, bias): + if self._has_bias and self._bias_pre_add: + input = op.Add(input, bias) + + # Support different combinations of addition of input and skip + skip_sum_pattern_1 = op.Add(skip, input) + skip_sum_pattern_2 = op.Add(input, skip) + skip_sum = pattern.OrValue([skip_sum_pattern_1, skip_sum_pattern_2], name="skip_sum") + + if self._has_bias and not self._bias_pre_add: + skip_sum = op.Add(skip_sum, bias) + + normalized = op.LayerNormalization( + skip_sum, + gamma, + beta, + axis=-1, + _allow_other_attributes=True, + _outputs=["layer_norm"], + ) + return normalized, skip_sum + + def check( + self, + op, + input, + skip, + gamma, + beta, + bias, + layer_norm, + **_, + ) -> pattern.MatchResult: # type: ignore[name-defined] + """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils.check_shape_bool(bindings, val, dims) + + if no_match(input, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']", + input, + ) + if no_match(skip, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {skip} does not match expected dimensions ['B', 'S', 'D']", + skip, + ) + if no_match(gamma, ["D"]): + return check_result.fail( + f"Shape mismatch: {gamma} does not match expected dimensions ['D']", + gamma, + ) + if no_match(beta, ["D"]): + return check_result.fail( + f"Shape mismatch: {beta} does not match expected dimensions ['D']", + beta, + ) + if self._has_bias: + if no_match(bias, ["D"]): + return check_result.fail( + f"Shape mismatch: {bias} does not match expected dimensions ['D']", + bias, + ) + + stash_type = layer_norm.producer().attributes.get_int("stash_type", 1) + if stash_type != 1: + return check_result.fail("Stash type is not supported.") + return check_result + + def rewrite( + self, + op, + input, + skip, + gamma, + beta, + bias, + layer_norm, + **_, + ): + epsilon = layer_norm.producer().attributes.get_float("epsilon", 1e-5) + + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( + input, + skip, + gamma, + beta, + bias, + epsilon=epsilon, + _outputs=4, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +_skip_layer_add_bias_rule = SkipLayerNormFusion.rule( + "SkipLayerNormBias", has_bias=True, bias_pre_add=False +) +_skip_layer_pre_add_bias_rule = SkipLayerNormFusion.rule( + "SkipLayerNormPreBias", has_bias=True, bias_pre_add=True +) +_skip_layer_rule = SkipLayerNormFusion.rule("SkipLayerNorm", has_bias=False) + +skip_layer_normalization_ruleset = pattern.RewriteRuleSet( + [ + _skip_layer_pre_add_bias_rule, + _skip_layer_add_bias_rule, + _skip_layer_rule, + ] +) + +fuse_skip_layer_normalization = _fusion_utils.apply_fusion_rules( + skip_layer_normalization_ruleset +) diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization_test.py b/onnxscript/rewriter/ort_fusions/skip_normalization_test.py new file mode 100644 index 0000000000..6ee80ce5dc --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/skip_normalization_test.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.models._bart_encoder import bart_encoder_test +from onnxscript.rewriter.models._smollm_1 import smollm_test_1 +from onnxscript.rewriter.models._whisper_decoder import whisper_decoder_test +from onnxscript.rewriter.models._whisper_encoder import whisper_encoder_test +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization +from onnxscript.rewriter.ort_fusions.skip_normalization import ( + fuse_skip_layer_normalization, + fuse_skip_rms_normalization, +) + + +class TestSkipNormalization(unittest.TestCase): + def test_smollm(self): + smollm_test = smollm_test_1() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + fuse_rms_normalization(model) + fuse_skip_rms_normalization(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("SkipSimplifiedLayerNormalization", op_types) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + @unittest.skip("fixme: accuracy is not high") + def test_whisper_encoder(self): + whisper_encoder = whisper_encoder_test() + model = whisper_encoder.get_onnx_model() + onnxscript.optimizer.optimize(model) + + inputs = whisper_encoder.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + fuse_skip_layer_normalization(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("SkipLayerNormalization", op_types) + + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_whisper_decoder(self): + whisper_decoder = whisper_decoder_test() + model = whisper_decoder.get_onnx_model() + onnxscript.optimizer.optimize(model) + + inputs = whisper_decoder.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + fuse_skip_layer_normalization(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("SkipLayerNormalization", op_types) + + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + def test_bart_encoder(self): + bart_encoder = bart_encoder_test() + model = bart_encoder.get_onnx_model() + onnxscript.optimizer.optimize(model) + + inputs = bart_encoder.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + + fuse_skip_layer_normalization(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("SkipLayerNormalization", op_types) + self.assertEqual(op_types.count("SkipLayerNormalization"), 5) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/ort_fusions/softmax.py similarity index 77% rename from onnxscript/rewriter/onnxruntime/softmax.py rename to onnxscript/rewriter/ort_fusions/softmax.py index df868f1348..10535f57f4 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/ort_fusions/softmax.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import logging @@ -5,13 +7,12 @@ import onnx from onnxscript import ir -from onnxscript.rewriter import pattern +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet -op = pattern.onnxop logger = logging.getLogger(__name__) -def softmax_with_fp32_upcast(input, axis): +def softmax_with_fp32_upcast(op, input, axis): upcast = op.Cast(input, to=onnx.TensorProto.FLOAT) softmax = op.Softmax(upcast, axis=axis) # pylint: disable=redefined-outer-name return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) @@ -21,7 +22,7 @@ def softmax(op, input, axis): return op.Softmax(input, axis=axis) -def softmax_with_fp32_upcast_without_axis(input): +def softmax_with_fp32_upcast_without_axis(op, input): upcast = op.Cast(input, to=onnx.TensorProto.FLOAT) softmax = op.Softmax(upcast) # pylint: disable=redefined-outer-name return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) @@ -31,7 +32,7 @@ def softmax_without_axis(op, input): return op.Softmax(input) -def check_if_fp16_input(input, **_) -> bool: +def check_if_fp16_input(context, input, **_) -> bool: if input is None: logger.warning( "Cannot perform softmax upcast removal: " @@ -50,10 +51,10 @@ def check_if_fp16_input(input, **_) -> bool: to free up memory as well as saving performance. """ # pylint: enable=pointless-string-statement -rules = pattern.RewriteRuleSet( +rules = RewriteRuleSet( [ - pattern.RewriteRule(softmax_with_fp32_upcast, softmax, check_if_fp16_input), - pattern.RewriteRule( + RewriteRule(softmax_with_fp32_upcast, softmax, check_if_fp16_input), + RewriteRule( softmax_with_fp32_upcast_without_axis, softmax_without_axis, check_if_fp16_input, diff --git a/onnxscript/rewriter/onnxruntime/softmax_test.py b/onnxscript/rewriter/ort_fusions/softmax_test.py similarity index 95% rename from onnxscript/rewriter/onnxruntime/softmax_test.py rename to onnxscript/rewriter/ort_fusions/softmax_test.py index 8c26adbe0e..e94657d573 100644 --- a/onnxscript/rewriter/onnxruntime/softmax_test.py +++ b/onnxscript/rewriter/ort_fusions/softmax_test.py @@ -1,10 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import onnx.parser import parameterized from onnxscript import ir -from onnxscript.rewriter.onnxruntime import softmax +from onnxscript.rewriter.ort_fusions import softmax class SoftmaxUpcastRemovalTest(unittest.TestCase): diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index e531f7c81f..c4fd6e9161 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1,916 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations -import dataclasses -import inspect -import itertools -import math -from typing import ( - Any, - Callable, - List, - MutableSequence, - Optional, - Protocol, - Sequence, - Tuple, - TypeVar, - Union, +from onnxscript.ir import _tape +from onnxscript.rewriter._basics import MatchingTracer, MatchResult, MatchStatus +from onnxscript.rewriter._matcher import PatternMatcher, SimplePatternMatcher +from onnxscript.rewriter._pattern_ir import ( + ANY_VALUE, + AttrVar, + Constant, + OpsetPatternBuilder, + OrValue, + Var, + pattern_builder, + torch_module_op, +) +from onnxscript.rewriter._rewrite_rule import ( + Pattern, + PatternBase, + RewriteRule, + RewriteRuleClassBase, + RewriteRuleSet, ) -import onnx - -from onnxscript import ir -from onnxscript.ir import _convenience -from onnxscript.rewriter import _ir_utils, _tape - -T = TypeVar("T") - - -class Pattern(Protocol[T]): # type: ignore[misc] - """This is essentially a Predicate[T], that is, a Callable[[T], bool] bound to the name "matches".""" - - def matches(self, item: T) -> bool: ... - - -class StringConstantPattern(Pattern[str]): - """Matches strings with given value.""" - - def __init__(self, value: str): - self._value = value - - def matches(self, item: str) -> bool: - return item == self._value - - -class PrefixPattern(Pattern[str]): - """Matches strings with a given prefix.""" - - def __init__(self, value: str) -> None: - self._value = value - - def matches(self, value: str) -> bool: - return value.startswith(self._value) - - -class AttrPattern(Pattern[Union[ir.Attr, ir.RefAttr]]): - """Base class for an attribute pattern. Matches any attribute value by default.""" - - def __init__(self, name: str | None): - self.name = name - - def matches(self, attr: ir.Attr | ir.RefAttr) -> bool: - return True - - -# TODO: Support tensors. Align with usage elsewhere. -SupportedAttrTypes = Union[ - int, - float, - str, - Sequence[int], - Sequence[float], - Sequence[str], +RewriterContext = _tape.Builder + +__all__ = [ + "ANY_VALUE", + "AttrVar", + "OrValue", + "Constant", + "OpsetPatternBuilder", + "pattern_builder", + "PatternBase", + "Pattern", + "RewriteRule", + "RewriteRuleClassBase", + "RewriteRuleSet", + "RewriterContext", + "MatchingTracer", + "MatchResult", + "MatchStatus", + "PatternMatcher", + "SimplePatternMatcher", + "torch_module_op", + "Var", ] - - -class AttrConstantPattern(AttrPattern): - """Matches attributes with given value. - - Uses standard equality for matching. For list-valued attributes, the order of elements matters. - If order is immaterial, we need to define a separate pattern for that. - """ - - def __init__(self, value: SupportedAttrTypes): - super().__init__(None) - self._value = value - - def matches(self, attr: ir.Attr | ir.RefAttr) -> bool: - return isinstance(attr, ir.Attr) and attr.value == self._value - - -def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> AttrPattern: - """Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern.""" - if isinstance(value, AttrPattern): - return value - if type(value) == ValuePattern: - # This is a hack. Currently, when we create pattern-variables, we create them as ValuePattern, - # and change them to AttrPattern if/when used in an attribute context. We could use type - # annotations to distinguish between ValuePattern and AttrPattern, but forces users to - # use these type annotations. - # TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.) - return AttrPattern(value.name) - if isinstance(value, (int, float, str)): - return AttrConstantPattern(value) - if isinstance(value, Sequence): - if all(isinstance(i, (int, float)) for i in value): - return AttrConstantPattern(value) - if all(isinstance(i, str) for i in value): - return AttrConstantPattern(value) - raise ValueError("Only lists of int/float/str can be used as an AttrPattern") - raise TypeError(f"Cannot convert {type(value)} to AttrPattern") - - -class OpsetPatternBuilder(Pattern[str]): - """Represents an opset pattern. - - (i) It is used to create a NodePattern (via OpPatternBuilder). - Example usage: - :: - - z = op.Matmul(x, y) - - Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance - of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. - - (ii) An opset pattern is also matched against the actual opset domain used in the - input model. - """ - - def __init__(self, domain_pattern: Pattern[str] | str) -> None: - if isinstance(domain_pattern, str): - domain_pattern = StringConstantPattern(domain_pattern) - self.domain_pattern = domain_pattern - - @classmethod - def domain_prefix(cls, domain: str) -> OpsetPatternBuilder: - return cls(PrefixPattern(domain)) - - def matches(self, domain): - return self.domain_pattern.matches(domain) - - def __getattr__(self, name: str) -> OpPatternBuilder: - return OpPatternBuilder(self, StringConstantPattern(name)) - - def submodule(self, name: str) -> OpPatternBuilder: - """This method is used to match against submodule ops with prefix.""" - return OpPatternBuilder(self, PrefixPattern(name)) - - -onnxop = OpsetPatternBuilder("") - -msft_op = OpsetPatternBuilder("com.microsoft") - -torch_module_op = OpsetPatternBuilder.domain_prefix("pkg.torch") - - -class OpPatternBuilder: - """A utility class to build a NodePattern. - - It is used primarily to create a NodePattern. - Example usage: - :: - - z = op.Matmul(x, y) - - Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance - of OpPatternBuilder, and `op.Matmul(x, y)` is an instance of NodePattern. - - """ - - def __init__( - self, - opset_pattern: Pattern[str], - op_name_pattern: Pattern[str], - ) -> None: - self.opset_pattern = opset_pattern - self.op_name_pattern = op_name_pattern - - def __call__(self, *args, **kwargs): - # TODO(rama): Unify with convention used elsewhere. - if "_num_outputs" in kwargs: - num_outputs = kwargs["_num_outputs"] - del kwargs["_num_outputs"] - else: - num_outputs = 1 - inputs = [_to_value_pattern(x) for x in args] - attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} - node_pattern = NodePattern( - self.opset_pattern, self.op_name_pattern, inputs, attributes - ) - if num_outputs == 1: - return NodeOutputPattern(node_pattern, 0) - else: - return [NodeOutputPattern(node_pattern, i) for i in range(num_outputs)] - - -def _to_value_pattern( - x: ValuePattern | int | float | None, -) -> ValuePattern | None: - """Promotes an input-value used to construct a NodePattern to a ValuePattern. - - Example usage: - :: - x = op.MatMul(a, b) - z = op.Add(x, 0) - - In this example, `a, `b`, and `x` are ValuePatterns used to construct a NodePattern. - `0` is a constant (int) value, and is automatically promoted to a ValuePattern. - - Note that this is a shorthand for creating a Constant pattern. The user can more - explicitly write this as: - :: - z = op.Add(x, op.Constant(0)) - """ - if x is None or isinstance(x, ValuePattern): - return x - if isinstance(x, (int, float)): - return Constant(x) - # TODO(rama): support lists of int/float - # if isinstance(x, list): - # if all(isinstance(i, (int, float)) for i in x): - # return Constant(x) - # raise ValueError("Only lists of int/float can be used as a ValuePattern") - # TODO(titaiwang): Could this be wrapped Constant? - raise TypeError(f"Cannot convert {type(x)} to ValuePattern") - - -class MatchResult: - """Represents the result of a match operation. - - A match can either succeed or fail. - If it succeeds, it returns a list of nodes that matched the pattern - and a set of bindings for the variables in the pattern. - - Example: - :: - def pattern(x, shape1, shape2): - t1 = op.Reshape(x, shape1) - t2 = op.Reshape(t1, shape2) - return t2 - The above pattern matches a sequence of two Reshape ops. - The matched_nodes will contain the two Reshape ops, and the bindings will - contain the values that are bound to the variables `x`, `shape1`, and `shape2`. - """ - - def __init__(self, success: bool) -> None: - self.success: bool = success - # For a successful match, matched_nodes is a list of values that matched the pattern. - # These include the internal nodes of the pattern that were matched, but not - # the leaves (sub-trees) that match against the variables in the pattern. - # These represent the values that will be replaced by the replacement pattern. - self.matched_nodes: MutableSequence[ir.Node] = [] - # For a successful match, bindings is a dictionary of mapping pattern-variable-names - # to values. - self.bindings: dict[str, Any] = {} - self.outputs: MutableSequence[ir.Value] = [] - - def __bool__(self): - return self.success - - @classmethod - def FAIL(cls): - return cls(False) - - @property - def nodes(self) -> MutableSequence[ir.Node]: - return self.matched_nodes - - def bind(self, var: str, value: Any) -> bool: - """Binds a pattern variable name to a value from the matched IR. - - Returns True if the binding is successful, False otherwise (when the binding is inconsistent). - """ - if var in self.bindings: - # TODO(rama): Use appropriate equality-check here. - if self.bindings[var] == value: - return True - self.success = False - return False - self.bindings[var] = value - return True - - def extend(self, other: MatchResult | bool): - if not self.success: - return - if not other: - self.success = False - return - if isinstance(other, bool): - return - for var, val in other.bindings.items(): - if var in self.bindings: - # TODO: handle attribute var bindings - if self.bindings[var] != val: - self.success = False - return - else: - self.bindings[var] = val - assert self.matched_nodes is not None, "matched_nodes should not be None." - self.matched_nodes.extend(other.matched_nodes) # type: ignore[attr-defined] - - -class ValuePattern: - """Base class for all patterns that match against IR values. - - This is used primarily to provide operator overloadings for arithmetic - operations, so that we can write patterns like `x + 1` and `1 + x`. - """ - - def __init__(self, name: str | None) -> None: - self.name = name - - def __repr__(self) -> str: - return f"ValuePattern({self.name!r})" - - def matches(self, value: ir.Value): - result = MatchResult(success=True) - if self.name is not None: - result.bind(self.name, value) - return result - - def commute(self) -> Sequence[ValuePattern]: - """Return a list of commuted patterns. - - This is used to handle commutative operations like addition and multiplication. - A single pattern is converted into a list of equivalent patterns by swapping - the parameters of commutative operations. - """ - return [self] - - def __add__(self, other): - return onnxop.Add(self, other) - - def __radd__(self, other): - return onnxop.Add(other, self) - - def __sub__(self, other): - return onnxop.Sub(self, other) - - def __rsub__(self, other): - return onnxop.Sub(other, self) - - def __mul__(self, other): - return onnxop.Mul(self, other) - - def __rmul__(self, other): - return onnxop.Mul(other, self) - - def __truediv__(self, other): - return onnxop.Div(self, other) - - def __rtruediv__(self, other): - return onnxop.Div(other, self) - - def __pow__(self, other): - return onnxop.Pow(self, other) - - -class NodePattern: - """Represents a pattern that matches against a Node. - - This differs from a NodeOutputPattern in that it matches against a node (which - may produce 1 or more outputs), whereas a NodeOutputPattern matches against - a specific output of a node. - """ - - def __init__( - self, - domain: Pattern[str], - op: Pattern[str], - inputs: Sequence[int | float | ValuePattern | None], - attributes: dict[str, AttrPattern], - ): - self.domain = domain - self.op = op - self.inputs = [_to_value_pattern(x) for x in inputs] - self.attributes = attributes - - def matches_node(self, node: ir.Node) -> MatchResult: - """Examine if the IR node matches the self pattern.""" - if not self.domain.matches(node.domain): - return MatchResult.FAIL() - if not self.op.matches(node.op_type): - return MatchResult.FAIL() - match = MatchResult(success=True) - # TODO: We should add filtered logging starting from here to emit why - # matching failed. This should cut a lot of noises compared to logging everything, - # because at least the starting node op_type is already matched. - for arg_value, previous_node_output_pattern in zip(node.inputs, self.inputs): - # previous_node_output_pattern could be a Var, if it's the original arg. - if arg_value is None and previous_node_output_pattern is None: - continue - if arg_value is None or previous_node_output_pattern is None: - return MatchResult.FAIL() - sub_match = previous_node_output_pattern.matches(arg_value) - match.extend(sub_match) - if not match: # If sub-match failed, - return match - # Sub-graphs not handled yet. - for name, attr_pattern in self.attributes.items(): - attr_value = node.attributes.get(name) - if attr_value is None: - return MatchResult.FAIL() - if not attr_pattern.matches(attr_value): - return MatchResult.FAIL() - if attr_pattern.name is not None: - if not match.bind(attr_pattern.name, attr_value): - return match - for name in node.attributes: - # TODO: Support matching default nodes for attributes. - if name not in self.attributes: - return MatchResult.FAIL() - match.nodes.append(node) - return match - - def commute(self) -> Sequence[NodePattern]: - list_of_lists = [ - [None] if pattern is None else pattern.commute() for pattern in self.inputs - ] # type: ignore[attr-defined] - - def enumerate_inputs(inputs, index): - if index >= len(inputs): - yield [] - else: - for pattern in inputs[index]: - for rest in enumerate_inputs(inputs, index + 1): - yield [pattern, *rest] - - inputs = list(enumerate_inputs(list_of_lists, 0)) - if self.domain.matches("") and (self.op.matches("Add") or self.op.matches("Mul")): - # TODO: handle cases where number of inputs is not 2. - swapped = [[x[1], x[0]] for x in inputs] - inputs.extend(swapped) - return [NodePattern(self.domain, self.op, input, self.attributes) for input in inputs] - - -class NodeOutputPattern(ValuePattern): - """Represents a pattern that matches against a specific output of a Node. - - This is the primary pattern used to match against computed values, that - is values computed using a specific op. - """ - - def __init__( - self, node_pattern: NodePattern, output_index: int, name: str | None = None - ) -> None: - super().__init__(name) - self.node_pattern = node_pattern - self.output_index = output_index - - def matches(self, value: ir.Value): - """Match the StaticValueInfo from IR with the `matches_node()` in node pattern.""" - node = value.producer() - if node is None: - return MatchResult.FAIL() - if value.index() != self.output_index: - return MatchResult.FAIL() - return self.node_pattern.matches_node(node) - - def commute(self) -> Sequence[ValuePattern]: - # TODO - return [ - NodeOutputPattern(pattern, self.output_index, self.name) - for pattern in self.node_pattern.commute() - ] - - -Var = ValuePattern - - -class Constant(ValuePattern): - """Represents a pattern that matches against a scalar constant value.""" - - def __init__( - self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8 - ) -> None: - super().__init__(None) - self.value = value - self.rel_tol = rel_tol - self.abs_tol = abs_tol - - def match_scalar(self, scalar_value): - status = math.isclose( - scalar_value, self.value, rel_tol=self.rel_tol, abs_tol=self.abs_tol - ) - # Note: If the value is produced by a Constant node, we could include - # the Constant node in the return_value list. However, we don't do that. - # Instead, we will rely on DCE to remove the constant node if it is not - # used elsewhere. - return MatchResult(success=status) - - def matches(self, value: ir.Value): - value = _ir_utils.propagate_const_value(value) - constant_value = _ir_utils.get_numpy_from_ir_value(value) - if constant_value is None: - return MatchResult.FAIL() - - # TODO (rama): allow users to specify shape requirement, if desired. - if constant_value.size != 1: - return MatchResult.FAIL() - - return self.match_scalar(constant_value.item()) - - def commute(self) -> list[ValuePattern]: - return [self] - - -class GraphPattern: - """Represents a pattern that can be matched against a subgraph.""" - - def __init__(self, outputs: Sequence[ValuePattern]) -> None: - self.outputs = outputs - if len(outputs) == 0: - raise ValueError("GraphPattern must have at least one output") - # Check if all outputs are produced by the same node. - output_node = None - for i, value_pattern in enumerate(outputs): - if not isinstance(value_pattern, ValuePattern): - raise TypeError( - f"Invalid type {type(value_pattern)} for graph pattern output." - ) - if not isinstance(value_pattern, NodeOutputPattern) or ( - value_pattern.output_index != i - ): - output_node = None - elif i == 0: - output_node = value_pattern.node_pattern - elif value_pattern.node_pattern is not output_node: - output_node = None - self._output_node = output_node - - @property - def num_outputs(self) -> int: - return len(self.outputs) - - def matches_node(self, node: ir.Node) -> MatchResult: - if self._output_node is None: - return MatchResult.FAIL() - return self._output_node.matches_node(node) - - def commute(self) -> Sequence[GraphPattern]: - if self._output_node is None: - raise NotImplementedError( - "Cannot commute a graph pattern with multiple output nodes." - ) - nodes = self._output_node.commute() - return [ - GraphPattern([NodeOutputPattern(n, i) for i in range(self.num_outputs)]) - for n in nodes - ] - - -def _to_graph_pattern(pattern_constructor: Callable) -> GraphPattern: - """Convert a pattern-construction function to a GraphPattern. - - A pattern-construction function will return values as below: - :: - def pattern(x: Var, shape1: Var, shape2: Var): - ... - return outputs - - We create a pattern graph by creating pattern-variables for each parameter of the function, - and calling the function. The returned values are normalized to a list of ValuePatterns, - which represent the outputs of the pattern graph. - - Args: - pattern_constructor: Callable - - Returns: - GraphPattern: A representation of the pattern that can be matched against a subgraph. - """ - _pattern_vars = inspect.signature(pattern_constructor).parameters - vars = [Var(v) for v in _pattern_vars] - pattern_outputs = pattern_constructor(*vars) - # Returned value could be a single ValuePattern or a list of ValuePatterns. - # Normalize representation to a list of ValuePatterns. - if isinstance(pattern_outputs, ValuePattern): - pattern_outputs = [pattern_outputs] - return GraphPattern(pattern_outputs) - - -def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: - """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" - # * Must check that all values matched by pattern are used only by pattern, - # except for the value that is replaced. - # * Must ensure that replacement subgraph does not use any of the deleted - # (intermediate) values. (Not necessary for now. Guaranteed.) - deleted_nodes = matched_nodes[:-1] - for n in deleted_nodes: - for v in n.outputs: - if v.is_graph_output(): - # value is an output-value of the graph/function. - return False - for consumer, _ in v.uses(): - if consumer not in matched_nodes: - return False - return True - - -# A type representing the domains/versions used in creating a replacement subgraph -UsedOpsets = List[Tuple[str, Optional[int]]] - - -class RewriterContext: - """Context parameter used to build the replacement pattern.""" - - # TODO(justinchuby): Merge with the rest of pattern building methods - def __init__(self): - self._tape = _tape.Tape() - self._used_opsets: UsedOpsets = [] - - def __getattr__(self, op_type: str) -> Any: - return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) - - def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): - # TODO(rama): some of the following logic should move into the tape. - domain = kwargs.pop("domain", "") - version = kwargs.pop("version", None) - self._used_opsets.append((domain, version)) - outputs = kwargs.pop("outputs", 1) - if isinstance(outputs, Sequence): - num_outputs = len(outputs) - else: - assert isinstance(outputs, int) - num_outputs = outputs - if num_outputs == 1: - value = self._tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain) - if isinstance(outputs, Sequence): - value.name = outputs[0] - return value - values = self._tape.op_multi_output( - op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs - ) - if isinstance(outputs, Sequence): - for value, name in zip(values, outputs): - value.name = name - return values - - @property - def nodes(self) -> Sequence[ir.Node]: - # TODO(rama): The current tape-based implementation will not track nodes added - # via overloaded operators, eg., `x + y`. One possible way to fix this is to - # have values/nodes know which tape they belong to (instead of a graph/function). - # However, it is unclear we need this feature for rewriting: we could also - # identify the nodes to be inserted from the replacement values (by tracing back). - return self._tape.nodes - - @property - def used_opsets(self) -> UsedOpsets: - return self._used_opsets - - -@dataclasses.dataclass -class ReplacementSubgraph: - """A subgraph that will replace the matched pattern.""" - - match: MatchResult - new_outputs: Sequence[ir.Value] - new_nodes: Sequence[ir.Node] - used_opsets: UsedOpsets - - -class ReplacementPatternFunction: - """The replacement pattern that will replace the targeted pattern. - - Attributes: - function (Callable): The replacement function that will be used to replace the matched pattern. - """ - - def __init__(self, function) -> None: - self._function = function - - def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: - context = RewriterContext() - new_outputs = self._function(context, **match.bindings) - if new_outputs is None: - return None # Failed to create replacement subgraph - if not isinstance(new_outputs, Sequence): - new_outputs = [new_outputs] - return ReplacementSubgraph(match, new_outputs, context.nodes, context.used_opsets) - - -def _update_opset_imports( - graph_or_function: ir.Graph | ir.Function, delta: ReplacementSubgraph -): - imports = graph_or_function.opset_imports - for domain, version in delta.used_opsets: - if domain not in imports: - # use 1 as default version if not explicitly specified - imports[domain] = version if version is not None else 1 - elif version is not None and version != imports[domain]: - raise ValueError( - f"Multiple versions of opset {domain} used. " - f"Expected version {imports[domain]}, but got {version}." - ) - - -class RewriteRule: - def __init__( - self, - target_pattern: GraphPattern | Callable | None = None, - replacement_pattern: ReplacementPatternFunction | Callable | None = None, - condition_function: Callable | None = None, - ) -> None: - """Create a rewrite rule. - - Args: - target_pattern: The pattern function that will be - matched against the IR. - replacement_pattern: The replacement function that - will be used to replace the matched pattern. - condition_function: The condition function that - will be used to check if the pattern matches the IR with ir.Values - constraints in consideration. - - """ - if target_pattern is None: - # NOTE: this is a default-constructor. Caller responsible for filling in the fields. - assert replacement_pattern is None - assert condition_function is None - return - elif replacement_pattern is None: - raise ValueError( - "replacement_pattern must be provided if target_pattern is provided" - ) - - if not isinstance(target_pattern, GraphPattern): - target_pattern = _to_graph_pattern(target_pattern) - self._target_pattern = target_pattern - - if not isinstance(replacement_pattern, ReplacementPatternFunction): - replacement_pattern = ReplacementPatternFunction(replacement_pattern) - self._replacement_pattern = replacement_pattern - self._condition_function = condition_function - - def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: - """Check if the node from IR matches the pattern.""" - if len(node.outputs) != self._target_pattern.num_outputs: - return MatchResult.FAIL() - match = self._target_pattern.matches_node(node) - if ( - self._condition_function is not None - and match - and not self._condition_function(**match.bindings) - ): - return MatchResult.FAIL() - match.outputs.extend(node.outputs) - return match - - def try_rewrite( - self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node - ) -> ReplacementSubgraph | None: - """If the node matches the pattern, then replace the node with the replacement pattern.""" - match = self.matches(node, model) - if match: - assert match.nodes is not None, "Matched values should not be None." - if _valid_to_replace(match.nodes): - replacement_subgraph = self._replacement_pattern.get_replacement(match) - if replacement_subgraph is None: - return None - if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: - raise ValueError( - f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " - f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}." - ) - # TODO(rama): Check/update opset-imports - # (i) Following is required by multi-output matcher too; move this. - # (ii) Remove the opset imports from deleted nodes? - _update_opset_imports(graph_or_function, replacement_subgraph) - _update_opset_imports(model.graph, replacement_subgraph) - return replacement_subgraph - return None - - def apply_to_model(self, model: ir.Model, *, commute: bool = False): - # TODO(titaiwang): Why do we need RewriteRuleSet? - return RewriteRuleSet([self], commute=commute).apply_to_model(model) - - def count_matches(self, model: ir.Model, *, commute: bool = False): - return RewriteRuleSet([self], commute=commute).count_matches(model) - - def commute(self) -> Sequence[RewriteRule]: - def replace_pattern(new_pattern): - """Return a shallow copy of self with node_pattern replaced by new_pattern.""" - rule = RewriteRule() - rule._condition_function = self._condition_function - rule._target_pattern = new_pattern - rule._replacement_pattern = self._replacement_pattern - return rule - - return [replace_pattern(p) for p in self._target_pattern.commute()] - - -def _apply_delta( - graph_or_function: ir.Graph | ir.Function, - node: ir.Node, - # TODO(jutinchuby): Use a more descriptive data structure to store deltas - delta, -): - """Applies delta. - - This code is valid is the considered pattern has only one output. - In case of multi output replacements, there is not need to rename - the outputs. - - In case of multi-output design, the nodes may not be necessary inserted - all at the same position. To be convinced, you can take a pattern - producing two outputs, but the second one needs the first one and - another input appeared after the first outputs. What could be - the right place to inserted all of the node. - - The current implementation insert all the nodes at the same position - but checks there is not inconsistency. In that case, it fails. - We could reorder (long) or do more clever changes. - The reordering would probably happen not very often. - """ - - if isinstance(delta, tuple): - # multi-output strategy - n_matches, matched_nodes, inserted_nodes = delta - - # TODO(rama): Was "assert i not in to_insert"; seems wrong. - # What is this trying to check? Best effort correction below. - assert node not in inserted_nodes # conflicts should avoid that case - - graph_or_function.insert_after(node, inserted_nodes) - # TODO: improve this - # This is updating the graph/function outputs to use the new outputs - for inserted_node in inserted_nodes: - for new_output in inserted_node.outputs: - if (index := new_output.meta.get(_ir_utils.GRAPH_OUTPUT_META_KEY)) is not None: # type: ignore[assignment] - graph_or_function.outputs[index] = new_output - - for d in matched_nodes: - assert d in graph_or_function - graph_or_function.remove(matched_nodes, safe=True) - else: - assert isinstance(delta, ReplacementSubgraph) - # Replace matched nodes with new nodes, matched values with new values - old_values = delta.match.outputs - new_values = delta.new_outputs - - for old_value, new_value in zip(old_values, new_values): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps we should merge old and new types. As of now, new - # values don't have type information. Note that this could be a problem - # for semantics-altering rewrite-rules: we should allow users to override - # this for such rules. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted node to use the new outputs - _convenience.replace_all_uses_with(old_values, new_values) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] - - # insert new nodes after the index node - graph_or_function.insert_after(node, delta.new_nodes) - graph_or_function.remove(delta.match.nodes, safe=True) - - -class RewriteRuleSet: - def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: - if commute: - rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules])) - self.rules = rules - - def _apply_to_graph_or_function( - self, - model: ir.Model, - graph_or_function: ir.Graph | ir.Function, - ) -> int: - count = 0 - - # NOTE: Rules should be prioritized in the order they are added to the RewriteRuleSet. - # And the graph is applied in order. - for rule in self.rules: - for node in graph_or_function: - delta = rule.try_rewrite(model, graph_or_function, node) - if delta is None: - continue - _apply_delta(graph_or_function, node, delta) - count += 1 - - return count - - def apply_to_model(self, model: ir.Model) -> int: - assert isinstance(model, ir.Model) - count = self._apply_to_graph_or_function(model, model.graph) - for function in model.functions.values(): - count += self._apply_to_graph_or_function(model, function) - return count - - def _count_matches_in_graph_or_function( - self, model: ir.Model, graph_or_function: ir.Graph | ir.Function - ) -> int: - count = 0 - for node in graph_or_function: - for rule in self.rules: - if rule.matches(node, model): - count += 1 - break - return count - - def count_matches(self, model: onnx.ModelProto | ir.Model): - if isinstance(model, onnx.ModelProto): - model = ir.serde.deserialize_model(model) - else: - assert isinstance(model, ir.Model) - count = self._count_matches_in_graph_or_function(model, model.graph) - for function in model.functions.values(): - count += self._count_matches_in_graph_or_function(model, function) - return count diff --git a/onnxscript/rewriter/pattern_base_test.py b/onnxscript/rewriter/pattern_base_test.py new file mode 100644 index 0000000000..8893d762b6 --- /dev/null +++ b/onnxscript/rewriter/pattern_base_test.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Test for the new Pattern and PatternBase classes.""" + +import unittest + +from onnxscript import ir +from onnxscript.rewriter import pattern + + +class PatternTest(unittest.TestCase): + """Test Pattern functionality.""" + + def test_pattern_impl_basic_functionality(self): + """Test that Pattern can be created and used independently.""" + + def simple_pattern(op, x): + return op.Identity(x) + + # Create a Pattern + pattern_impl = pattern.Pattern(simple_pattern, name="SimpleIdentity") + + # Verify basic properties + self.assertEqual(pattern_impl.name, "SimpleIdentity") + self.assertIsNotNone(pattern_impl._target_pattern) + self.assertIsNotNone(pattern_impl._matcher) + self.assertIsNotNone(pattern_impl._condition_function) + + def test_pattern_impl_match_method(self): + """Test that Pattern.match method works correctly.""" + + def identity_pattern(op, x): + return op.Identity(x) + + pattern_impl = pattern.Pattern(identity_pattern, name="IdentityPattern") + + # Create a model with an Identity node + model = ir.from_onnx_text( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Identity(x) + } + """ + ) + + # Find the Identity node + identity_node = None + for node in model.graph: + if node.op_type == "Identity": + identity_node = node + break + + self.assertIsNotNone(identity_node) + + # Test pattern matching + match_result = pattern_impl.match(model, model.graph, identity_node) + + # The match might succeed or fail depending on how the pattern matching works + # The important thing is that the method runs without error + self.assertIsInstance(match_result, (pattern.MatchResult, type(None))) + + def test_pattern_impl_with_condition_function(self): + """Test Pattern with a custom condition function.""" + + def identity_pattern(op, x): + return op.Identity(x) + + def always_fail_condition(context, x): + return False + + pattern_impl = pattern.Pattern( + identity_pattern, condition_function=always_fail_condition, name="FailingIdentity" + ) + + # Create a model with an Identity node + model = ir.from_onnx_text( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Identity(x) + } + """ + ) + + # Find the Identity node + identity_node = None + for node in model.graph: + if node.op_type == "Identity": + identity_node = node + break + + self.assertIsNotNone(identity_node) + + # Test pattern matching - should fail due to condition function + match_result = pattern_impl.match(model, model.graph, identity_node) + + # Should return None due to failing condition + self.assertIsNone(match_result) + + def test_pattern_impl_no_match_returns_match_object(self): + """Test that Pattern.match returns match object (not always None) when available.""" + + def identity_pattern(op, x): + return op.Identity(x) + + pattern_impl = pattern.Pattern(identity_pattern, name="IdentityPattern") + + # Create a model with an Add node (should not match Identity pattern) + model = ir.from_onnx_text( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + z = Add(x, y) + } + """ + ) + + # Find the Add node + add_node = None + for node in model.graph: + if node.op_type == "Add": + add_node = node + break + + self.assertIsNotNone(add_node) + + # Test pattern matching - should fail because Add != Identity + match_result = pattern_impl.match(model, model.graph, add_node) + + # The result should be falsy (either None or a failed MatchResult) + self.assertFalse(bool(match_result)) + + +class PatternBaseTest(unittest.TestCase): + """Test PatternBase functionality.""" + + def test_pattern_base_creation(self): + """Test that PatternBase can be subclassed and used.""" + + class TestPattern(pattern.PatternBase): + def pattern(self, op, x): + return op.Identity(x) + + test_pattern = TestPattern(name="TestPattern") + self.assertEqual(test_pattern.name, "TestPattern") + + def test_pattern_base_compiled_pattern_access(self): + """Test that PatternBase has an internal Pattern that is created on demand.""" + + class TestPattern(pattern.PatternBase): + def pattern(self, op, x): + return op.Identity(x) + + def check(self, context, x): + return pattern.MatchResult() # Always succeeds + + test_pattern = TestPattern(name="TestPattern") + + # Initially, the Pattern should not be created (lazy initialization) + self.assertIsNone(test_pattern._compiled_pattern) + + # Create a simple model to trigger pattern creation + model = ir.from_onnx_text( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Identity(x) + } + """ + ) + graph = model.graph + node = next(iter(graph)) + + # Calling match() should trigger the creation of _compiled_pattern + test_pattern.match(model, graph, node) + + # Now the Pattern should be created + self.assertIsInstance(test_pattern._compiled_pattern, pattern.Pattern) + self.assertEqual(test_pattern._compiled_pattern.name, "TestPattern") + + def test_pattern_base_default_name(self): + """Test that PatternBase uses class name as default.""" + + class MyCustomPattern(pattern.PatternBase): + def pattern(self, op, x): + return op.Identity(x) + + test_pattern = MyCustomPattern() + self.assertEqual(test_pattern.name, "MyCustomPattern") + + +class RewriteRuleInheritanceTest(unittest.TestCase): + """Test that RewriteRule still works after inheriting from Pattern.""" + + def test_rewrite_rule_still_works(self): + """Test that existing RewriteRule functionality is preserved.""" + + def reciprocal_mul_pattern(op, x, y): + return (1 / x) * y + + def div_replacement(op, x, y): + return op.Div(y, x) + + rule = pattern.RewriteRule(reciprocal_mul_pattern, div_replacement) + + # Create a model that should match + model = ir.from_onnx_text( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + c1 = Constant() + t1 = Div(c1, x) + z1 = Mul(t1, y) + z = Identity(z1) + } + """ + ) + + # Apply the rule + count = rule.apply_to_model(model) + + # The rule should either apply or not, but the method should work + self.assertIsInstance(count, int) + self.assertGreaterEqual(count, 0) + + def test_rewrite_rule_class_base_still_works(self): + """Test that RewriteRuleClassBase still works after inheriting from PatternBase.""" + + class SimpleIdentityRule(pattern.RewriteRuleClassBase): + def pattern(self, op, x): + return op.Identity(x) + + def check(self, context, x): + return pattern.MatchResult() # Always succeeds + + def rewrite(self, op, x): + return op.Identity(x) # No-op replacement + + # Create a rule instance + rule = SimpleIdentityRule.rule() + + self.assertIsInstance(rule, pattern.RewriteRule) + self.assertEqual(rule.name, "SimpleIdentityRule") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 45bdcd6ad9..0a29080b4d 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -1,3 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import contextlib +import io import logging import unittest @@ -5,17 +9,18 @@ import onnx.checker import onnx.parser -from onnxscript import ir -from onnxscript.rewriter import _ir_utils, cast_constant_of_shape, pattern +import onnxscript.optimizer +from onnxscript import FLOAT, ir, script +from onnxscript import opset17 as op +from onnxscript.rewriter import pattern +from onnxscript.rewriter.rules.common import _cast_constant_of_shape logger = logging.getLogger(__name__) -op = pattern.onnxop -msft_op = pattern.msft_op class ReciprocalMulTest(unittest.TestCase): def rule(self) -> pattern.RewriteRule: - def reciprocal_mul_pattern(x, y): + def reciprocal_mul_pattern(op, x, y): return (1 / x) * y def div(op, x, y): @@ -59,6 +64,16 @@ def test_failed_match(self): self.assertEqual(count, 0) self.assertEqual(len(model.graph), 4) + # Test verbose output produces something: + # TODO(rama): Need a better way to test this. + # Well-defined error-codes and messages would be helpful. + + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + self.rule().apply_to_model(model, verbose=5) + out = buffer.getvalue() + self.assertIn("Match failed", out) + def test_multiple_matches(self): model_proto = onnx.parser.parse_model( """ @@ -91,19 +106,19 @@ def test_multiple_matches(self): class FastGeluTest(unittest.TestCase): def rule(self) -> pattern.RewriteRule: - def fast_gelu_pattern1(x): + def fast_gelu_pattern1(op, x): b = 0.044715 c = 0.79788 tanh = op.Tanh(c * (x + (x**3) * b)) return (1.0 + tanh) * (0.5 * x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu) def long_form_rule(self) -> pattern.RewriteRule: - def fast_gelu_pattern1_long(x): + def fast_gelu_pattern1_long(op, x): three = pattern.Constant(3) x_cube = op.Pow(x, three) b = pattern.Constant(0.044715) @@ -119,7 +134,7 @@ def fast_gelu_pattern1_long(x): return op.Mul(one_plus_tanh, half_x) def fast_gelu(op, x): - return op.FastGelu(x, domain="com.microsoft") + return op.FastGelu(x, _domain="com.microsoft") return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu) @@ -160,7 +175,7 @@ def test_long_rule(self): class ConcatTest(unittest.TestCase): def rule(self) -> pattern.RewriteRule: - def concat_pattern(x, y, axis): + def concat_pattern(op, x, y, axis): seq = op.SequenceConstruct(x, y) return op.ConcatFromSequence(seq, axis=axis) @@ -211,7 +226,7 @@ def test_concat_in_function(self): class RewriteRuleTest(unittest.TestCase): def test_commute(self): - def add_0(x): + def add_0(op, x): return x + 0 def identity(op, x): @@ -238,19 +253,20 @@ def identity(op, x): self.assertEqual(nodes[1].op_type, "Identity") def test_const_value(self): - def reshape(x, newshape): + def reshape(op, x, newshape): return op.Reshape(x, newshape) def identity(op, x, newshape): del newshape # Unused return op.Identity(x) - def check_for_redundant_reshape(x, newshape): + def check_for_redundant_reshape(context, x, newshape): oldshape = x.shape - newshape = _ir_utils.propagate_const_value(newshape) - newshape = _ir_utils.get_numpy_from_ir_value(newshape) - if not isinstance(newshape, np.ndarray): + newshape_const_value = newshape.const_value + if newshape_const_value is None: return False + + newshape = newshape_const_value.numpy() newshape = newshape.tolist() if len(oldshape) != len(newshape): @@ -291,18 +307,18 @@ def test_delayed_run_provides_correct_bindings_for_multiple_matches(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 2) self.assertEqual(len(model.graph), 2) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) self.assertEqual(model.graph[1].attributes["value"].value.dtype, 1) def test_opset_import(self): - def add_same(x): + def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -322,11 +338,11 @@ def double(op, x): self.assertEqual(model.graph.opset_imports["custom.domain"], 10) def test_opset_import_in_function(self): - def add_same(x): + def add_same(op, x): return x + x def double(op, x): - return op.Double(x, domain="custom.domain", version=10) + return op.Double(x, _domain="custom.domain", _version=10) rule = pattern.RewriteRule(add_same, double) @@ -355,6 +371,584 @@ def double(op, x): ) onnx.checker.check_model(ir.serde.serialize_model(model)) + def test_optional_attribute(self): + """Test rules with optional attributes.""" + + def concat_pattern(op, x, y): + seq = op.SequenceConstruct(x, y) + result = op.ConcatFromSequence(seq, _outputs=["result"]) + return result + + def concat(op, x, y, result: ir.Value): + node = result.producer() + assert node is not None + axis = node.attributes.get("axis", None) + return op.Concat(x, y, axis=axis) + + rule = pattern.RewriteRule(concat_pattern, concat) + + # Case 1: a model with attribute axis present + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[M] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph[0].op_type, "Concat") + self.assertEqual(model.graph[0].attributes["axis"].value, 0) + + # Case 2: a model with attribute axis absent + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[M] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph[0].op_type, "Concat") + self.assertNotIn("axis", model.graph[0].attributes) + + def test_match_none_input(self): + def none_pattern(op, x): + # match against a call to Original where the first input is None + return op.Original(None, x) + + def replacement(op, x): + return op.Replaced(x) + + rule = pattern.RewriteRule(none_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call + t1 = op.Original(None, x) + # Pattern should not match following call + z = op.Original(t1, x) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 2) + self.assertEqual(model.graph.node(0).op_type, "Replaced") + self.assertEqual(model.graph.node(1).op_type, "Original") + + def test_match_optional_input(self): + def none_pattern(op, x): + # match against a call to Original where the first input may or may not be None + optional_input = pattern.Var("optional_input", can_match_none=True) + return op.Original(optional_input, x) + + def replacement(op, optional_input, x): + if optional_input is None: + return op.ReplacedNone(x) + return op.ReplacedNotNone(x) + + rule = pattern.RewriteRule(none_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call + t1 = op.Original(None, x) + # as well as this one + z = op.Original(t1, x) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 2) + self.assertEqual(len(model.graph), 2) + self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") + self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + + def test_mismatched_number_of_inputs(self): + def var_length_pattern(op): + # match against a call to Original where the first input may or may not be None + input1 = pattern.Var("input1", can_match_none=False) + input2 = pattern.Var("input2", can_match_none=True) + return op.Original(input1, input2) + + def replacement(op, input1, input2): + return op.Replaced(input1, input2) + + rule = pattern.RewriteRule(var_length_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should NOT match following 2 calls, since pattern requires first input to be non-None + t0 = op.Original() + t1 = op.Original(None, x) + + # Pattern should match following 3 calls, since second input can be None + t2 = op.Original(x) + t3 = op.Original(x, None) + t4 = op.Original(x, y) + + # Pattern should NOT match following call, since it has more than 2 inputs + t5 = op.Original(x, y, z) + return op.All(t0, t1, t2, t3, t4, t5) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 3) + self.assertEqual(len(model.graph), 7) + self.assertEqual( + [n.op_type for n in model.graph], + ["Original", "Original", "Replaced", "Replaced", "Replaced", "Original", "All"], + ) + + def test_graph_visitor(self): + class ReplaceFoo(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__() + self.replacement = None + + def pattern(self, op): + return op.Foo() + + def rewrite(self, op): + if self.replacement is None: + self.replacement = op.Bar() + return self.replacement + + rule = ReplaceFoo.rule() + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call + t1 = op.Foo() + # as well as this one + t2 = op.Foo() + z = op.Add(t1, t2) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 2) + self.assertEqual(len(model.graph), 2) + self.assertEqual(model.graph.node(0).op_type, "Bar") + self.assertEqual(model.graph.node(1).op_type, "Add") + + def test_debug_mode(self): + def source_pattern(op, x): + t1 = op.Abs(x) + t2 = op.Neg(t1) + t3 = op.Exp(t2) + return t3 + + def replacement(op, x): + return op.Something(x) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + a2 = op.Abs(x) # match-1 fails here + a3 = op.Exp(a2) # match-1 starts here + b1 = op.Neg(a3) # match-2 fails here + b2 = op.Neg(b1) # match-2 (partially) succeeds here + b3 = op.Exp(b2) # match-2 starts here + return b3 + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + tracer = pattern.MatchingTracer() + count = rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + best_matches = tracer.best_matches_map[rule] + self.assertEqual(len(best_matches), 1) + best_match = best_matches[0] + self.assertEqual(best_match.status.value, pattern.MatchStatus.NO_MATCH) + self.assertIn("OpType mismatch: expected Abs, got Neg", best_match.match_result.reason) + + def test_new_initializer(self): + def source_pattern(op, x, y): + return op.Gemm(x, op.Transpose(y)) + + def check(context, x, y): + return y.const_value is not None + + def replacement(op, x, y): + tensor = y.const_value + name = y.name + "_transposed" + transposed = ir.tensor(tensor.numpy().T, name=name) + initializer = op.initializer(transposed) + return op.Gemm(x, initializer) + + rule = pattern.RewriteRule(source_pattern, replacement, check) + + y_value = np.random.rand(8, 4).astype(np.float32) + + @script() + def test_model(x: FLOAT[16, 8]) -> FLOAT[16, 4]: + y = op.Constant(value=y_value) + return op.Gemm(x, op.Transpose(y)) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.graph.initializers), 1) + last_node = model.graph[-1] + self.assertEqual(len(last_node.inputs), 2) + init_name = last_node.inputs[1].name + self.assertIn(init_name, model.graph.initializers) + self.assertIs(last_node.inputs[1], model.graph.initializers[init_name]) + + def test_extract_function(self): + def source_pattern(op, x, y, z): + sum = op.Add(x, y) + return op.Mul(sum, z) + + def replacement(op, x, y, z): + return op.AddMul(x, y, z, _domain="some.domain") + + rule = pattern.RewriteRule(source_pattern, replacement, as_function=True) + + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: + return op.Mul(op.Add(x, y), z) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.graph), 1) + call_node = model.graph.node(0) + self.assertEqual(call_node.domain, "some.domain") + self.assertEqual(call_node.op_type, "AddMul") + function_id = call_node.op_identifier() + self.assertIn(function_id, model.functions) + function = model.functions[function_id] + self.assertEqual([x.op_type for x in function], ["Add", "Mul"]) + onnxscript.optimizer.inline(model) + self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul"]) + + def test_extract_function_with_attr(self): + def source_pattern(op, x, y): + sum = op.Add(x, y) + return op.Transpose(sum, perm=[1, 0]) + + def replacement(op, x, y): + return op.AddTranspose(x, y, _domain="some.domain") + + rule = pattern.RewriteRule(source_pattern, replacement, as_function=True) + + @script() + def test_model(x: FLOAT[1024, 512], y: FLOAT[1024, 512]) -> FLOAT[512, 1024]: + return op.Transpose(op.Add(x, y), perm=[1, 0]) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.graph), 1) + call_node = model.graph.node(0) + self.assertEqual(call_node.domain, "some.domain") + self.assertEqual(call_node.op_type, "AddTranspose") + function_id = call_node.op_identifier() + self.assertIn(function_id, model.functions) + function = model.functions[function_id] + self.assertEqual([x.op_type for x in function], ["Add", "Transpose"]) + transpose_node = function[1] + self.assertEqual(list(transpose_node.attributes["perm"].value), [1, 0]) + onnxscript.optimizer.inline(model) + self.assertEqual([x.op_type for x in model.graph], ["Add", "Transpose"]) + + def test_extract_repeated_function(self): + def source_pattern(op, x, y, z): + sum = op.Add(x, y) + return op.Mul(sum, z) + + def replacement(op, x, y, z): + return op.AddMul(x, y, z, _domain="some.domain") + + rule = pattern.RewriteRule(source_pattern, replacement, as_function=True) + + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: + t1 = op.Mul(op.Add(x, y), z) + t2 = op.Mul(op.Add(t1, y), z) + return t2 + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual(len(model.functions), 2) + self.assertEqual(len(model.graph), 2) + for call_node in model.graph: + self.assertEqual(call_node.domain, "some.domain") + self.assertEqual(call_node.op_type, "AddMul") + function_id = call_node.op_identifier() + self.assertIn(function_id, model.functions) + onnxscript.optimizer.inline(model) + self.assertEqual([x.op_type for x in model.graph], ["Add", "Mul", "Add", "Mul"]) + + def test_any_value(self): + def source_pattern(op, x): + return op.Add(x, op.Mul(0, pattern.ANY_VALUE)) + + def replacement(op, x): + return op.Identity(x) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: + zero = op.Constant(value_float=0.0) + return op.Add(x, op.Mul(zero, y)) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + self.assertEqual([x.op_type for x in model.graph], ["Constant", "Mul", "Add"]) + rule.apply_to_model(model) + self.assertEqual(len(model.graph), 2) + self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"]) + + def test_or_pattern(self): + def source_pattern(op, x, y, bias): + t1 = op.MatMul(x, y) + t2 = op.Add(t1, bias) + t1_or_t2 = pattern.OrValue([t1, t2], tag_var="has_bias", tag_values=[False, True]) + return op.Relu(t1_or_t2) + + def replacement(op, x, y, bias, has_bias): + if has_bias: + return op.WithBias(x, y, bias) + else: + return op.WithoutBias(x, y) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model1(x: FLOAT[16, 32], y: FLOAT[32, 16]) -> FLOAT[16, 16]: + return op.Relu(op.MatMul(x, y)) + + model_proto = test_model1.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["WithoutBias"]) + + @script() + def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]: + return op.Relu(op.Add(op.MatMul(x, y), bias)) + + model_proto = test_model2.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["WithBias"]) + + def test_backtracking_pattern(self): + def source_pattern(op, x, y, bias): + t1 = op.MatMul(x, y) + choice1 = op.Add(t1, bias) + choice2 = op.Add(bias, t1) + t2 = pattern.OrValue([choice1, choice2]) + return op.Relu(t2) + + def replacement(op, x, y, bias): + return op.GemmRelu(x, y, bias) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model1(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]: + return op.Relu(op.Add(op.MatMul(x, y), bias)) + + model_proto = test_model1.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["GemmRelu"]) + self.assertEqual([x.name for x in model.graph.node(0).inputs], ["x", "y", "bias"]) + + @script() + def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]: + return op.Relu(op.Add(bias, op.MatMul(x, y))) + + model_proto = test_model2.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["GemmRelu"]) + self.assertEqual([x.name for x in model.graph.node(0).inputs], ["x", "y", "bias"]) + + def test_or_pattern_return_value(self): + """Test that an OrValue can be used as a return value from the source pattern.""" + + def source_pattern(op, x, y): + choice1 = op.Add(x, y) + choice2 = op.Mul(x, y) + t = pattern.OrValue([choice1, choice2]) + z = op.Relu(t) + return z, t + + def replacement(op, x, y): + z, t = op.ReluPlus(x, y, _outputs=2) + return z, t + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model1(x: FLOAT[16, 32], y: FLOAT[16, 32]) -> FLOAT[16, 32]: + return op.Relu(op.Add(x, y)) + + model_proto = test_model1.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["ReluPlus"]) + + +class ValueNodeCheckersTest(unittest.TestCase): + """Test value/node level checkers functionality.""" + + def test_pattern_match_with_node_checker(self): + """Test Pattern.match with node-level checker.""" + + def shape_node_checker(context, node): + return node.attributes.get_int("start", 0) == 0 + + # Create a pattern that matches Shape operations with a node checker + def shape_pattern(op, x): + return op.Shape(x, _check=shape_node_checker) + + # Create the pattern + rule_pattern = pattern.Pattern(shape_pattern) + + # Create a model with multiple Shape nodes with different start attributes + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3) + { + z1 = Shape(x) + z2 = Shape (x) + z3 = Shape (x) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + + # Find the Shape nodes in the model + nodes = list(model.graph) + shape_node_no_attr = nodes[0] # Shape without start attribute + shape_node_start_0 = nodes[1] # Shape with start=0 + shape_node_start_1 = nodes[2] # Shape with start=1 + + self.assertEqual(shape_node_no_attr.op_type, "Shape") + self.assertEqual(shape_node_start_0.op_type, "Shape") + self.assertEqual(shape_node_start_1.op_type, "Shape") + + # Test case 1: Shape without start attribute (should match, default is 0) + match_result = rule_pattern.match(model, model.graph, shape_node_no_attr) + self.assertTrue(bool(match_result)) + + # Test case 2: Shape with start=0 (should match) + match_result = rule_pattern.match(model, model.graph, shape_node_start_0) + self.assertTrue(bool(match_result)) + + # Test case 3: Shape with start=1 (should not match) + match_result = rule_pattern.match(model, model.graph, shape_node_start_1) + self.assertFalse(bool(match_result)) + + def test_pattern_match_with_value_checker(self): + """Test Pattern.match with value-level checker.""" + + def is_positive_constant(context, value: ir.Value): + if value.const_value is not None: + # Get the numpy array from const_value + numpy_array = value.const_value.numpy() + + # Check if it represents a single value and is positive + if numpy_array.size != 1: + return False + + return float(numpy_array.item()) > 0 + + return False + + # Create a pattern with value checker using callable directly + def add_pattern(op, x, y): + # Use callable as input to create ValuePattern with checker + return op.Add(is_positive_constant, y) + + # Create the pattern + rule_pattern = pattern.Pattern(add_pattern) + + # Create a model with several calls to Add: + # - one with first parameter non-constant + # - one with first parameter a positive constant + # - one with first parameter a negative constant + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z1, float[N] z2, float[N] z3) + { + pos_const = Constant () + neg_const = Constant () + z1 = Add(x, y) # non-constant first parameter + z2 = Add(pos_const, y) # positive constant first parameter + z3 = Add(neg_const, y) # negative constant first parameter + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + + # Apply constant propagation to set const_value fields + onnxscript.optimizer.basic_constant_propagation(model.graph.all_nodes()) + + # Find the Add nodes in the model + add_nodes = [node for node in model.graph if node.op_type == "Add"] + self.assertEqual(len(add_nodes), 3) + + # Test case 1: Non-constant first parameter - should not match + match_result = rule_pattern.match(model, model.graph, add_nodes[0]) + self.assertFalse(bool(match_result)) + + # Test case 2: Positive constant first parameter - should match + match_result = rule_pattern.match(model, model.graph, add_nodes[1]) + self.assertTrue(bool(match_result)) + self.assertEqual(len(match_result.nodes), 1) + self.assertGreaterEqual(len(match_result.value_bindings), 1) + + # Test case 3: Negative constant first parameter - should not match + match_result = rule_pattern.match(model, model.graph, add_nodes[2]) + self.assertFalse(bool(match_result)) + + +class PatternBuilderTest(unittest.TestCase): + def test_pattern_builder_context(self): + builder = pattern.OpsetPatternBuilder("", True) + with pattern.pattern_builder(builder): + x = builder.Op1() + y = builder.Op2(x) + z = x + y + w = builder.Op3(z) + _ = z * w + ops = [x.op_type for x in builder.nodes()] + self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/rules/__init__.py b/onnxscript/rewriter/rules/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py new file mode 100644 index 0000000000..14ed3587f3 --- /dev/null +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +__all__ = [ + "add_0_rule", + "affine_conv_fusion_rule", + "cast_cast_rule", + "cast_constant_of_shape_rule", + "cast_constant_of_shape_without_value_rule", + "collapse_slice_rule", + "collapse_slice2_rule", + "conv_affine_fusion_rule", + "div_by_1_rule", + "dropout_inference_rule", + "dropout_zero_rule", + "flatten_to_reshape_rule", + "fuse_batchnorm_into_conv_rule", + "fuse_batchnorm_into_conv_transpose_rule", + "fuse_batchnorm_into_gemm_rule", + "fuse_hardswish_rules", + "fuse_pad_into_conv_integer_rule", + "fuse_pad_into_conv_rule", + "min_min_rule", + "max_max_rule", + "min_max_rule", + "max_min_rule", + "gemm_to_matmul_add_rule", + "matmul_add_to_gemm_rule", + "mul_by_1_rule", + "no_op_cast_rule", + "no_op_dynamic_scatter_nd_rule", + "no_op_expand_rule", + "no_op_static_scatter_nd_rule", + "no_op_transpose_rule", + "normalize_pad_format_conv_integer_rule", + "normalize_pad_format_conv_rule", + "one_reshape_matmul_reshape_rule", + "reshape_reshape_rule", + "slice_split_rule", + "squeeze_reshape_1d_rule", + "sub_0_rule", + "successive_clip_relu_rule", + "successive_clip_rule", + "successive_relu_clip_rule", + "successive_relu_rule", + "transpose_a_matmul_add_to_gemm_rule", + "transpose_ab_matmul_add_to_gemm_rule", + "transpose_b_matmul_add_to_gemm_rule", + "transpose_transpose_rule", + "two_reshapes_matmul_reshape_rule", + "unsqueeze_unsqueeze_rule", +] + +from onnxscript.rewriter.rules.common._basic_rules import ( + cast_cast_rule, + flatten_to_reshape_rule, + no_op_cast_rule, + no_op_expand_rule, + no_op_transpose_rule, + reshape_reshape_rule, + slice_split_rule, + squeeze_reshape_1d_rule, + transpose_transpose_rule, + unsqueeze_unsqueeze_rule, +) +from onnxscript.rewriter.rules.common._broadcast_to_matmul import ( + one_reshape_matmul_reshape_rule, + two_reshapes_matmul_reshape_rule, +) +from onnxscript.rewriter.rules.common._cast_constant_of_shape import ( + cast_constant_of_shape_rule, + cast_constant_of_shape_without_value_rule, +) +from onnxscript.rewriter.rules.common._collapse_slices import ( + collapse_slice2_rule, + collapse_slice_rule, +) +from onnxscript.rewriter.rules.common._fuse_batchnorm import ( + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, +) +from onnxscript.rewriter.rules.common._fuse_conv_affine import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) +from onnxscript.rewriter.rules.common._fuse_hardswish import fuse_hardswish_rules +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( + fuse_pad_into_conv_integer_rule, + fuse_pad_into_conv_rule, + normalize_pad_format_conv_integer_rule, + normalize_pad_format_conv_rule, +) +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, + successive_relu_rule, +) +from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule +from onnxscript.rewriter.rules.common._matmul_add_to_gemm import ( + matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_ab_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, +) +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, +) +from onnxscript.rewriter.rules.common._no_op import ( + add_0_rule, + div_by_1_rule, + dropout_inference_rule, + dropout_zero_rule, + mul_by_1_rule, + sub_0_rule, +) +from onnxscript.rewriter.rules.common._redundant_scatter_nd import ( + no_op_dynamic_scatter_nd_rule, + no_op_static_scatter_nd_rule, +) diff --git a/onnxscript/rewriter/rules/common/_basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py new file mode 100644 index 0000000000..b7a648880a --- /dev/null +++ b/onnxscript/rewriter/rules/common/_basic_rules.py @@ -0,0 +1,396 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Basic rewrite rules for general optimization patterns. + +This module contains fundamental optimization rules that are generally applicable +to most ONNX models, including cast elimination, transpose simplification, +shape operation fusion, and other common patterns. +""" + +from __future__ import annotations + +from typing import ClassVar, Sequence + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import _ir_utils as ir_utils +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class SqueezeReshape(RewriteRuleClassBase): + """Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x. + + This pattern arises from the translation of pytorch symints. + """ + + def __init__(self): + super().__init__("SqueezeReshape1d", remove_nodes=False) + + def pattern(self, op, x): + return op.Reshape(op.Squeeze(x), [-1]) + + def rewrite(self, op, x: ir.Value): + return op.Identity(x) + + def check(self, context, x) -> MatchResult: + del context # Unused + check_result = MatchResult() + if not ir_utils.has_rank(x, 1): + return check_result.fail("Input is not 1D") + return check_result + + +class CastIdentity(RewriteRuleClassBase): + """Replaces ``Cast(., to=to)`` by ``Identity`` if possible.""" + + def pattern(self, op, x, to): + return op.Cast(x, to=to) + + def rewrite(self, op, x: ir.Value, to: ir.Attr): + return op.Identity(x) + + def check(self, context, x, to) -> MatchResult: + check_result = MatchResult() + if x.dtype != to.as_int(): + return check_result.fail("Input and output types are not the same") + return check_result + + +class CastCast(RewriteRuleClassBase): + """Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``.""" + + # Simplify "cast type1 => type2 => type3" to "cast type1 => type3". + # This rule is not valid for all combinations of types: e.g., + # it is not valid for float32 => float16 => float32 or float32 => int32 => string. + # TODO: fill out the list of allowed combinations: the following is just a couple + # that shows up in practice where it is valid + _allowed_type2_type3: ClassVar = frozenset( + { + (ir.DataType.FLOAT, ir.DataType.FLOAT16), + (ir.DataType.FLOAT, ir.DataType.BFLOAT16), + } + ) + + def pattern(self, op, x, to, to_ignored): + return op.Cast(op.Cast(x, to=to_ignored), to=to) + + def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: + check_result = MatchResult() + type2 = to_ignored.as_int() + type3 = to.as_int() + if (type2, type3) not in self._allowed_type2_type3: + return check_result.fail( + f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. " + f"Cast-Cast rule may be incomplete for this combination." + ) + return check_result + + def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): + return op.Cast(x, to=to) + + +class ExpandIdentity(RewriteRuleClassBase): + """Replaces ``Expand(..., shape)`` by ``Identity`` if possible.""" + + def pattern(self, op, x, shape): + return op.Expand(x, shape) + + def rewrite(self, op, x: ir.Value, shape: ir.Value): + return op.Identity(x) + + def check(self, context, x, shape) -> MatchResult: + check_result = MatchResult() + if shape.const_value is None: + # Shape is not a constant and cannot be guessed. + return check_result.fail("Shape is not a constant and cannot be guessed.") + if (x_shape := x.shape) is None: + # We don't know the shape of the input + return check_result.fail("Input shape is not known.") + if x_shape.dims != tuple(shape.const_value.numpy().tolist()): + return check_result.fail( + f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}." + ) + return check_result + + +class ReshapeReshape(RewriteRuleClassBase): + """Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``. + The pattern matches only if second reshape reshapes into a shape + with positive values. + """ + + def pattern(self, op, x, shape_ignored, shape): + return op.Reshape(op.Reshape(x, shape_ignored), shape) + + def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): + new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name)) + return op.Reshape(x, new_shape, allowzero=self._allowzero) + + def check(self, context, x, shape_ignored, shape) -> MatchResult: + check_result = MatchResult() + + # Shape must be a constant. + if (np_shape := ir_utils.get_numpy_value(shape)) is None: + return check_result.fail("Shape is not a constant.") + # Convert to array to support assignment destination. + self._new_shape = np.array(np_shape, np_shape.dtype) + + # Try to replace {0,-1} values in shape if reshape output is known. + if (reshape_output := context.output_values[0].shape) is not None: + for i, dim in enumerate(reshape_output): + if isinstance(dim, int) and dim > 0: + self._new_shape[i] = dim + + # Constraints for shape. + self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0) + if self._allowzero == 1 and any(self._new_shape == 0): + return check_result + if any(self._new_shape == 0) and any(self._new_shape < 0): + return check_result.fail("Shape cannot contain both 0 and -1 dimensions.") + elif np.count_nonzero(self._new_shape == 0) > 1: + return check_result.fail("Shape cannot contain more than one 0 dimension.") + + # At this point, we can safely replace '0' with '-1'. + # Note allowzero is removed since at this point it does not have any effect. + self._allowzero = None + self._new_shape = np.where(self._new_shape == 0, -1, self._new_shape) + return check_result + + +class SlicesSplit(RewriteRuleClassBase): + """Replaces ``Slice(x, ...), Slice(x, ...)`` + by ``Split(x, ...)`` if possible. + """ + + def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1): + return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1) + + def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> MatchResult: + check_result = MatchResult() + if ( + axes0.const_value is None + or axes1.const_value is None + or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist() + ): + return check_result.fail("Axes are not equal or not constant.") + axes = axes0.const_value.numpy().tolist() + if len(axes) != 1: + return check_result.fail("Axes has more than one dimension.") + if x.shape: + rk = len(x.shape) + else: + rk = x.rank + if axes[0] != -1 and axes[0] != rk - 1: + return check_result.fail("Axes is not -1 or last dimension.") + if ( + begin0.const_value is None + or end0.const_value is None + or begin1.const_value is None + or end1.const_value is None + ): + return check_result.fail("Begin or end are not constant values.") + if begin0.const_value.numpy().tolist() != [0]: + return check_result.fail("First begin value is not 0.") + e0, b1, e1 = ( + end0.const_value.numpy().tolist(), + begin1.const_value.numpy().tolist(), + end1.const_value.numpy().tolist(), + ) + if e0[0] != b1[0]: + return check_result.fail("End0 is not equal to Begin1.") + shape = x.shape + if shape is None: + return check_result.fail("Shape is not known.") + last_dim = shape[-1] + if not isinstance(last_dim, int): + return check_result.fail("Last dimension is not known.") + if last_dim != e1[0]: + return check_result.fail("Last dimension is not equal to End1.") + if last_dim // 2 != b1[0]: + return check_result.fail("Last dimension is not equal to Begin1.") + return check_result + + def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1): + return op.Split(x, num_outputs=2, axis=-1, _outputs=2) + + +class TransposeIdentity(RewriteRuleClassBase): + """Replaces ``Transpose(. perm=perm)`` + when the permutation is identity. + """ + + def pattern(self, op, x, perm): + return op.Transpose(x, perm=perm) + + def check(self, context, x: ir.Value, perm: ir.Attr) -> MatchResult: + check_result = MatchResult() + if perm.is_ref(): + return check_result.fail("Permutation is a reference attribute.") + if perm.type == ir.AttributeType.INTS: + perm_ints = tuple(perm.as_ints()) + if perm_ints == tuple(range(len(perm_ints))): + return check_result + return check_result.fail("Permutation is not identity.") + + def rewrite(self, op, x: ir.Value, perm: ir.Attr): + return op.Identity(x) + + +class TransposeTranspose(RewriteRuleClassBase): + """Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)`` + when both permutations are inverse. + """ + + def pattern(self, op, x, perm1, perm2): + return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2) + + def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> MatchResult: + check_result = MatchResult() + if perm1.is_ref() or perm2.is_ref(): + return check_result.fail("Permutation is a reference attribute.") + return check_result + + def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]: + assert len(perm) == len(on), "length mismatch" + res = [-1 for i in on] + for i, p in enumerate(perm): + res[i] = on[p] + return res + + def _apply_transposes( + self, perms: list[Sequence[int]], on: list[int] | None = None + ) -> list[int]: + if on is None: + on = list(range(len(perms[0]))) + for p in perms: + on = self._apply_transpose(p, on) + return on + + def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr): + first = list(range(len(perm1.as_ints()))) + last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()]) + if first == last: + return op.Identity(x) + return op.Transpose(x, perm=last) + + +class UnsqueezeUnsqueeze(RewriteRuleClassBase): + """Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze.""" + + def pattern(self, op, x, axes1, axes2): + return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2) + + def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value): + v1 = ir_utils.get_singleton_value(axes1) + v2 = ir_utils.get_singleton_value(axes2) + axes = [v1, v2] if v1 < v2 else [v2, v1 + 1] + return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64))) + + def check(self, context, x, axes1, axes2) -> MatchResult: + check_result = MatchResult() + del context # Unused + del x # Unused + # Currently restricted to single element positive axis + v1 = ir_utils.get_singleton_value(axes1) + v2 = ir_utils.get_singleton_value(axes2) + if v1 is None or v2 is None: + return check_result.fail("Axes are not constant.") + if (v1 < 0) or (v2 < 0): + return check_result.fail("Axes are negative.") + return check_result + + +class Flatten2Reshape(RewriteRuleClassBase): + """Convert ``Flatten(x)`` to Reshape.""" + + def pattern(self, op, x: ir.Value): + return op.Flatten(x) + + def rewrite(self, op, x: ir.Value): + new_shape = op.initializer(ir.Tensor(self._new_shape, name=f"{x.name}/shape")) + return op.Reshape(x, new_shape) + + def check(self, context, x: ir.Value) -> MatchResult: + check_result = MatchResult() + self._new_shape = np.array([-1, -1], "int64") + + # Convert axis in a positive value if possible. + axis = context.root.attributes.get_int("axis", 1) + input_rank = None + if (input_shape := x.shape) is not None: + input_rank = len(input_shape) + if axis < 0: + axis += input_rank + + # Compute reshape shape following axis attribute. + if axis == 0: + self._new_shape[0] = 1 + elif axis == 1: + self._new_shape[0] = 0 + elif axis == input_rank: + self._new_shape[1] = 1 + + # Try to update shape if output is known. + if (output_shape := context.output_values[0].shape) is not None: + for i, dim in enumerate(output_shape): + if isinstance(dim, int): + self._new_shape[i] = dim + + # Try to update shape if input is known. + if input_shape is not None: + if all(isinstance(dim, int) for dim in input_shape[:axis]): + self._new_shape[0] = np.prod(input_shape[:axis]) + if all(isinstance(dim, int) for dim in input_shape[axis:]): + self._new_shape[1] = np.prod(input_shape[axis:]) + + # Verify if it is possible to apply rule. + if np.count_nonzero(self._new_shape == -1) > 1: + return check_result.fail("Impossible to compute new shape.") + return check_result + + +# Create rule instances +cast_cast_rule = CastCast.rule() +no_op_cast_rule = CastIdentity.rule() +no_op_expand_rule = ExpandIdentity.rule() +reshape_reshape_rule = ReshapeReshape.rule() +slice_split_rule = SlicesSplit.rule() +no_op_transpose_rule = TransposeIdentity.rule() +transpose_transpose_rule = TransposeTranspose.rule() +unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule() +squeeze_reshape_1d_rule = SqueezeReshape.rule() +flatten_to_reshape_rule = Flatten2Reshape.rule() + + +def basic_optimization_rules() -> RewriteRuleSet: + """Returns a set of basic optimization rules. + + These rules perform fundamental optimizations such as: + - Eliminating redundant cast operations + - Simplifying consecutive operations of the same type + - Removing identity operations + - Optimizing shape manipulation operations + + These rules are generally safe to apply as a first optimization pass + before other more specialized optimizations. + + Returns: + RewriteRuleSet: A collection of basic optimization rules + """ + return RewriteRuleSet( + [ + cast_cast_rule, + no_op_cast_rule, + no_op_expand_rule, + # flatten_to_reshape_rule is order sensitive to reshape_reshape_rule + flatten_to_reshape_rule, + reshape_reshape_rule, + slice_split_rule, + no_op_transpose_rule, + transpose_transpose_rule, + unsqueeze_unsqueeze_rule, + squeeze_reshape_1d_rule, + ] + ) diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py new file mode 100644 index 0000000000..7d4e9d9b33 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -0,0 +1,615 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from typing import Any + +import numpy as np +import onnx +import onnx.reference +import parameterized + +import onnxscript +import onnxscript.onnx_types as ot +from onnxscript import ir +from onnxscript.onnx_opset import opset18 +from onnxscript.rewriter import MatchingTracer, testing +from onnxscript.rewriter import pattern as orp +from onnxscript.rewriter.rules.common import _basic_rules + +FLOAT = onnx.TensorProto.FLOAT + + +@onnxscript.script() +def cast_identity_model(x: ot.FLOAT["a", "b", "c"]) -> ot.FLOAT["a", "b", "c"]: # noqa: F821, UP037 + y = opset18.Cast(x, to=onnx.TensorProto.FLOAT) + return y + + +def _make_model(*args, **kwargs) -> ir.Model: + return ir.serde.deserialize_model(onnx.helper.make_model(*args, **kwargs)) + + +def clone_model(model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + +class BasicRulesTest(unittest.TestCase): + def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: + feeds: dict[str, Any] = {} + for i in model.graph.input: + ish = tuple(i.type.tensor_type.shape.dim) + # Creates an input tensor with a dimension defined by the onnx model + # or equals to i + 2 with i being the dimension index. + # The tensor is kept small to make the test fast. + shape = tuple( + (d.dim_value if d.dim_value > 0 else i + 2) for i, d in enumerate(ish) + ) + if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT: + feeds[i.name] = np.random.randn(*shape).astype(np.float32) + else: + raise AssertionError(f"Not implemented for input {i}") + return feeds + + def _check_model( + self, + model: onnx.ModelProto, + optimized_model: onnx.ModelProto, + feeds: dict[str, Any] | None = None, + atol: float = 0.0, + rtol: float = 1e-7, + ): + if not feeds: + feeds = self._get_random_inputs(model) + ref = onnx.reference.ReferenceEvaluator(model) + opt = onnx.reference.ReferenceEvaluator(optimized_model) + expected = ref.run(None, feeds) + got = opt.run(None, feeds) + self.assertEqual(len(expected), len(got)) + for a, b in zip(expected, got): + np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) + + @parameterized.parameterized.expand( + [ + ( + "no_op_transpose", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 1, 2]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ), + ( + "canceled_out_transposes", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 0]), + onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 0]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [None, None])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None])], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ), + ] + ) + def test_basic_optimization_rules_identity(self, _: str, model: ir.Model): + rule_set = _basic_rules.basic_optimization_rules() + model_proto = ir.serde.serialize_model(model) + rule_set.apply_to_model(model) + rewritten_model = ir.serde.serialize_model(model) + + self.assertEqual(["Identity"], [n.op_type for n in model.graph]) + self._check_model(model_proto, rewritten_model) + + @parameterized.parameterized.expand( + [ + ( + "consecutive_transposes", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]), + onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ), + ] + ) + def test_basic_optimization_rules_transpose_transpose(self, _: str, model: ir.Model): + rule_set = _basic_rules.basic_optimization_rules() + model_proto = ir.serde.serialize_model(model) + rule_set.apply_to_model(model) + rewritten_model = ir.serde.serialize_model(model) + self.assertEqual(["Transpose"], [n.op_type for n in model.graph]) + self._check_model(model_proto, rewritten_model) + + def _double_cast_model(self, ostype1, ostype2, ostype3): + dtype2 = ostype2.dtype + dtype3 = ostype3.dtype + + @onnxscript.script() + def cast_cast_model(x): + intermediate = opset18.Cast(x, to=dtype2) + y = opset18.Cast(intermediate, to=dtype3) + return y + + return cast_cast_model.to_model_proto( + input_types=[ostype1[10]], output_types=[ostype3[10]] + ) + + @parameterized.parameterized.expand( + [ + ("float16_float_float16", ot.FLOAT16, ot.FLOAT, ot.FLOAT16), + ] + ) + def test_cast_cast_rule(self, _: str, type1, type2, type3): + rule = _basic_rules.cast_cast_rule + model_proto = self._double_cast_model(type1, type2, type3) + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + _rewritten_model = ir.serde.serialize_model(model) + + self.assertEqual(["Cast"], [n.op_type for n in model.graph]) + # TODO: (random) fp16 inputs + # self._check_model(model_proto, rewritten_model, atol=1e-2) + + @parameterized.parameterized.expand( + [ + ( + "cast_identity", + ir.serde.deserialize_model(cast_identity_model.to_model_proto()), + ), + ] + ) + def test_cast_identity_rule(self, _: str, model: ir.Model): + rule_set = _basic_rules.basic_optimization_rules() + model_proto = ir.serde.serialize_model(model) + rule_set.apply_to_model(model) + rewritten_model = ir.serde.serialize_model(model) + + self.assertEqual(["Identity"], [n.op_type for n in model.graph]) + self._check_model(model_proto, rewritten_model) + + @parameterized.parameterized.expand( + [ + ( + "normal_case", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Expand", ["X", "shape"], ["Y"]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [3, 4, 5])], + [ + onnx.numpy_helper.from_array( + np.array([3, 4, 5], dtype=np.int64), name="shape" + ) + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ("Identity",), + ), + ( + "input_no_shape", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Identity", ["X"], ["Y"]), + onnx.helper.make_node("Expand", ["Y", "shape"], ["Z"]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [3, 4, 5])], + [ + onnx.numpy_helper.from_array( + np.array([3, 4, 5], dtype=np.int64), name="shape" + ) + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ("Identity", "Expand"), + ), + ] + ) + def test_expand_identity_rule( + self, _: str, model: ir.Model, expected_nodes: tuple[str, ...] + ): + rule_set = _basic_rules.basic_optimization_rules() + model_proto = ir.serde.serialize_model(model) + rule_set.apply_to_model(model) + rewritten_model = ir.serde.serialize_model(model) + + self.assertEqual(tuple(n.op_type for n in model.graph), expected_nodes) + self._check_model(model_proto, rewritten_model) + + @parameterized.parameterized.expand( + [ + ( + "double_unsqueezes_1", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Unsqueeze", ["X", "axes1"], ["Xu"]), + onnx.helper.make_node("Unsqueeze", ["Xu", "axes2"], ["Y"]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [1, 3, 1])], + [ + onnx.numpy_helper.from_array( + np.array([1], dtype=np.int64), name="axes1" + ), + onnx.numpy_helper.from_array( + np.array([0], dtype=np.int64), name="axes2" + ), + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ), + ( + "double_unsqueezes_2", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Unsqueeze", ["X", "axes1"], ["Xu"]), + onnx.helper.make_node("Unsqueeze", ["Xu", "axes2"], ["Y"]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [1, 3, 1])], + [ + onnx.numpy_helper.from_array( + np.array([0], dtype=np.int64), name="axes1" + ), + onnx.numpy_helper.from_array( + np.array([1], dtype=np.int64), name="axes2" + ), + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ), + ( + "double_unsqueezes_3", + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Unsqueeze", ["X", "axes1"], ["Xu"]), + onnx.helper.make_node("Unsqueeze", ["Xu", "axes2"], ["Y"]), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3])], + [onnx.helper.make_tensor_value_info("Y", FLOAT, [1, 3, 1])], + [ + onnx.numpy_helper.from_array( + np.array(0, dtype=np.int64), name="axes1" + ), + onnx.numpy_helper.from_array( + np.array(1, dtype=np.int64), name="axes2" + ), + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ), + ] + ) + def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): + rule_set = _basic_rules.basic_optimization_rules() + model_proto = ir.serde.serialize_model(model) + rule_set.apply_to_model(model) + rewritten_model = ir.serde.serialize_model(model) + + self.assertEqual(["Constant", "Unsqueeze"], [n.op_type for n in model.graph]) + self._check_model(model_proto, rewritten_model) + + @classmethod + def _slices_split_models(cls): + models = [ + _make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node( + "Slice", ["X", "zero", "half", "axis"], ["spl1"] + ), + onnx.helper.make_node( + "Slice", ["X", "half", "last", "axis"], ["spl2"] + ), + ], + "name", + [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 6])], + [ + onnx.helper.make_tensor_value_info("spl1", FLOAT, [3, 4, 3]), + onnx.helper.make_tensor_value_info("spl2", FLOAT, [3, 4, 3]), + ], + [ + onnx.numpy_helper.from_array( + np.array([0], dtype=np.int64), name="zero" + ), + onnx.numpy_helper.from_array( + np.array([3], dtype=np.int64), name="half" + ), + onnx.numpy_helper.from_array( + np.array([6], dtype=np.int64), name="last" + ), + onnx.numpy_helper.from_array( + np.array([2], dtype=np.int64), name="axis" + ), + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + ] + return models + + @unittest.skipIf(True, reason="see https://github.com/microsoft/onnxscript/issues/1642") + def test_slices_split_rule(self): + for model_proto in self._slices_split_models(): + ir_model = ir.serde.deserialize_model(model_proto) + rule_set = _basic_rules.basic_optimization_rules() + rule_set.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + + self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node]) + self._check_model(model_proto, rewritten_model) + + def test_squeeze_reshape_1d_rule(self): + rule = _basic_rules.squeeze_reshape_1d_rule + + def check(model_script, expected_count) -> None: + model_proto = model_script.to_model_proto() + ir_model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(ir_model) + self.assertEqual(count, expected_count) + if count > 0: + self.assertEqual([x.op_type for x in ir_model.graph], ["Identity"]) + rewritten_proto = ir.serde.serialize_model(ir_model) + self._check_model(model_proto, rewritten_proto) + + op = onnxscript.opset17 + + # input of shape [12] + @onnxscript.script() + def model1(X: ot.FLOAT[12]): + return op.Reshape(op.Squeeze(X), [-1]) + + check(model1, 1) + + # input of shape [1] + @onnxscript.script() + def model2(X: ot.FLOAT[1]): + return op.Reshape(op.Squeeze(X), [-1]) + + check(model2, 1) + + # input of shape [1, 1] + # This should NOT be optimized to Identity + @onnxscript.script() + def model3(X: ot.FLOAT[1, 1]): + return op.Reshape(op.Squeeze(X), [-1]) + + check(model3, 0) + + +class ReshapeReshapeTest(unittest.TestCase): + @staticmethod + def create_model( + input_shape, shape1, shape2, allowzero1=0, allowzero2=0, infer_shape=False + ): + def _convert_shape(shape, name): + if isinstance(shape, np.ndarray): + shape = tape.initializer(ir.Tensor(shape, name=name)) + elif isinstance(shape, (list, tuple)): + shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape)) + tape.graph_like.inputs.append(shape) + else: + raise TypeError(f"Unsupported type {type(shape)} for shape.") + return shape + + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + reshape = tape.op( + "Reshape", + inputs=[x, _convert_shape(shape1, "shape_")], + attributes={"allowzero": allowzero1}, + ) + tape.op( + "Reshape", + inputs=[reshape, _convert_shape(shape2, "shape")], + attributes={"allowzero": allowzero2}, + output=y, + ) + model = ir.Model(tape.graph_like, ir_version=10) + + # Infer shapes. + if infer_shape: + model = ir.passes.common.ShapeInferencePass()(model).model + return model + + @parameterized.parameterized.expand( + [ + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 8), [2, 0, 3, -1], [0, 3, 2, 8]), + ((3, 4, 8), [3, 4, -1], [-1, 12], 1), + ((3, 4, 2), [0, 4, -1], [12, -1], 0, 1), + ((3, 0, 8), [4, 2, 0, 0], [3, 0], 1, 1), + ] + ) + def test_reshape_reshape_rule( + self, input_shape, shape1, shape2, allowzero1=0, allowzero2=0 + ): + model = self.create_model( + input_shape, + np.array(shape1, dtype="int64"), + np.array(shape2, dtype="int64"), + allowzero1=allowzero1, + allowzero2=allowzero2, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(10).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand([([3, 2, 3, 3, 3], 1), ([0, -1, 3, 2], 0)]) + def test_reshape_dynamic_reshape_rule(self, shape1, allowzero1=0): + input_shape = (3, 6, 9) + shape1 = np.array(shape1, dtype="int64") + # Build the model with unknown shape1. + model = self.create_model( + input_shape, + (shape1.size,), + np.array((1, 6, 27), dtype="int64"), + allowzero1=allowzero1, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + feeds = { + "X": np.random.default_rng(2).random(input_shape, dtype="float32"), + "shape_": shape1, + } + testing.assert_numerically_equal(model, updated_model, feeds, atol=0, rtol=0) + + @parameterized.parameterized.expand( + [((3, 6, 9), [0, 3, 2, -1]), ((0, 6, 2), [0, 0, 3], 1)] + ) + def test_reshape_reshape_dynamic_rule(self, input_shape, shape2, allowzero2=0): + # Note that shape inference is required for this test to be valid. + shape2 = np.array(shape2, dtype="int64") + model = self.create_model( + input_shape, + np.array((3, 2, -1), dtype="int64"), + shape2, + allowzero2=allowzero2, + infer_shape=True, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(7).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + ((3,), "is not a constant"), + (np.array([0, -1], dtype="int64"), "both 0 and -1 dimensions"), + (np.array([0, 0, 3], dtype="int64"), "more than one 0 dimension"), + ] + ) + def test_unsupported_reshape_reshape(self, shape2, error_msg): + model = self.create_model((1, 2, 3), np.array([1, 6], dtype="int64"), shape2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.reshape_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.reshape_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, error_msg) + + +class Flatten2ReshapeTest(unittest.TestCase): + @staticmethod + def create_model(input_shape, axis=1): + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + tape.op("Flatten", inputs=[x], attributes={"axis": axis}, output=y) + model = ir.Model(tape.graph_like, ir_version=10) + return model + + @parameterized.parameterized.expand(list(range(-5, 6))) + def test_flatten_to_reshape_rule(self, axis): + input_shape = (1, 4, 8, 7, 5) + model = self.create_model(input_shape=input_shape, axis=axis) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(13).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand(list(range(-4, 5))) + def test_flatten_to_reshape_dynamic_input(self, axis): + model = self.create_model(input_shape=("N", "C1", "C2", "C3"), axis=axis) + # Rule is supported in all cases if the output shape is known for non-special cases. + input_shape = (1, 2, 3, 4) + if axis not in {-3, 0, 1, 4}: + out_shape = ir.Shape((np.prod(input_shape[:axis]), np.prod(input_shape[axis:]))) + model.graph.outputs[0].shape = out_shape + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(17).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + def test_unsupported_flatten_to_reshape(self): + model = self.create_model(input_shape=("N", "C1", "C2"), axis=2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.flatten_to_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, "Impossible to compute new shape") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul.py similarity index 58% rename from onnxscript/rewriter/broadcast_to_matmul.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul.py index bc45e06b50..ddf00bc327 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/rules/common/_broadcast_to_matmul.py @@ -1,18 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import logging -import numpy as np +from onnxscript import ir +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet -from onnxscript.rewriter import _ir_utils, pattern - -op = pattern.onnxop logger = logging.getLogger(__name__) -# condition to check if we need to replace the pattern -def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool: - """If matmul broadcasting is enough, then we don't need the reshapes. +def check_if_not_need_reshape( + context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_ +) -> bool: + """Condition to check if we need to replace the pattern. + + If matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: 1. Input shapes check: input_a and input_b should be broadcastable @@ -21,65 +24,73 @@ def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool: If the above are true, then we don't need the reshapes. Returns: - bool: True if we need to replace the pattern, False otherwise. - + True if we need to replace the pattern, False otherwise. """ + del context # Reserved for future extensions + input_a_shape = input_a.shape input_b_shape = input_b.shape - # TODO: Get a helper func to get const_value - shape_c_value = _ir_utils.propagate_const_value(shape_c) - shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr] - if shape_c is None: - return False - if not isinstance(shape_c, np.ndarray): - logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c)) + shape_c_tensor = shape_c.const_value + if shape_c_tensor is None: + logger.info("The value 'shape_c' is not statically known.") return False - if len(shape_c.shape) != 1: + + if len(shape_c_tensor.shape) != 1: logger.info( "Unexpected final shape. The shape of 'shape' value is %s", - shape_c.shape, + shape_c_tensor.shape, ) return False - shape_c = shape_c.tolist() # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape # information. So, we need to check if the shape is None and return False. - if input_a_shape is None or input_b_shape is None or shape_c is None: + if input_a_shape is None or input_b_shape is None: logger.info("Shape information is not available for the inputs and outputs.") return False - input_a_shape = list(input_a_shape) - input_b_shape = list(input_b_shape) + if any(isinstance(dim, ir.SymbolicDim) for dim in input_a_shape): + logger.info("Symbolic dimensions are not yet supported.") + return False + if any(isinstance(dim, ir.SymbolicDim) for dim in input_b_shape): + logger.info("Symbolic dimensions are not yet supported.") + return False + input_a_shape = input_a_shape.numpy() # type: ignore[assignment] + input_b_shape = input_b_shape.numpy() # type: ignore[assignment] + shape_c = shape_c_tensor.numpy().tolist() # type: ignore[assignment] - dim_a = len(input_a_shape) - dim_b = len(input_b_shape) + a_rank = len(input_a_shape) + b_rank = len(input_b_shape) # 1. Check if input shapes are broadcastable # 1.a. If the first input is 1-D, check whether # the dim matches the last second dim of the second input. - mimic_matmul_broadcast_behavior = False - if dim_a < 2: + mimic_matmul_broadcast_behavior_a = False + mimic_matmul_broadcast_behavior_b = False + if a_rank < 2: + if b_rank < 2: + logger.info("Optimization of dot product is not supported yet.") + return False if input_a_shape[-1] != input_b_shape[-2]: logger.info("Original shape is not MatMul compatible.") return False else: - input_a_shape = [1, *input_a_shape] - dim_a = len(input_a_shape) - mimic_matmul_broadcast_behavior = True + input_a_shape = [1, *input_a_shape] # type: ignore[assignment] + a_rank = len(input_a_shape) + mimic_matmul_broadcast_behavior_a = True # 1.b. If the second input is 1-D, check whether # the dim matches the last dim of the first input. - if dim_b < 2: + if b_rank < 2: if input_b_shape[-1] != input_a_shape[-1]: logger.info("Original shape is not MatMul compatible.") return False else: - input_b_shape = [*input_b_shape, 1] - dim_b = len(input_b_shape) - mimic_matmul_broadcast_behavior = True + input_b_shape = [*input_b_shape, 1] # type: ignore[assignment] + b_rank = len(input_b_shape) + mimic_matmul_broadcast_behavior_b = True # 1.c. If both inputs are at least 2-D, check whether # the last dimension of the first input matches the second # last dimension of the second input, and shape[:-2] are # broadcastable. - input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] + input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]] input_b_shape_except_last_dim = input_b_shape[:-1] broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] for idx, (dim_from_a, dim_from_b) in enumerate( @@ -93,25 +104,26 @@ def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool: return False elif idx > 0: broadcast_matmul_output_shape = [ - max(dim_from_a, dim_from_b), + max(dim_from_a, dim_from_b), # type: ignore[type-var] *broadcast_matmul_output_shape, ] # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) # Prepend the broadcast_matmul_output_shape with the longer shape of input - if dim_a > dim_b: + if a_rank > b_rank: longer_shape = input_a_shape shorter_shape = input_b_shape else: longer_shape = input_b_shape shorter_shape = input_a_shape - broadcast_matmul_output_shape = ( - longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape - ) - if mimic_matmul_broadcast_behavior and dim_b == 2 and input_b_shape[-1] == 1: + broadcast_matmul_output_shape = [ + *longer_shape[: -len(shorter_shape)], + *broadcast_matmul_output_shape, + ] + if mimic_matmul_broadcast_behavior_b and b_rank == 2 and input_b_shape[-1] == 1: # If input_b is expanded to 2-D, then we need to remove the last dimension broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] - if mimic_matmul_broadcast_behavior and dim_a == 2 and input_a_shape[0] == 1: + if mimic_matmul_broadcast_behavior_a and a_rank == 2 and input_a_shape[0] == 1: # If input_a is expanded to 2-D, then we need to remove the first dimension # of input_a, which would be the -2nd dimension of the output shape. broadcast_matmul_output_shape.pop(-2) @@ -126,7 +138,7 @@ def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool: return True -def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c): +def _two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c): # TODO: Modified from `value_ints` to `value` to match pattern in benchmark models. # This implementation misses pattern of Constants with `value_ints` attribute. # See more at https://github.com/microsoft/onnx-rewriter/issues/191. @@ -138,31 +150,29 @@ def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shap return op.Reshape(matmul, shape_c) -def matmul(op, input_a, input_b, **_): +def _matmul(op, input_a, input_b, **_): return op.MatMul(input_a, input_b) -def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c): +def _one_reshape_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_c): reshape_a = op.Reshape(input_a, shape_a) matmul = op.MatMul(reshape_a, input_b) return op.Reshape(matmul, shape_c) # Register the rewrite rules -two_reshapes_matmul_reshape_rule = pattern.RewriteRule( - two_reshapes_matmul_reshape_pattern, - matmul, +two_reshapes_matmul_reshape_rule = RewriteRule( + _two_reshapes_matmul_reshape_pattern, + _matmul, check_if_not_need_reshape, ) -one_reshape_matmul_reshape_rule = pattern.RewriteRule( - one_reshape_matmul_reshape_pattern, - matmul, +one_reshape_matmul_reshape_rule = RewriteRule( + _one_reshape_matmul_reshape_pattern, + _matmul, # We can use the same check_if_not_need_reshape function for both the rules, - # as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern. + # as one_reshape_matmul_reshape_pattern is a subset of _two_reshapes_matmul_reshape_pattern. check_if_not_need_reshape, ) # NOTE: The order of the rules is important. Larger pattern should be checked first. -rules = pattern.RewriteRuleSet( - [two_reshapes_matmul_reshape_rule, one_reshape_matmul_reshape_rule] -) +rules = RewriteRuleSet([two_reshapes_matmul_reshape_rule, one_reshape_matmul_reshape_rule]) diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py similarity index 72% rename from onnxscript/rewriter/broadcast_to_matmul_test.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py index a654a5734d..4e33544986 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py @@ -1,10 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + import unittest import onnx.parser import onnx.shape_inference +import parameterized from onnxscript import ir -from onnxscript.rewriter import broadcast_to_matmul +from onnxscript.rewriter.rules.common import _broadcast_to_matmul + + +def _infer_shapes(model: ir.Model) -> ir.Model: + """Run shape inference on the IR model.""" + # TODO: Update when shape inference is supported on the IR + return ir.serde.deserialize_model( + onnx.shape_inference.infer_shapes(ir.serde.serialize_model(model)) + ) class TwoReshapesMatMulReshapeTest(unittest.TestCase): @@ -25,10 +38,82 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) + @parameterized.parameterized.expand( + [ + ( + "0d", + [], + [1, 1], + [], + [1, 1], + [1, 1], + [1, 1], + ), + ( + "x_1d", + [4], + [1, 4], + [4, 2], + [4, 2], + [1, 2], + [1, 2], + ), + ( + "y_1d", + [1, 4], + [1, 4], + [2], + [4, 2], + [1, 2], + [1, 2], + ), + ( + "both_1d", + [2], + [1, 2], + [2], + [2, 1], + [], + [], + ), + ] + ) + def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match( + self, + _: str, + input_x_shape: list[int], + shape_a: list[int], + input_y_shape: list[int], + shape_b: list[int], + output_shape: list[int], + shape_c: list[int], + ): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float{input_x_shape} input_x, float{input_y_shape} input_y) => (float{output_shape} output) + {{ + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = _broadcast_to_matmul.rules.apply_to_model(model) + self.assertEqual(count, 0) + self.assertEqual(len(model.graph), 7) + model = _infer_shapes(model) + self.assertEqual(model.graph.outputs[0].shape, output_shape) + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nested_function( self, ): @@ -66,7 +151,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest ) ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -93,7 +178,7 @@ def test_reshape_matmul_reshape_remain_when_input_last_dim_and_second_last_dim_n """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -117,7 +202,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_inputs_are_not_broadcast ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -141,7 +226,7 @@ def test_reshape_matmul_reshape_replace_when_inputs_are_broadcastable_with_one_i """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -164,7 +249,30 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_br """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 4) + + def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_second_isexpanded_alike_and_broadcastable( + self, + ): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[5] input_x, float[5, 1] input_y) => (float[1] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -187,7 +295,7 @@ def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -210,7 +318,7 @@ def test_reshape_matmul_reshape_replace_when_second_input_is_one_dimension_and_b """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -234,7 +342,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_second_input_is_one_dime ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -258,7 +366,7 @@ def test_reshape_matmul_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -279,7 +387,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) # The constant nodes are not removed. They should be removed by a subsequent DCE in optimizer. self.assertEqual(len(model.graph), 3) diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape.py similarity index 60% rename from onnxscript/rewriter/cast_constant_of_shape.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape.py index ce5c8b8f2e..030302f722 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/rules/common/_cast_constant_of_shape.py @@ -1,17 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import logging -import onnx.helper - from onnxscript import ir -from onnxscript.rewriter import pattern +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet -op = pattern.onnxop logger = logging.getLogger(__name__) -def cast_constant_of_shape(shape, scalar, dtype): +def cast_constant_of_shape(op, shape, scalar, dtype): constant = op.ConstantOfShape(shape, value=scalar) return op.Cast(constant, to=dtype) @@ -19,29 +18,27 @@ def cast_constant_of_shape(shape, scalar, dtype): def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_): # Cast scalar (a TensorProto attribute) to the specified dtype scalar_value = scalar.value.numpy().item() - cast_value = onnx.helper.make_tensor("value", dtype.value, (1,), [scalar_value]) + cast_value = ir.tensor([scalar_value], dtype=ir.DataType(dtype.as_int())) return op.ConstantOfShape(shape, value=cast_value) -def cast_constant_of_shape_without_value(shape, dtype): +def cast_constant_of_shape_without_value(op, shape, dtype): constant = op.ConstantOfShape(shape) return op.Cast(constant, to=dtype) def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_): - zero = onnx.helper.make_tensor("value", dtype.value, (1,), [0]) + zero = ir.tensor([0], dtype=ir.DataType(dtype.as_int())) return op.ConstantOfShape(shape, value=zero) -cast_constant_of_shape_rule = pattern.RewriteRule( - cast_constant_of_shape, fused_cast_constant_of_shape -) +cast_constant_of_shape_rule = RewriteRule(cast_constant_of_shape, fused_cast_constant_of_shape) -cast_constant_of_shape_without_value_rule = pattern.RewriteRule( +cast_constant_of_shape_without_value_rule = RewriteRule( cast_constant_of_shape_without_value, fused_cast_constant_of_shape_without_value ) -rules = pattern.RewriteRuleSet( +rules = RewriteRuleSet( [ cast_constant_of_shape_rule, cast_constant_of_shape_without_value_rule, diff --git a/onnxscript/rewriter/cast_constant_of_shape_test.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py similarity index 86% rename from onnxscript/rewriter/cast_constant_of_shape_test.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py index c16ac082d6..794491024b 100644 --- a/onnxscript/rewriter/cast_constant_of_shape_test.py +++ b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py @@ -1,10 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import onnx.checker import onnx.parser from onnxscript import ir -from onnxscript.rewriter import cast_constant_of_shape +from onnxscript.rewriter.rules.common import _cast_constant_of_shape class CastConstantOfShapeTest(unittest.TestCase): @@ -21,7 +23,7 @@ def test_cast_after_constant_of_shape_is_fused(self): ) onnx.checker.check_model(input_model_proto, True) model = ir.serde.deserialize_model(input_model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) @@ -40,7 +42,7 @@ def test_cast_after_constant_of_shape_without_value_is_fused(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) diff --git a/onnxscript/rewriter/rules/common/_collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py new file mode 100644 index 0000000000..21b2694b82 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import logging + +from onnxscript import ir +from onnxscript.rewriter import _ir_utils +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet + +logger = logging.getLogger(__name__) +_INT64_MAX = 9223372036854775807 + + +def _check_if_redundant_slice( + context, + data: ir.Value, + starts: ir.Value, + ends: ir.Value, + axes: ir.Value, + steps: ir.Value, + **_, +) -> bool: + """If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant.""" + del context # Reserved for future extensions + + starts_const = starts.const_value + ends_const = ends.const_value + axes_const = axes.const_value + steps_const = steps.const_value + + if starts_const is None or ends_const is None or axes_const is None or steps_const is None: + logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.") + return False + + # Check if the values are scalar + if starts_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'start' is not a scalar.") + return False + if ends_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'end' is not a scalar.") + return False + if axes_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'axis' is not a scalar.") + return False + if steps_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'step' is not a scalar.") + return False + + if steps_const.numpy().item() != 1: + logger.info("The value 'step' is not 1.") + return False + # starts is 0 + if starts_const.numpy().item() != 0: + logger.info("The value 'start' is not 0.") + return False + # In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize + if ends_const.numpy().item() == _INT64_MAX: + return True + if data.shape is None or data.shape.is_dynamic(axes_const.numpy().item()): + logger.info("The value 'data' shape is not statically known.") + return False + if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]: + logger.info("The value 'end' is less than the shape of the specified axis.") + return False + + return True + + +def _identity_to_itself(op, data, **_): + """Return the input data as the output.""" + return op.Identity(data) + + +def _potential_redundant_slice(op, data, starts, ends, axes, steps): + """To identify a slice op""" + return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"]) + + +def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_): + """Check if the shape of the slice output is the same as the data.""" + if data.shape is None or slice_output.shape is None: + return False + + if not _ir_utils.is_singleton_value(steps, 1): + return False + + return _ir_utils.same_shape(data.shape, slice_output.shape) + + +# Register the rewrite rules +collapse_slice_rule = RewriteRule( + _potential_redundant_slice, + _identity_to_itself, + _check_if_redundant_slice, +) + +collapse_slice2_rule = RewriteRule( + _potential_redundant_slice, + _identity_to_itself, + _same_shape, +) + +# NOTE: The second rule subsumes the first one. So, we may be able to remove the first one, +# provided shape-inference is run before the rewriter and computes the shape of the slice output. + +rules = RewriteRuleSet([collapse_slice_rule, collapse_slice2_rule]) diff --git a/onnxscript/rewriter/rules/common/_collapse_slices_test.py b/onnxscript/rewriter/rules/common/_collapse_slices_test.py new file mode 100644 index 0000000000..727240344d --- /dev/null +++ b/onnxscript/rewriter/rules/common/_collapse_slices_test.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import onnx.parser + +from onnxscript import ir +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _collapse_slices + +_INT64_MAX = 9223372036854775807 + + +class TwoReshapesMatMulReshapeTest(unittest.TestCase): + def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[512, 16, 112] data) => (float[512, 16, 112] output) + { + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = _collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 5) + self.assertIn("Identity", [node.op_type for node in model.graph]) + testing.assert_numerically_equal( + model_proto, + model, + (np.random.rand(512, 16, 112).astype(np.float32),), + ) + + def test_slice_is_redundant_when_ends_reaches_int64_max(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[512, 16, 112] data) => (float[512, 16, 112] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = _collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 5) + self.assertIn("Identity", [node.op_type for node in model.graph]) + testing.assert_numerically_equal( + model_proto, + model, + (np.random.rand(512, 16, 112).astype(np.float32),), + ) + + def test_slice_unequal_dynamic_shape(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[L, M, N] data) => (float[P, M, N] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = _collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 0) + + def test_slice_equal_dynamic_shape(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[L, M, N] data) => (float[L, M, N] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = _collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1) + + def test_slice_equal_dynamic_shape_but_step_reverse(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[L, M, N] data) => (float[L, M, N] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = _collapse_slices.rules.apply_to_model(model) + # Should not change the output shape if we did not use the default step of 1 + self.assertEqual(count, 0) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py new file mode 100644 index 0000000000..9d8b8f23f4 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns: +- BatchNormalization ∘ Conv -> Conv +- BatchNormalization ∘ ConvTranpose -> ConvTranpose +- BatchNormalization ∘ Gemm -> Gemm + +Approach: + Given an inbound operation output: Y = W * X + B + And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps) + + The fusion updates the inbound weights as follows: + - W_fused = W * (gamma / std) + - B_fused = (B - μ) * (gamma / std) + β +""" + +from abc import ABC, abstractmethod +from typing import ClassVar, Mapping + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray: + # Build shape: 1s everywhere except -1 at the target axis + broadcast_shape = [1 if axis != i else -1 for i in range(rank)] + return np.reshape(x, broadcast_shape) + + +class _FuseBatchNormBase(RewriteRuleClassBase, ABC): + """Interface for BatchNormalization nodes fusion.""" + + @abstractmethod + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + """Return the axis along which BatchNorm scale should be broadcasted.""" + + def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value): + batchnorm_node = batchnorm_out.producer() + # Get BatchNorm parameters + gamma, beta, input_mean, input_var = [ + inp.const_value.numpy() for inp in batchnorm_node.inputs[1:] + ] + + # 1e-5 is the default value for epsilon according to + # https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes + default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5) + eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float() + + # Compute the scale_factor to update the inbound weights and bias + scale_factor = gamma / np.sqrt(input_var + eps) + + # Update inbound weights + inbound_node = inbound_out.producer() + weights = inbound_node.inputs[1].const_value.numpy() + + # Reshape scale factor so it is broadcastable + axis = self.get_filters_axis(inbound_node.attributes) + fused_weights = ir.tensor( + weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) + ) + + # Update bias + if len(inbound_node.inputs) > 2: + original_bias = inbound_node.inputs[2].const_value.numpy() + bias_name = inbound_node.inputs[2].name + else: + original_bias = np.zeros_like(input_mean) + bias_name = x.name + "_bias" + fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta) + + return op.op( + self.op_type, + inputs=[ + x, + op.initializer(fused_weights, name=inbound_node.inputs[1].name), + op.initializer(fused_bias, name=bias_name), + ], + attributes=inbound_node.attributes, + ) + + def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> MatchResult: + del context # Unused + check_result = MatchResult() + + inbound_node = inbound_out.producer() + batchnorm_node = batchnorm_out.producer() + + # Check that inbound weights + (inbound bias) + batchnorm params are initializers + # and that they are not graph inputs + initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]] + if len(inbound_node.inputs) > 2: + initializers.append(inbound_node.inputs[2]) + + for initializer in initializers: + if not initializer.is_initializer() or initializer.const_value is None: + return check_result.fail(f"{initializer.name} is not a constant initializer.") + if initializer.is_graph_input(): + return check_result.fail(f"{initializer.name} is a graph input.") + + return check_result + + +class FuseBatchNormIntoConv(_FuseBatchNormBase): + """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" + + op_type: ClassVar = "Conv" + + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + return 0 + + def pattern(self, op, x): + return op.BatchNormalization( + op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): + """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" + + op_type: ClassVar = "ConvTranspose" + + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + return 1 + + def pattern(self, op, x): + return op.BatchNormalization( + op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +class FuseBatchNormIntoGemm(_FuseBatchNormBase): + """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" + + op_type: ClassVar = "Gemm" + + def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: + return ( + 0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1 + ) + + def pattern(self, op, x): + return op.BatchNormalization( + op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]), + _allow_other_inputs=True, + _outputs=["batchnorm_out"], + ) + + +fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule() +fuse_batchnorm_into_conv_transpose_rule = FuseBatchNormIntoConvTranspose().rule() +fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule() + + +rules = RewriteRuleSet( + [ + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py new file mode 100644 index 0000000000..3e617340ff --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx.checker +import onnx.parser +import parameterized + +from onnxscript import ir +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _fuse_batchnorm + + +class FuseBatchnormTest(unittest.TestCase): + def _create_batchnorm_params(self, size: int): + return [ + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="gamma" + ), + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="beta" + ), + onnx.numpy_helper.from_array( + np.random.randn(size).astype(np.float32), name="input_mean" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(size)).astype(np.float32), name="input_var" + ), + ] + + @parameterized.parameterized.expand( + [ + ("bias_false", False), + ("bias_true", True), + ] + ) + def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): + convtranspose_inputs = "X, W" + parameters = ( + "float[32, 64, 3, 3] W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + if convtranspose_bias: + parameters += ", float[64] B" + convtranspose_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) + <{parameters}> + {{ + X1 = ConvTranspose({convtranspose_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + # Add initializers + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=64), + ] + if convtranspose_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = _fuse_batchnorm.rules.apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32, 14, 16).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + @parameterized.parameterized.expand( + [ + ("bias_false", False), + ("bias_true", True), + ] + ) + def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): + conv_inputs = "X, W" + parameters = ( + "float[64, 32, 3, 3] W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + if conv_bias: + parameters += ", float[64] B" + conv_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) + <{parameters}> + {{ + X1 = Conv({conv_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + # Add initializers + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=64), + ] + if conv_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = _fuse_batchnorm.rules.apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32, 14, 16).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + @parameterized.parameterized.expand( + [ + ("bias_false_transB_0", False, 0), + ("bias_true_transB_0", True, 0), + ("bias_false_transB_1", False, 1), + ("bias_true_transB_1", True, 1), + ] + ) + def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): + gemm_inputs = "X, W" + parameters = ( + f"float{'[64, 32]' if transB else '[32, 64]'} W, " + "float[64] gamma, " + "float[64] beta, " + "float[64] input_mean, " + "float[64] input_var" + ) + + if gemm_bias: + parameters += ", float[64] B" + gemm_inputs += ", B" + + model_proto = onnx.parser.parse_model(f""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32] X) => (float [N, ?] Y) + <{parameters}> + {{ + X1 = Gemm({gemm_inputs}) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + }} + """) + weights = np.random.randn(32, 64).astype(np.float32) + if transB: + weights = weights.T + + # Add initializers + initializers = [ + onnx.numpy_helper.from_array(weights, name="W"), + *self._create_batchnorm_params(size=64), + ] + if gemm_bias: + initializers.append( + onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + ) + model_proto.graph.initializer.extend(initializers) + + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + + # Apply rule + count = _fuse_batchnorm.rules.apply_to_model(model) + + # Check that BatchNorm was fused + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + + # Check inference + testing.assert_numerically_equal( + model_proto, model, (np.random.rand(1, 32).astype(np.float32),) + ) + + output_model_proto = ir.serde.serialize_model(model) + onnx.checker.check_model(output_model_proto, True) + + def test_fuse_batchnorm_non_initializers(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X, float[64, 32, 3, 3] W, float[64] B, + float[64] gamma, float[64] beta, float[64] input_var, + float[64] input_mean) => (float [N, ?, ?, ?] Y) + { + X1 = Conv(X, W, B) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + } + """) + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + count = _fuse_batchnorm.rules.apply_to_model(model) + + # No changes were applied + self.assertEqual(count, 0) + + def test_fuse_batchnorm_graph_inputs(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X, float[64, 32, 3, 3] W) => (float [N, ?, ?, ?] Y) + { + X1 = Conv(X, W) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + } + """) + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=64), + ] + model_proto.graph.initializer.extend(initializers) + onnx.checker.check_model(model_proto, True) + + model = ir.serde.deserialize_model(model_proto) + count = _fuse_batchnorm.rules.apply_to_model(model) + + # No changes were applied as W is a graph input + self.assertEqual(count, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py new file mode 100644 index 0000000000..2aaba5cd73 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Absorbs affine operation into convolution (best effort): +- Conv(Mul(Add(x))) -> Conv (only conv without padding can be fused) +- Add(Mul(Conv)) -> Conv (for all convolutions) +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import get_const_value, get_singleton_value + + +class _ConvAffineFusionBase(pattern.RewriteRuleClassBase): + def check( + self, + context, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + if get_const_value(w) is None: + return check_result.fail("The weight of Conv should be constant") + if get_const_value(b) is None: + return check_result.fail("The bias of Conv should be constant") + if get_singleton_value(scale) is None: + return check_result.fail("Operand for Mul should be constant scalar value") + if get_singleton_value(offset) is None: + return check_result.fail("Operand for Add should be constant scalar value") + return check_result + + +class AffineConvFusion(_ConvAffineFusionBase): + """Pattern: scalar Mul + scalar Add + Conv (1x1) --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return op.Conv( + x * scale + offset, + w, + b, + pads=[0, 0, 0, 0], + _allow_other_attributes=True, + _outputs=["conv_out"], + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_value = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor( + b_value + np.sum(w_value * offset_value, axis=(1, 2, 3), keepdims=False) + ) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_value, offset_bias, **conv_attributes) + + +class ConvAffineFusion(_ConvAffineFusionBase): + """Pattern: Conv + scalar Mul + scalar Add --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return ( + op.Conv(x, w, b, _allow_other_attributes=True, _outputs=["conv_out"]) * scale + + offset + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_weight = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor(b_value * scale_value + offset_value) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_weight, offset_bias, **conv_attributes) + + +affine_conv_fusion_rule = AffineConvFusion().rule() +conv_affine_fusion_rule = ConvAffineFusion().rule() diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py new file mode 100644 index 0000000000..d456cab76b --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import rewrite, testing +from onnxscript.rewriter.rules.common import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) + + +class FuseConvAffineTest(unittest.TestCase): + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def test_conv_affine_fusion(self): + tape = ir.tape.Tape() + x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32])) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + conv_out = tape.op("Conv", [x, w, b], attributes={"pads": [1, 1, 1, 1]}) + mul_out = tape.op("Mul", [conv_out, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.val( + "z", + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1, 3, 32, 32]), + ), + ) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[z], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[conv_affine_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + def test_affine_conv_fusion_without_pad(self): + tape = ir.tape.Tape() + x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32])) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + mul_out = tape.op("Mul", [x, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.val( + "z", + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1, 3, 32, 32]), + ), + ) + conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]}) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[conv_out], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[affine_conv_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish.py b/onnxscript/rewriter/rules/common/_fuse_hardswish.py new file mode 100644 index 0000000000..6d2e8c84e1 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Div(Clip(Add(x))) -> HardSigmoid +- Mul(HardSigmoid(x), x) -> HardSwish +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import is_singleton_value +from onnxscript.rewriter._rewrite_rule import RewriteRuleSet + + +class _HardSigmoidFusionBase(pattern.RewriteRuleClassBase): + """HardSwish requires constant values so we check in base class.""" + + def check( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + + if not is_singleton_value(clip_min, 0.0, rtol=1e-4): + return check_result.fail("Swish requires min value of 0 for clip") + if not is_singleton_value(clip_max, 6.0, rtol=1e-4): + return check_result.fail("Swish requires max value of 6 for clip") + if not is_singleton_value(bias, 3.0, rtol=1e-4): + return check_result.fail("Swish requires bias value of 3") + if not is_singleton_value(divisor, 6.0, rtol=1e-4): + return check_result.fail("Swish requires divisor value of 6") + return check_result + + +class HardSwishFusion(_HardSigmoidFusionBase): + """Fuse Add(_, 3) + Clip<0, 6>(_) + Mul + Div(_, 6) into HardSwish + + In this case we can't make HardSigmoid fusion first. The Mul + is placed before Div while HardSigmoid requires Add+Clip+Div. + """ + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) * x + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSwish(x) + + +class HardSwishFusionFromHardSigmoid(pattern.RewriteRuleClassBase): + """Fuse HardSigmoid + Mul into HardSwish""" + + def pattern(self, op, x: ir.Value) -> ir.Value: + # Floating point matching for 1/6 is not exact, so we use isclose below + out = op.HardSigmoid(x, _allow_other_attributes=True, _outputs=["hardsigmoid_out"]) + out = out * x + return out + + def check(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> MatchResult: + check_result = MatchResult() + hardsigmoid = hardsigmoid_out.producer() + # Use getter to protect when 'alpha' / 'beta' is not in attributes + alpha = hardsigmoid.attributes.get_float("alpha", -1) + beta = hardsigmoid.attributes.get_float("beta", -1) + if not np.isclose(alpha, 1 / 6): + return check_result.fail( + "HardSigmoid alpha must be 1/6 to get fused into HardSwish" + ) + if not np.isclose(beta, 0.5): + return check_result.fail( + "HardSigmoid beta must be 0.5 to get fused into HardSwish" + ) + return check_result + + def rewrite(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> ir.Value: + return op.HardSwish(x) + + +class HardSigmoidFusion(_HardSigmoidFusionBase): + """Fuse HardSigmoid only for HardSwish hyper-parameters: alpha=1/6, beta=0.5""" + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSigmoid(x, alpha=1 / 6, beta=0.5) + + +def fuse_hardswish_rules() -> RewriteRuleSet: + """Returns the rewrite rules for fusing HardSwish and HardSigmoid.""" + return RewriteRuleSet( + [ + HardSwishFusion().rule(), + HardSigmoidFusion().rule(), + HardSwishFusionFromHardSigmoid().rule(), + ], + commute=True, + ) diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py new file mode 100644 index 0000000000..36556e9cff --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +import onnxruntime as ort +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript import optimizer +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import fuse_hardswish_rules + + +class FuseHardSwishTest(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250621) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = fuse_hardswish_rules().apply_to_model(updated_model) + + # Polish model to remove unused constants + updated_model = optimizer.optimize(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = (self.rng.integers(low=-10, high=10, size=(2 * 32), dtype=np.int32),) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ort_optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + + # Validate serialized model + output_model_proto = ir.to_proto(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_hardsigmoid_fusion(self): + model_text = """ + + hardsigmoid (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + y = Div(clipped, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSigmoid"]) + + def test_hardswish_fusion(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + mul_x = Mul(clipped, x) + y = Div(mul_x, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_mul_last(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + div_x = Div(clipped, six) + y = Mul(div_x, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_from_sigmoid(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + hardsigmoid_out = HardSigmoid(x) + y = Mul(hardsigmoid_out, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py new file mode 100644 index 0000000000..39aab00eda --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses Pad nodes into preceding nodes. Supported fusion patterns: +- Conv ∘ Pad -> Conv +- ConvInteger ∘ Pad -> ConvInteger + +To make some rules possible, we implicitly transform `auto_pad` attribute into its explicit list. +""" + +from __future__ import annotations + +from typing import List, Sequence + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern as orp + + +def fill_pads_with_axes(pads: Sequence[int], axes: Sequence[int], rank: int) -> List[int]: + """Converts the parameters of the ONNX Pad operator into an explicit list of values. + + A filled list of pads will be returned following the format: + [x1_begin, x2_begin, ..., x{rank}_begin, x1_end, x2_end, ..., x{rank}_end] + + Args: + pads: list of integers indicating the number of padding elements to add at + the beginning and end of each axis. + axes: list of axes that pads apply to. + rank: value to compute the size of the filled list (2 * rank). + + Returns: + The filled list of pads. + """ + new_pads = [0] * 2 * rank + N = len(axes) + for start_idx, axis in enumerate(axes): + new_pads[axis] = pads[start_idx] + new_pads[axis + rank] = pads[start_idx + N] + return new_pads + + +def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]: + # Read attributes + attributes = {} + ir_attributes = ir_conv.attributes + attributes["kernel_shape"] = ir_attributes.get_ints( + "kernel_shape", ir_conv.inputs[1].shape[2:] + ) + attributes["strides"] = ir_attributes.get_ints( + "strides", [1] * len(ir_conv.inputs[0].shape[2:]) + ) + attributes["auto_pad"] = ir_attributes.get_string("auto_pad", "NOTSET") + if "pads" in ir_attributes: + attributes["pads"] = ir_attributes.get_ints("pads") + return attributes + + +class _FuseConvPadBase(orp.RewriteRuleClassBase): + """Interface for PadConv nodes fusion.""" + + def __init__(self, as_function: bool = False): + # Remove nodes is set to False to remove unused nodes after the rewrite, since + # Pad or Conv inputs can come from constant nodes. + # With remove_nodes=False these nodes are removed if these nodes are no longer needed. + super().__init__(remove_nodes=False, as_function=as_function) + + def rewrite( + self, op: ir.tape.Tape, x: ir.Value, pad: ir.Value, conv: ir.Value + ) -> ir.Value: + conv_node = conv.producer() + + # Retrieve the padding and axes + x_rank = len(x.shape) + + # Get computed pads in check() + pad_pads = self._pads_list + + # Get only spatial pads + new_pads = pad_pads[2:x_rank] + pad_pads[x_rank + 2 :] + + # Replace conv pads = new + old + conv_attr = conv_node.attributes.copy() + if "pads" in conv_attr: + new_pads = [x + y for x, y in zip(conv_attr["pads"].as_ints(), new_pads)] + conv_attr.add(ir.AttrInt64s("pads", new_pads)) + + return op.op( + conv_node.op_type, + inputs=(x, *conv_node.inputs[1:]), + attributes=conv_attr, + domain=conv_node.domain, + name=conv_node.name, + ) + + def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult: + """Condition to check if we need to replace the pattern. + + If Pad inputs can be added in 'pads' attribute of the Conv operator. + + To validate this, we need to check the following: + 1. `Pad` attribute has 'constant' as value + 2. `Pad` operator inputs are constants ('pads', 'constant_value', 'axes') + 3. 'constant_value' is equal to 0.0. + 4. `Pad` operator is only used for the spatial dimensions (batch dimension and channels + remain unchanged). + + If the above are true, then we don't need the reshapes. + + Returns: + True if we need to replace the pattern, False otherwise. + """ + del context # Unused + check_result = orp.MatchResult() + pad_node = pad.producer() + if x.shape is None: + return check_result.fail( + f"Input shapes are not defined on {pad_node.name} ({pad_node.op_type})." + ) + x_rank = len(x.shape) + + # Pad constraints: attributes + if (mode := pad_node.attributes.get("mode", None)) and mode.as_string() != "constant": + return check_result.fail( + f"{pad_node.name} ({pad_node.op_type}) mode must be 'constant'." + ) + + # Pad constraints: inputs + if (pads := pad_node.inputs[1]).const_value is None: + return check_result.fail(f"{pads.name} is not a constant/initializer.") + if len(pad_node.inputs) > 2 and (constant_value := pad_node.inputs[2]) is not None: + if constant_value.const_value is None: + return check_result.fail( + f"{constant_value.name} is not a constant/initializer." + ) + elif constant_value.const_value.numpy().item() != 0: + return check_result.fail(f"{constant_value.name} must be equal to 0.") + if len(pad_node.inputs) > 3 and (axes := pad_node.inputs[3]) is not None: + if axes.const_value is None: + return check_result.fail(f"{axes.name} is not a constant/initializer.") + axes_list = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()] + else: + axes_list = list(range(x_rank)) + + # Pad constraints: values + self._pads_list = fill_pads_with_axes(pads.const_value.numpy(), axes_list, x_rank) + if np.any(self._pads_list[:2] + self._pads_list[x_rank : x_rank + 2]): + self._pads_list = None + return check_result.fail(f"{pads.name} must be zero in non-spatial dimensions.") + + return check_result + + +class FuseConvPad(_FuseConvPadBase): + """Replaces ``Conv(Pad(x))`` with ``Conv(x)``.""" + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.Conv( + op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]), + _allow_other_inputs=True, + _outputs=["conv"], + ) + + def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult: + check_result = super().check(context, x, pad, conv) + if not check_result: + return check_result + + # Conv constraints: attributes + conv_node = conv.producer() + if conv_node.attributes.get_string("auto_pad", "NOTSET") != "NOTSET": + return check_result.fail( + f"{conv_node.name} ({conv_node.op_type}) auto_pad must be 'NOTSET'." + ) + return check_result + + +class FuseConvIntegerPad(FuseConvPad): + """Replaces ``ConvInteger(Pad(x))`` with ``ConvInteger(x)``.""" + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.ConvInteger( + op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]), + _allow_other_inputs=True, + _outputs=["conv"], + ) + + +class _NormalizePadFormatBase(orp.RewriteRuleClassBase): + """Interface to normalize pad attributes in conv nodes.""" + + @staticmethod + def compute_pads( + input_shape: Sequence[int], + output_shape: Sequence[int], + attributes: dict[str, Sequence[int] | str], + ) -> Sequence[int]: + raise NotImplementedError("Child have to implement this function") + + def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value: + conv_node = conv.producer() + + # Read spatial dimensions and attributes + input_shape = conv_node.inputs[0].shape[2:] + output_shape = conv_node.outputs[0].shape[2:] + attributes = read_conv_attributes(conv_node) + + # Convert auto_pad mode into an explicit list + pads = self.compute_pads(input_shape, output_shape, attributes) + + # Replace auto_pad, forcing to the explicit list + conv_attr = conv_node.attributes.copy() + conv_attr.add(ir.AttrString("auto_pad", "NOTSET")) + if any(x != 0 for x in pads): + conv_attr.add(ir.AttrInt64s("pads", pads)) + + return op.op( + conv_node.op_type, + inputs=conv_node.inputs, + attributes=conv_attr, + domain=conv_node.domain, + name=conv_node.name, + ) + + def check(self, context, conv: ir.Value, **__) -> orp.MatchResult: + """Condition to check if we need to replace the pattern. + + If it is possible to deduce 'pads'. + + To validate this, we need to check the following: + 1. `Conv` (nothing to do in this case, since 'pads' are + already explicit) + 2. it is possible to deduce the input rank when `Conv` + 3. When `Conv`: + * spatial input/output shapes are static + * it is possible to infer `kernel_shape` either from the `Conv` operator attribute + or from the kernel input + + If the above are true, then we don't need the reshapes. + + Returns: + True if we need to replace the pattern, False otherwise. + """ + del context + check_result = orp.MatchResult() + + # Conv constraints: attributes + conv_node = conv.producer() + auto_pad = conv_node.attributes.get_string("auto_pad", None) + if auto_pad in {None, "NOTSET"}: + return check_result.fail( + f"{conv_node.name} ({conv_node.op_type}) auto_pad must be different to 'NOTSET'." + ) + + # Conv constraints: inputs/outputs + input_shape = conv_node.inputs[0].shape + output_shape = conv_node.outputs[0].shape + if input_shape is None or len(input_shape) <= 2: + return check_result.fail( + f"Input shapes are not defined on {conv_node.name} ({conv_node.op_type})." + ) + if output_shape is None or len(output_shape) <= 2: + return check_result.fail( + f"Output shapes are not defined on {conv_node.name} ({conv_node.op_type})." + ) + + # Conv constraints: values + if auto_pad != "VALID": + error_msg = ( + "Expected static spatial {} shapes on " + + conv_node.name + + f" ({conv_node.op_type})." + ) + if not all(isinstance(x, int) for x in input_shape[2:]): + return check_result.fail(error_msg.format("input")) + if not all(isinstance(x, int) for x in output_shape[2:]): + return check_result.fail(error_msg.format("output")) + attributes = read_conv_attributes(conv_node) + if len(attributes["kernel_shape"]) != len(attributes["strides"]): + return check_result.fail( + "strides must have the same length than kernel_shape on " + f"{conv_node.name} ({conv_node.op_type})." + ) + return check_result + + +class NormalizePadFormatConv(_NormalizePadFormatBase): + """Convert auto_pad attribute into 'NOTSET' in Conv nodes .""" + + @staticmethod + def compute_pads( + input_shape: Sequence[int], + output_shape: Sequence[int], + attributes: dict[str, Sequence[int] | str], + ) -> Sequence[int]: + # Compute pads, following auto_pad/pads attributes + if attributes["auto_pad"] in {"NOTSET", "VALID"}: + assert len(input_shape) > 0 + return attributes.get("pads", [0] * len(input_shape) * 2) + + bottom_pads, top_pads = [], [] + kernel_shape, strides = attributes["kernel_shape"], attributes["strides"] + assert len(kernel_shape) == len(strides) == len(input_shape) == len(output_shape) + for x, y, k, s in zip(input_shape, output_shape, kernel_shape, strides): + # Compute the output shape and the total padding to apply + total_pads = max(0, (y - 1) * s + k - x) + + # Depending of mode, apply the padding to the upper or lower part + pad1 = total_pads // 2 + pad2 = total_pads - pad1 + if attributes["auto_pad"] == "SAME_UPPER": + bottom_pads.append(pad1) + top_pads.append(pad2) + else: + top_pads.append(pad1) + bottom_pads.append(pad2) + return bottom_pads + top_pads + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.Conv(x, _allow_other_inputs=True, _outputs=["conv"]) + + +class NormalizePadFormatConvInteger(NormalizePadFormatConv): + """Convert auto_pad attribute into 'NOTSET' in ConvInteger nodes .""" + + def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: + return op.ConvInteger(x, _allow_other_inputs=True, _outputs=["conv"]) + + +normalize_pad_format_conv_rule = NormalizePadFormatConv.rule() +normalize_pad_format_conv_integer_rule = NormalizePadFormatConvInteger.rule() +fuse_pad_into_conv_rule = FuseConvPad.rule() +fuse_pad_into_conv_integer_rule = FuseConvIntegerPad.rule() + + +rules = orp.RewriteRuleSet( + [ + normalize_pad_format_conv_rule, + normalize_pad_format_conv_integer_rule, + fuse_pad_into_conv_rule, + fuse_pad_into_conv_integer_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py new file mode 100644 index 0000000000..ded57fe023 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -0,0 +1,406 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from typing import Mapping, Sequence + +import numpy as np +import onnx_ir as ir +import parameterized +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript.rewriter import pattern as orp +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _fuse_pad_into_conv +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( + fuse_pad_into_conv_rule, + normalize_pad_format_conv_rule, +) + + +def _clone_model(model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + +class FuseConvPadBaseTest(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250522) + + def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None): + w = ir.tensor(self.rng.uniform(-0.5, 0.5, shape).astype("float32"), name="W") + if tape is not None: + w = tape.initializer(w) + return w + + def build_model( + self, + op_type: str, + input_shape: ir.Shape, + weight_shape: Sequence[int], + pad_inputs: Sequence[ir.TensorProtocol | ir.Value | None], + pad_attributes: Mapping[str, ir.Attr] | None = None, + conv_attributes: Mapping[str, ir.Attr] | None = None, + ) -> ir.Model: + tape = ir.tape.Tape() + inputs = [] + output_shape = ir.Shape((input_shape[0],) + ("?",) * (len(input_shape) - 1)) + + # Convert pad_inputs to initializers (if needed) + pad_inputs = list(pad_inputs) + for idx, x in enumerate(pad_inputs): + if isinstance(x, ir.TensorProtocol): + pad_inputs[idx] = tape.initializer(x) + elif isinstance(x, ir.Value): + inputs.append(x) + elif isinstance(x, float): + pad_inputs[idx] = tape.op("Constant", inputs=[], attributes={"value_float": x}) + elif x is not None: + raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") + + # Register operations in the tape + idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT + x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype)) + y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes) + y = tape.op( + op_type, + inputs=[y, self.get_conv_weights(weight_shape, tape)], + attributes=conv_attributes, + output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)), + ) + if op_type == "ConvInteger": + y.dtype = ir.DataType.INT32 + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x, *inputs], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="model", + ), + ir_version=10, + ) + onnx_checker.CheckerPass(True)(ir_model) + ir_model = shape_inference.infer_shapes(ir_model) + return ir_model + + +class FuseConvPadTest(FuseConvPadBaseTest): + @parameterized.parameterized.expand( + [ + (pad_pads, const_value, axes, conv_pads, conv_auto_pad) + for pad_pads, axes, conv_pads, conv_auto_pad in [ + ([0, 0, 2, 2, 0, 0, 2, 2], None, None, None), + ([0, 2, 2, 0, 2, 2], ir.tensor([1, -2, -1], name="axes"), [2, 0, 2, 0], None), + ([1, 1, 1, 1], ir.tensor([-2, 3], name="axes"), [0, 1, 0, 1], None), + ([1, 3, 1, 3], ir.tensor([3, 2], name="axes"), None, "VALID"), + ] + for const_value in [None, 0.0] + ] + ) + def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads, conv_auto_pad): + pad_inputs = [ir.tensor(pad_pads, name="pads")] + if const_value is not None or axes is not None: + pad_inputs.append(const_value) + if axes is not None: + pad_inputs.append(axes) + base_model = self.build_model( + op_type="Conv", + input_shape=ir.Shape(("N", 32, 14, 16)), + weight_shape=(10, 32, 3, 3), + pad_inputs=pad_inputs, + conv_attributes={"pads": conv_pads, "auto_pad": conv_auto_pad}, + ) + updated_model = _clone_model(base_model) + + # Apply rule + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) + + # Check that Pad was fused + self.assertEqual(count, 1 if conv_auto_pad is None else 2) + self.assertEqual(updated_model.graph.num_nodes(), 1) + onnx_checker.CheckerPass(True)(updated_model) + + # Check inference + inputs = self.rng.random((1, 32, 14, 16), dtype="float32") + testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + ( + "constant", + ir.tensor([1] * 10, name="pads"), + ir.tensor([0.0], name="const_value"), + None, + "NOTSET", + "must be zero in non-spatial dimensions", + ), + ( + "constant", + ir.tensor([0, 0, 0, 0], name="pads"), + ir.tensor([1.0], name="const_value"), + ir.tensor([0, -1], name="axes"), + "NOTSET", + "must be equal to 0.", + ), + ( + "edge", + ir.tensor([0, 0, 0, 0], name="pads"), + ir.tensor([0.0], name="const_value"), + ir.tensor([0, -1], name="axes"), + "NOTSET", + "mode must be 'constant'.", + ), + ( + "constant", + ir.Value( + name="pads", shape=ir.Shape([4]), type=ir.TensorType(ir.DataType.INT64) + ), + None, + ir.tensor([0, -1], name="axes"), + "NOTSET", + "pads is not a constant/initializer.", + ), + ( + "constant", + ir.tensor([0] * 10, name="pads"), + ir.Value( + name="cval", shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.FLOAT) + ), + None, + "NOTSET", + "cval is not a constant", + ), + ( + "constant", + ir.tensor([0, 0, 0, 0], name="pads"), + None, + ir.Value( + name="axes", shape=ir.Shape([2]), type=ir.TensorType(ir.DataType.INT64) + ), + "NOTSET", + "axes is not a constant", + ), + ( + "constant", + ir.tensor([0, 0, 0, 0], name="pads"), + ir.tensor([0.0], name="const_value"), + ir.tensor([0, -1], name="axes"), + "VALID", + "auto_pad must be 'NOTSET'.", + ), + ] + ) + def test_unsupported_fuse_pad_into_conv( + self, mode, pads, const_value, axes, auto_pad, err_msg + ): + base_model = self.build_model( + op_type="Conv", + input_shape=ir.Shape(("N", 32, 14, 16, 12)), + weight_shape=(10, 32, 3, 4, 5), + pad_inputs=[pads, const_value, axes], + pad_attributes={"mode": mode}, + conv_attributes={"auto_pad": auto_pad}, + ) + + # Apply rule and check it was not applied + tracer = orp.MatchingTracer() + count = fuse_pad_into_conv_rule.apply_to_model(base_model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[fuse_pad_into_conv_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, err_msg) + + +class FuseConvIntegerPadTest(FuseConvPadBaseTest): + def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None): + w = ir.tensor(self.rng.integers(0, 256, shape).astype("uint8"), name="W") + if tape is not None: + w = tape.initializer(w) + return w + + @parameterized.parameterized.expand( + [ + (pad_pads, const_value, axes, conv_pads, conv_auto_pad) + for pad_pads, axes, conv_pads, conv_auto_pad in [ + ([0, 0, 3, 2, 0, 0, 1, 4], None, [1, 1, 1, 1], None), + ([2, 2, 0, 2, 2, 0], ir.tensor([-2, -1, 1], name="axes"), None, None), + ([1, 2, 2, 1], ir.tensor([-1, 2], name="axes"), [0, 1, 0, 1], None), + ([3, 3], ir.tensor([2], name="axes"), None, "SAME_UPPER"), + ] + for const_value in [None, ir.tensor(np.array([0], "uint8"), name="const_value")] + ] + ) + def test_fuse_pad_into_conv_integer( + self, pad_pads, const_value, axes, conv_pads, conv_auto_pad + ): + pad_inputs = [ir.tensor(pad_pads, name="pads")] + if const_value is not None or axes is not None: + pad_inputs.append(const_value) + if axes is not None: + pad_inputs.append(axes) + base_model = self.build_model( + op_type="ConvInteger", + input_shape=ir.Shape(("N", 24, 19, 23)), + weight_shape=(8, 24, 3, 3), + pad_inputs=pad_inputs, + conv_attributes={"pads": conv_pads, "auto_pad": conv_auto_pad}, + ) + updated_model = _clone_model(base_model) + + # Apply rule + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) + + # Check that Pad was fused + self.assertEqual(count, 1 if conv_auto_pad is None else 2) + self.assertEqual(updated_model.graph.num_nodes(), 1) + onnx_checker.CheckerPass(True)(updated_model) + + # Check inference + inputs = self.rng.integers(0, 255, (1, 24, 19, 23), dtype="uint8") + testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0) + + +class NormalizePadFormatTest(FuseConvPadBaseTest): + def build_model( + self, + input_shape: ir.Shape, + conv_inputs: Sequence[int], + conv_attributes: Mapping[str, ir.Attr] | None = None, + infer_shapes=True, + ) -> ir.Model: + tape = ir.tape.Tape() + inputs = [] + output_shape = ir.Shape(("?",) * len(input_shape)) + + # Convert conv_inputs to initializers (if needed) + conv_inputs = list(conv_inputs) + for idx, x in enumerate(conv_inputs): + if isinstance(x, ir.TensorProtocol): + conv_inputs[idx] = tape.initializer(x) + elif isinstance(x, ir.Value): + inputs.append(x) + elif x is not None: + raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") + + # Register operations in the tape + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + y = tape.op( + "Conv", + inputs=[x, *conv_inputs], + attributes=conv_attributes, + output=ir.val("Y", shape=output_shape, type=x.type), + ) + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x, *inputs], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="model", + ), + ir_version=10, + ) + if len(input_shape) > 0 and infer_shapes: + onnx_checker.CheckerPass(True)(ir_model) + ir_model = shape_inference.infer_shapes(ir_model) + else: + onnx_checker.CheckerPass(False)(ir_model) + return ir_model + + @parameterized.parameterized.expand( + [ + (dynamic_shape, strides, kernel_shape, auto_pad) + for strides, kernel_shape in [((2, 3), (1, 4)), ((2, 1), (5, 2))] + for dynamic_shape, auto_pad in [ + (False, "SAME_UPPER"), + (False, "SAME_LOWER"), + (True, "VALID"), + ] + ] + ) + def test_normalize_pad_format(self, dynamic_shape, strides, kernel_shape, auto_pad): + input_shape = ( + ir.Shape(("N", "A", "B", "C")) if dynamic_shape else ir.Shape(("N", 32, 22, 27)) + ) + base_model = self.build_model( + input_shape=input_shape, + conv_inputs=[ir.tensor(self.get_conv_weights((32, 32, *kernel_shape)), name="W")], + conv_attributes={ + "strides": strides, + "auto_pad": auto_pad, + "kernel_shape": kernel_shape, + }, + ) + updated_model = _clone_model(base_model) + + # Apply rule + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) + onnx_checker.CheckerPass(True)(updated_model) + + # Check conv has changed + self.assertEqual(count, 1) + self.assertEqual(updated_model.graph[0].attributes.get_string("auto_pad"), "NOTSET") + + # Check inference + inputs = self.rng.random((1, 32, 22, 27), dtype="float32") + testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + (ir.Shape([]), False, "Input shapes are not defined"), + (ir.Shape(("N", "C", "A")), False, "Expected static spatial input shapes"), + (ir.Shape(("N", "C", 32)), False, "Expected static spatial output shapes"), + ] + ) + def test_unsupported_normalize_pad_format(self, input_shape, infer_shapes, error_msg): + base_model = self.build_model( + input_shape=input_shape, + conv_inputs=[ir.tensor(np.ones((32, 11, 4)), name="W")], + conv_attributes={"auto_pad": "SAME_UPPER"}, + infer_shapes=infer_shapes, + ) + + # Apply rule and check it was not applied + tracer = orp.MatchingTracer() + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, error_msg) + + def test_unsupported_normalize_pad_format_on_weights(self): + W = ir.Value(name="W", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.FLOAT)) + base_model = self.build_model( + input_shape=ir.Shape(("N", 2, 32)), + conv_inputs=[W], + conv_attributes={"auto_pad": "SAME_UPPER"}, + infer_shapes=False, + ) + # Set output shape to analyze error due to weights + base_model.graph[0].outputs[0].shape = ir.Shape(("N", 10, 32)) + + # Apply rule and check it was not applied + tracer = orp.MatchingTracer() + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, "same length than kernel_shape") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_relus_clips.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py new file mode 100644 index 0000000000..5d294cdbd7 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Relu(Relu(X)) -> Relu +- Relu(Clip(X)) -> Clip +- Clip(Relu(X)) -> Clip +- Clip(Clip(X)) -> Clip +""" + +from __future__ import annotations + +import abc + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class FuseSuccessiveRelu(RewriteRuleClassBase): + """Replaces ``Relu(Relu(X))`` with ``Relu(X)``.""" + + def rewrite(self, op, x): + return op.Relu(x) + + def pattern(self, op, x): + return op.Relu(op.Relu(x)) + + +class _FuseReluClipBase(RewriteRuleClassBase, abc.ABC): + def rewrite(self, op, x, **kwargs): + first_clip_node = kwargs.get("out_first_clip").producer() + second_clip_node = None + + if out_second_clip := kwargs.get("out_second_clip"): + second_clip_node = out_second_clip.producer() + + min_clip, max_clip = self.compute_clip_min_max(first_clip_node, second_clip_node) + clip_min_max = [] + + if min_clip is not None: + clip_min_max.append( + op.initializer(min_clip, name=f"{first_clip_node.inputs[0].name}_min") + ) + + if max_clip is not None: + # ONNX Clip expects min and max inputs in order. + # If min is not provided, we insert None to maintain correct argument positions. + if min_clip is None: + clip_min_max.append(None) + + clip_min_max.append( + op.initializer(max_clip, name=f"{first_clip_node.inputs[0].name}_max") + ) + + return op.Clip(x, *clip_min_max) + + @abc.abstractmethod + def compute_clip_min_max( + self, first_clip_node: ir.Node, second_clip_node: ir.Node | None = None + ): + pass + + def extract_min_max(self, node: ir.Node): + # Infer dtype from node first input + dtype = node.inputs[0].dtype.numpy() + min_clip, max_clip = None, None + + if len(node.inputs) > 1: + min_input = node.inputs[1] + # If only a max is provided, min is implicitly None, so we check that + if min_input is not None: + min_clip = min_input.const_value.numpy() + + if len(node.inputs) > 2: + max_clip = node.inputs[2].const_value.numpy() + + return min_clip, max_clip, dtype + + def check(self, context, **kwargs): + """Condition to check if we need to replace the pattern. + + The pattern is applied only when the min and max inputs of the Clip nodes are + not graph inputs and are constant values (i.e., provided by Constant nodes or initializers). + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Unused + check_result = MatchResult() + + # Check if Clip min/max are not graph inputs and are constant values + clip_min_max = [] + + first_clip_node = kwargs.get("out_first_clip").producer() + clip_min_max.extend([inp for inp in first_clip_node.inputs[1:] if inp is not None]) + + if out_second_clip := kwargs.get("out_second_clip"): + second_clip_node = out_second_clip.producer() + clip_min_max.extend( + [inp for inp in second_clip_node.inputs[1:] if inp is not None] + ) + + for m in clip_min_max: + if m.is_graph_input(): + return check_result.fail(f"{m.name} is a graph input.") + + if ir.convenience.get_const_tensor(m) is None: + return check_result.fail(f"{m.name} is not a constant.") + + return check_result + + +class FuseSuccessiveClip(_FuseReluClipBase): + """Replaces ``Clip(Clip(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Clip( + op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"]), + _allow_other_inputs=True, + _outputs=["out_second_clip"], + ) + + def compute_clip_min_max(self, first_clip_node: ir.Node, second_clip_node: ir.Node): + min_clip1, max_clip1, dtype = self.extract_min_max(first_clip_node) + min_clip2, max_clip2, _ = self.extract_min_max(second_clip_node) + + def combine(val1, val2, op): + if val1 is not None and val2 is not None: + return ir.tensor(np.array(op(val1, val2), dtype=dtype)) + elif val1 is not None: + return ir.tensor(val1) + elif val2 is not None: + return ir.tensor(val2) + return None + + min_clip = combine(min_clip1, min_clip2, np.maximum) + max_clip = combine(max_clip1, max_clip2, np.minimum) + + return min_clip, max_clip + + +class FuseSuccessiveClipRelu(_FuseReluClipBase): + """Replaces ``Clip(Relu(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Clip(op.Relu(x), _allow_other_inputs=True, _outputs=["out_first_clip"]) + + def compute_clip_min_max(self, first_clip_node: ir.Node, _): + min_clip, max_clip, dtype = self.extract_min_max(first_clip_node) + + if min_clip is None: + # The minimum clipping value is implicitly 0 (Relu clamps at 0) + min_clip = 0 + + min_clip = ir.tensor(np.array(np.maximum(0.0, min_clip), dtype=dtype)) + + if max_clip is not None: + max_clip = ir.tensor(max_clip) + return min_clip, max_clip + + +class FuseSuccessiveReluClip(FuseSuccessiveClipRelu): + """Replaces ``Relu(Clip(X))`` with ``Clip(X)``.""" + + def pattern(self, op, x): + return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"])) + + +successive_relu_rule = FuseSuccessiveRelu().rule() +successive_clip_rule = FuseSuccessiveClip().rule() +successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() +successive_relu_clip_rule = FuseSuccessiveReluClip().rule() + + +rules = RewriteRuleSet( + [ + successive_clip_relu_rule, + successive_relu_clip_rule, + successive_relu_rule, + successive_clip_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py new file mode 100644 index 0000000000..df2d669930 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py @@ -0,0 +1,371 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +import onnxruntime as ort +import parameterized +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript.rewriter import ( + MatchingTracer, + MatchStatus, + RewriteRule, + testing, +) +from onnxscript.rewriter.rules.common import _fuse_relus_clips +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, +) + + +class _FuseReluClipTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250621) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = _fuse_relus_clips.rules.apply_to_model(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = (self.rng.integers(low=-10, high=10, size=(2, 32, 14), dtype=np.int32),) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + # onnxruntime has an optimization that fuses Clip(Relu) and + # it doesn't support int data, that's why we disable ort optimization + # see https://github.com/microsoft/onnxruntime/blob/c98a0e014b641e289ed25f42b792bca1893ccb03/onnxruntime/core/optimizer/relu_clip_fusion.cc#L60 + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ort_optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class FuseSuccessiveReluTest(_FuseReluClipTestBase): + def test_successful_fuse_successive_relus(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + x1 = Relu(X) + x2 = Relu(x1) + Y = Relu(x2) + } + """) + self.run_test(model, expected_op_types=["Relu"]) + + +class FuseSuccessiveReluClipTest(_FuseReluClipTestBase): + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min, max) + """, + "float", + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min, max) + Y = Relu(x1) + """, + "float", + ), + ( + "int_relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min, max) + """, + "int32", + ), + ( + "int_clip_then_relu", + """ + x1 = Clip(X, min, max) + Y = Relu(x1) + """, + "int32", + ), + ] + ) + def test_successful_fuse_successive_relu_clip(self, _, nodes, dtype): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) + <{dtype} min = {{1}}, {dtype} max = {{6}}> + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Clip"], dtype=dtype) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + min = Constant() + Y = Clip(x1, min) + """, + ), + ( + "clip_then_relu", + """ + min = Constant() + x1 = Clip(X, min) + Y = Relu(x1) + """, + ), + ] + ) + def test_successful_fuse_successive_relu_clip_constant_nodes(self, _, nodes): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float[N, ?, ?] Y) + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Constant", "Clip"]) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1,,max) + """, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X,,max) + Y = Relu(x1) + """, + ), + ] + ) + def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Clip"]) + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min) + """, + successive_clip_relu_rule, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min) + Y = Relu(x1) + """, + successive_relu_clip_rule, + ), + ] + ) + def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + {{ + min = ReduceMean(X) + {nodes} + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is not a constant.") + + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1, min) + """, + successive_clip_relu_rule, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X, min) + Y = Relu(x1) + """, + successive_relu_clip_rule, + ), + ] + ) + def test_fail_fuse_successive_relu_clip_graph_inputs(self, _, nodes, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X, float min) => (float [N, ?, ?] Y) + {{ + {nodes} + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is a graph input.") + + +class FuseSuccessiveClipTest(_FuseReluClipTestBase): + @parameterized.parameterized.expand( + [ + ("float", "float"), + ("int32", "int32"), + ] + ) + def test_successful_fuse_successive_clips(self, _, dtype): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14] X) => ({dtype} [N, ?, ?] Y) + <{dtype} max1 = {{4}}, {dtype} min2 = {{0}}, + {dtype} max2 = {{11}}, {dtype} min3 = {{1}}, + {dtype} max3 = {{7}}, {dtype} max4 = {{13}}> + {{ + x1 = Clip(X) + x2 = Clip(x1,,max1) + x3 = Clip(x2, min2, max2) + x4 = Clip(x3, min3, max3) + x5 = Clip(x4,,max4) + Y = Clip(x5) + }} + """) + self.run_test(model, expected_op_types=["Clip"], dtype=dtype) + + def test_successful_fuse_successive_clips_node_constants(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + min1 = Constant() + max1 = Constant() + min2 = Constant() + max2 = Constant() + x1 = Clip(X, min1, max1) + Y = Clip(x1, min2, max2) + } + """) + self.run_test( + model, expected_op_types=["Constant", "Constant", "Constant", "Constant", "Clip"] + ) + + def test_successful_fuse_successive_clips_no_min(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + x1 = Clip(X,, max1) + Y = Clip(x1,, max2) + } + """) + self.run_test(model, expected_op_types=["Clip"]) + + def test_fail_fuse_successive_clips_non_initializers(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + min1 = ReduceMean(X) + min2 = ReduceMax(X) + x1 = Clip(X, min1) + Y = Clip(x1, min2) + } + """) + self.run_failed_condition_test(model, successive_clip_rule, "is not a constant.") + + def test_fail_fuse_successive_clips_graph_inputs(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X, float min1, float min2) => (float [N, ?, ?] Y) + + { + x1 = Clip(X, min1) + Y = Clip(x1, min2) + } + """) + self.run_failed_condition_test(model, successive_clip_rule, "is a graph input.") + + +class FuseReluClipIntegrationTest(_FuseReluClipTestBase): + def test_successful_full_chain_fusion(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + { + x1 = Relu(X) + x2 = Relu(x1) + x3 = Relu(x2) + x4 = Relu(x3) + x5 = Clip(x4) + x6 = Relu(x5) + Y = Clip(x6) + } + """) + self.run_test(model, expected_op_types=["Clip"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py similarity index 51% rename from onnxscript/rewriter/gemm_to_matmul_add.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py index 95cb82e300..e51b4b22fa 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py @@ -1,11 +1,11 @@ -from onnxscript.rewriter import pattern -from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape - -op = pattern.onnxop +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from onnxscript.rewriter._rewrite_rule import RewriteRule +from onnxscript.rewriter.rules.common._broadcast_to_matmul import check_if_not_need_reshape # Pattern to match against -def reshape_gemm_reshape_pattern(input_a, input_b, input_c, shape_a, shape_c): +def reshape_gemm_reshape_pattern(op, input_a, input_b, input_c, shape_a, shape_c): reshape_a = op.Reshape(input_a, shape_a) # TODO: Temporary workaround to support benchmodels. # Tracked by https://github.com/microsoft/onnx-rewriter/issues/197. @@ -18,4 +18,6 @@ def matmul_add(op, input_a, input_b, input_c, **_): return op.Add(matmul, input_c) -rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) +gemm_to_matmul_add_rule = RewriteRule( + reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape +) diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py similarity index 91% rename from onnxscript/rewriter/gemm_to_matmul_add_test.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py index cb285036b6..90551d8d3b 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py @@ -1,9 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import onnx.parser from onnxscript import ir -from onnxscript.rewriter import gemm_to_matmul_add +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add class ReshapeGemmReshapeTest(unittest.TestCase): @@ -23,7 +25,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable(self): ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -68,7 +70,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable_in_nested ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -92,7 +94,7 @@ def test_reshape_gemm_reshape_remain_when_input_last_dim_and_second_last_dim_not """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -113,7 +115,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_not_broadcastable( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -134,7 +136,7 @@ def test_reshape_gemm_reshape_replace_when_inputs_are_broadcastable_with_one_in_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -157,7 +159,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -180,7 +182,7 @@ def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_b """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -201,7 +203,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -224,7 +226,7 @@ def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -245,7 +247,7 @@ def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broad """ ) model = ir.serde.deserialize_model(model_proto) - replacement_count = gemm_to_matmul_add.rule.apply_to_model(model) + replacement_count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(replacement_count, 1) self.assertEqual(len(model.graph), 4) @@ -266,7 +268,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broad """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -287,7 +289,7 @@ def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py new file mode 100644 index 0000000000..fe7a4a6cd8 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Add(MatMul(X, W), B) -> Gemm +- Add(MatMul(Transpose(X), W), B) -> Gemm +- Add(MatMul(X, Transpose(W)), B) -> Gemm +- Add(MatMul(Transpose(X), Transpose(W)), B) -> Gemm +""" + +import abc +from typing import ClassVar + +from onnxscript.rewriter import _ir_utils +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _MatMulAddToGemmBase(RewriteRuleClassBase, abc.ABC): + trans_a: ClassVar = False + trans_b: ClassVar = False + + def rewrite(self, op, input_a, input_b, input_c): + attributes = {} + if self.trans_a: + attributes["transA"] = 1 + if self.trans_b: + attributes["transB"] = 1 + return op.Gemm(input_a, input_b, input_c, **attributes) + + def check(self, context, input_a, input_b, **_): + del context # Not used + check_result = MatchResult() + # Rank of input_a and input_b must be 2 + if not (_ir_utils.has_rank(input_a, 2) and _ir_utils.has_rank(input_b, 2)): + return check_result.fail("Rank of input_a and input_b must be 2") + return check_result + + +class MatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(a, b), c)`` with ``Gemm(a, b, c)``.""" + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul(input_a, input_b) + return op.Add(matmul, input_c) + + +class TransAMatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(Transpose(a), b), c)`` with ``Gemm(a, b, c)``.""" + + trans_a: ClassVar = True + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul(op.Transpose(input_a, perm=[1, 0]), input_b) + return op.Add(matmul, input_c) + + +class TransBMatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(a, Transpose(b)), c)`` with ``Gemm(a, b, c)``.""" + + trans_b: ClassVar = True + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul(input_a, op.Transpose(input_b, perm=[1, 0])) + return op.Add(matmul, input_c) + + +class TransABMatMulAddToGemm(_MatMulAddToGemmBase): + """Replaces ``Add(MatMul(Transpose(a), Transpose(b)), c)`` with ``Gemm(a, b, c)``.""" + + trans_a: ClassVar = True + trans_b: ClassVar = True + + def pattern(self, op, input_a, input_b, input_c): + matmul = op.MatMul( + op.Transpose(input_a, perm=[1, 0]), + op.Transpose(input_b, perm=[1, 0]), + ) + return op.Add(matmul, input_c) + + +matmul_add_to_gemm_rule = MatMulAddToGemm().rule() +transpose_a_matmul_add_to_gemm_rule = TransAMatMulAddToGemm().rule() +transpose_b_matmul_add_to_gemm_rule = TransBMatMulAddToGemm().rule() +transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule() + + +rules = RewriteRuleSet( + [ + transpose_ab_matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, + matmul_add_to_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py new file mode 100644 index 0000000000..4c643801fc --- /dev/null +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -0,0 +1,316 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest +from typing import Sequence + +import numpy as np +import onnx +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript import ir +from onnxscript.rewriter import MatchingTracer, MatchStatus, testing +from onnxscript.rewriter.rules.common import _matmul_add_to_gemm + + +class _MatMulAddToGemmTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250607) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def get_test_model( + self, + input_shape: ir.Shape, + weight_shape: ir.Shape, + transA: bool = False, + transB: bool = False, + permA: Sequence[int] = [1, 0], + permB: Sequence[int] = [1, 0], + weight_as_inputs: bool = False, + bias_as_inputs: bool = False, + ): + """Returns the following model: + + Y = Add(MatMul(Transpose(X), Transpose(W)), B) + + Where: + - Transpose(X) is applied only if `transA=True` + - Transpose(W) is applied only if `transB=True` + - W and B can be graph inputs or initializers + """ + tape = ir.tape.Tape() + inputs = [] + bias_shape = weight_shape[0] if transB else weight_shape[-1] + output_shape = ir.Shape(("?",) * input_shape.rank()) + + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + + if weight_as_inputs: + w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) + inputs.append(w) + else: + w = ir.tensor( + self.rng.uniform(-0.5, 0.5, weight_shape).astype("float32"), name="W" + ) + w = tape.initializer(w) + + if bias_as_inputs: + b = ir.val( + "B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) + ) + inputs.append(b) + else: + b = ir.tensor(self.rng.uniform(-0.5, 0.5, bias_shape).astype("float32"), name="B") + b = tape.initializer(b) + + x_t, w_t = None, None + if transA: + x_t = tape.op("Transpose", inputs=[x], attributes={"perm": permA}) + + if transB: + w_t = tape.op("Transpose", inputs=[w], attributes={"perm": permB}) + + y = tape.op("MatMul", inputs=[x_t if transA else x, w_t if transB else w]) + y = tape.op( + "Add", + inputs=[y, b], + output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + ) + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x, *inputs], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="test_model", + ), + ir_version=10, + ) + onnx_checker.CheckerPass(True)(ir_model) + ir_model = shape_inference.infer_shapes(ir_model) + return ir_model + + def check_matmul_add_to_gemm_incompatible_shapes(self, **kwargs): + base_model = self.get_test_model(**kwargs) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = _matmul_add_to_gemm.matmul_add_to_gemm_rule.apply_to_model( + updated_model, tracer=tracer + ) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_matmul_add_to_gemm.matmul_add_to_gemm_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex( + tracer_match.match_result.reason, "Rank of input_a and input_b must be 2" + ) + + +class MatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((512, 256)), + weight_shape=ir.Shape((256, 64)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + ) + updated_model = self.clone_model(base_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) + + # Check MatMul + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((512, 256), dtype=np.float32), + self.rng.random((256, 64), dtype=np.float32), + self.rng.random((64), dtype=np.float32), + ) + else: + inputs = (self.rng.random((512, 256), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 256, 512)), + "weight_shape": ir.Shape((1, 512, 64)), + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +class TransAMatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_transpose_a_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((256, 512)), + weight_shape=ir.Shape((256, 64)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + transA=True, + ) + updated_model = self.clone_model(base_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) + + # Check MatMul(Transpose, W) + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((256, 512), dtype=np.float32), + self.rng.random((256, 64), dtype=np.float32), + self.rng.random((64,), dtype=np.float32), + ) + else: + inputs = (self.rng.random((256, 512), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_transpose_a_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 256, 512)), + "weight_shape": ir.Shape((1, 256, 64)), + "transA": True, + "permA": [0, 2, 1], + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +class TransBMatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_transpose_b_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((512, 256)), + weight_shape=ir.Shape((64, 256)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + transB=True, + ) + updated_model = self.clone_model(base_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) + + # Check MatMul(X, Transpose) + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((512, 256), dtype=np.float32), + self.rng.random((64, 256), dtype=np.float32), + self.rng.random((64,), dtype=np.float32), + ) + else: + inputs = (self.rng.random((512, 256), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_transpose_b_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 512, 256)), + "weight_shape": ir.Shape((1, 64, 256)), + "transB": True, + "permB": [0, 2, 1], + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +class TransABMatMulAddToGemmTest(_MatMulAddToGemmTestBase): + @parameterized.expand( + [ + ("initializers", False, False), + ("inputs", True, True), + ] + ) + def test_transpose_ab_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): + base_model = self.get_test_model( + input_shape=ir.Shape((256, 512)), + weight_shape=ir.Shape((64, 256)), + weight_as_inputs=weight_as_inputs, + bias_as_inputs=bias_as_inputs, + transA=True, + transB=True, + ) + updated_model = self.clone_model(base_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) + + # Check MatMul(Transpose, Transpose) + Add are fused into Gemm + self.assertEqual(count, 1) + self.assertEqual(len(updated_model.graph), 1) + + # Prepare inputs + if weight_as_inputs and bias_as_inputs: + inputs = ( + self.rng.random((256, 512), dtype=np.float32), + self.rng.random((64, 256), dtype=np.float32), + self.rng.random((64), dtype=np.float32), + ) + else: + inputs = (self.rng.random((256, 512), dtype=np.float32),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_transpose_ab_matmul_add_to_gemm_incompatible_shapes(self): + kwargs = { + "input_shape": ir.Shape((1, 256, 512)), + "weight_shape": ir.Shape((1, 64, 256)), + "transA": True, + "transB": True, + "permA": [0, 2, 1], + "permB": [0, 2, 1], + } + return super().check_matmul_add_to_gemm_incompatible_shapes(**kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip.py b/onnxscript/rewriter/rules/common/_min_max_to_clip.py new file mode 100644 index 0000000000..88ae495dbc --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses successive Min/Max patterns in ONNX graphs. + +Supported transformations: +- Min(Min(X, c1, c2, ...), d1, d2, ...) → Min(X, fused_const) +- Max(Max(X, c1, c2, ...), d1, d2, ...) → Max(X, fused_const) +- Min(Max(X, lb1, lb2, ...), ub1, ub2, ...) → Clip(X, lb, ub) +- Max(Min(X, ub1, ub2, ...), lb1, lb2, ...) → Clip(X, lb, ub) + +Where: + - fused_const is the reduction (min or max) over all constant inputs. + - For Clip fusion: + * All constant inputs must be scalars. + * The effective lower bound is the maximum of all lower-bound constants. + * The effective upper bound is the minimum of all upper-bound constants. + + For the case of Max(Min(X, upper_bound), lower_bound): + * The rule applies only if lower_bound ≤ upper_bound. + +General constraints: + - The first input may be any tensor. + - All other inputs must be constant tensors (from Constant nodes or initializers). +""" + +import abc +import functools +from typing import ClassVar + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC): + """Base class for Min/Max fusion rewrites. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - If ``need_scalars`` is True (Clip fusion), all constants must be scalars. + - If ``check_bounds`` is True (Clip fusion in the pattern Max(Min(X, upper_bound), lower_bound)), lower_bound ≤ upper_bound. + """ + + need_scalars: ClassVar = False + check_bounds: ClassVar = False + + @abc.abstractmethod + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: ... + + def rewrite(self, op, x, out1, out2): + first_node = out1.producer() + second_node = out2.producer() + + # Compute new constants for the fused op + constants = self.compute_constants(first_node, second_node, x.name) + + initializers = [op.initializer(constant, name=name) for constant, name in constants] + + return op.op( + self.op_type, + inputs=[x, *initializers], + ) + + def _is_scalar(self, v: np.ndarray) -> bool: + return np.isscalar(v) or np.size(v) == 1 + + def check(self, context, out1, out2, **_): + """Condition to check if we need to replace the pattern. + + Conditions: + - The min and max input nodes must not be graph inputs. + - These inputs (except the first) must be constant values (from Constant nodes or initializers). + - In the case of Min(Max) and Max(Min) patterns: + * All inputs must be scalars (as Clip requires scalars). + For Max(Min) pattern: + * The lower bound must be less than or equal to the upper bound. + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Not used + check_result = MatchResult() + + first_node = out1.producer() + second_node = out2.producer() + + # Ensure all inputs except the first are constants + for input_ in first_node.inputs[1:] + second_node.inputs[1:]: + if ir.convenience.get_const_tensor(input_) is None: + return check_result.fail(f"{input_.name} is not a constant.") + + # If scalars are required (Clip fusion), enforce scalar-ness + if self.need_scalars and not self._is_scalar(input_.const_value.numpy()): + return check_result.fail(f"{input_.name} is not a scalar.") + + if self.need_scalars and self.check_bounds: + # For Clip fusion in the case of Max(Min(X, upper_bound), lower_bound): check that lower_bound <= upper_bound + lower_bound, upper_bound = self.compute_constants(first_node, second_node) + if lower_bound[0].numpy() > upper_bound[0].numpy(): + return check_result.fail( + f"Invalid bounds: lower bound ({lower_bound[0].numpy()}) is greater " + f"than upper bound ({upper_bound[0].numpy()})." + ) + + return check_result + + +class FuseSuccessiveMin(_FuseMinMaxBase): + """Replaces ``Min(Min(X, c1, c2, ...), d1, d2, ...)`` with ``Min(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Min" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.minimum, values)), f"{input_name}_min")] + + def pattern(self, op, x): + return op.Min( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseSuccessiveMax(_FuseMinMaxBase): + """Replaces ``Max(Max(X, c1, c2, ...), d1, d2, ...)`` with ``Max(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Max" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.maximum, values)), f"{input_name}_max")] + + def pattern(self, op, x): + return op.Max( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMaxMinToClip(_FuseMinMaxBase): + """Replaces ``Min(Max(X, lb1, lb2, ...), ub1, ub2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + lower_bound = np.max([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + upper_bound = np.min([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Min( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMinMaxToClip(_FuseMinMaxBase): + """Replaces ``Max(Min(X, ub1, ub2, ...), lb1, lb2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + - Requires ``lower_bound <= upper_bound``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + check_bounds: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + upper_bound = np.min([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + lower_bound = np.max([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Max( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +min_min_rule = FuseSuccessiveMin().rule() +max_max_rule = FuseSuccessiveMax().rule() +min_max_rule = FuseMinMaxToClip().rule() +max_min_rule = FuseMaxMinToClip().rule() + + +rules = RewriteRuleSet( + [ + min_min_rule, + max_max_rule, + min_max_rule, + max_min_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py new file mode 100644 index 0000000000..dd09078a9e --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py @@ -0,0 +1,367 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, + rules, +) + + +class _TestMinMaxToClipBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250817) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = rules.apply_to_model(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = ( + self.rng.integers( + low=-10, + high=10, + size=(2, *updated_model.graph.inputs[0].shape[1:]), + dtype=np.int32, + ), + ) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class TestFuseSuccessiveMinOrMax(_TestMinMaxToClipBase): + @parameterized.expand( + [ + ("int32_min", "int32", "Min"), + ("int32_max", "int32", "Max"), + ("float32_min", "float", "Min"), + ("float32_max", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14, 17] X) => ({dtype} [N, ?, ?, ?] Y) + <{dtype}[1] cst1 = {{3}}, {dtype}[1] cst2 = {{6}}> + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min_multi", "int32", "Min"), + ("int32_max_multi", "int32", "Max"), + ("float32_min_multi", "float", "Min"), + ("float32_max_multi", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_multiple_inputs(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 3, 3] X) => ({dtype}[N, 3, 3] Y) + <{dtype}[3] cst1 = {{2, 5, 8}}, + {dtype}[1] cst2 = {{4}}, + {dtype}[3] cst3 = {{3, 1, -6}}, + {dtype}[1] cst4 = {{10}}, + {dtype}[3] cst5 = {{-2, 7, 9}}, + {dtype}[1] cst6 = {{0}}, + {dtype}[3] cst7 = {{11, -3, 4}}> + {{ + x1 = {op_type}(X, cst1, cst2, cst3, cst4) + Y = {op_type}(x1, cst5, cst6, cst7) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min", "Min"), + ("int32_max", "Max"), + ("float32_min", "Min"), + ("float32_max", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + cst2 = Constant() + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=["Constant", op_type]) + + @parameterized.expand( + [ + ("min_nonconst", "Min", min_min_rule), + ("max_nonconst", "Max", max_max_rule), + ] + ) + def test_failure_fuse_successive_min_or_max_non_constant(self, _, op_type, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] Y) + + {{ + cst1 = ReduceMean(X) + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is not a constant.") + + @parameterized.expand( + [ + ("min_graph_input", "Min"), + ("max_graph_input", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_graph_inputs_as_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] cst1, float[1] cst2) => (float[N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type]) + + +class TestMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_min_max_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_min_max_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + max = Constant() + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_min_max_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_min_max_to_clip_invalid_bounds(self): + """Min node should have the max value and Max node should have the min value.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "Invalid bounds:") + + def test_failure_fuse_min_max_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(model, min_max_rule, "is not a constant.") + + def test_failure_min_max_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "is not a scalar") + + +class TestMaxMinToClip(_TestMinMaxToClipBase): + def test_successful_max_min_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + min = Constant() + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_max_min_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_check_bounds(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_fuse_max_min_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_failed_condition_test(model, max_min_rule, "is not a constant.") + + def test_failure_max_min_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Max(X, min) + Y = Min(x1, max) + } + """) + self.run_failed_condition_test(base_model, max_min_rule, "is not a scalar") + + +class TestIntegrationMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_full_chain_fusion(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min1) + x2 = Min(x1, min2) + x3 = Max(x2, max1) + x4 = Max(x3, max2) + x5 = Min(x4, min3) + x6 = Max(x5, max3) + Y = Min(x6, min4) + } + """) + self.run_test(model, expected_op_types=["Clip", "Clip", "Clip"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_no_op.py b/onnxscript/rewriter/rules/common/_no_op.py new file mode 100644 index 0000000000..d75338bf03 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_no_op.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet + +# TODO: Support 1-D constant tensors +# https://github.com/microsoft/onnx-rewriter/issues/186 + + +# Pattern to match against +def mul_by_1(op, x): + return x * 1 + + +def add_0(op, x): + return x + 0 + + +def sub_0(op, x): + return x - 0 + + +def div_by_1(op, x): + return x / 1 + + +def dropout_zero(op, x): + return op.Dropout(x, ratio=0.0) + + +def dropout_inference(op, x): + return op.Dropout(x, training_mode=False) + + +# Replacement +def identity(op, x, **_): + return op.Identity(x) + + +mul_by_1_rule = RewriteRule(mul_by_1, identity) +add_0_rule = RewriteRule(add_0, identity) +sub_0_rule = RewriteRule(sub_0, identity) +div_by_1_rule = RewriteRule(div_by_1, identity) +dropout_zero_rule = RewriteRule(dropout_zero, identity) +dropout_inference_rule = RewriteRule(dropout_inference, identity) +# TODO: Include Mul by 0, 0 by Mul, 0 by Div? Those would be 0s, but not no-ops + +rules = RewriteRuleSet( + [ + *mul_by_1_rule.commute(), + *add_0_rule.commute(), + sub_0_rule, + div_by_1_rule, + dropout_zero_rule, + dropout_inference_rule, + ] +) diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/rules/common/_no_op_test.py similarity index 85% rename from onnxscript/rewriter/no_op_test.py rename to onnxscript/rewriter/rules/common/_no_op_test.py index 1cc1a47cfa..7815473e34 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/rules/common/_no_op_test.py @@ -1,17 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest -import onnx.parser import parameterized from onnxscript import ir -from onnxscript.rewriter import no_op +from onnxscript.rewriter.rules.common import _no_op class NoOpTest(unittest.TestCase): def _check(self, model_text: str) -> None: - model_proto = onnx.parser.parse_model(model_text) - model = ir.serde.deserialize_model(model_proto) - count = no_op.rules.apply_to_model(model) + model = ir.from_onnx_text(model_text) + count = _no_op.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(model.graph[-1].op_type, "Identity") @@ -175,6 +175,26 @@ def test_div_one_should_become_no_op_with_initializer( """ ) + @parameterized.parameterized.expand( + [ + ("dropout zero ratio", "ratio=0.0"), + ("dropout inference", "training_mode=0"), + ("dropout inference with positive ratio", "ratio=0.42, training_mode=0"), + ("dropout training with zero ratio", "ratio=0.0, training_mode=1"), + ] + ) + def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: str): + self._check( + f""" + + agraph (float16[M] input) => (float16[M] output) + {{ + output = Dropout<{attribute}>(input) + }} + """ + ) + # TODO: Test the negative cases + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py new file mode 100644 index 0000000000..09c5db7735 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Rewrite rules to eliminate redundant ScatterND operations. + +This module contains two rewrite rules: + +1. ScatterAllDynamic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates) + when the indices are computed dynamically using Range operations but represent a complete update + of an entire axis. This is generated by the translation of `x[:, ...] = y` in PyTorch. + +2. ScatterAllStatic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates) + when the indices are statically known constants in the form [[0], [1], ..., [n-1]] covering + the entire first dimension of the data tensor. + +Both rules detect when the scatter-update ends up being an assignment of a new value to the entire tensor. +""" + +from __future__ import annotations + +import onnx_ir as ir + +import onnxscript.rewriter +from onnxscript.rewriter import _ir_utils +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class ScatterAllDynamic(RewriteRuleClassBase): + def __init__(self): + super().__init__(remove_nodes=False) + + def pattern(self, op, data, axis, transposed_data, updates): + # Construct update-indices spanning an entire axis: + shape = op.Shape(data, start=0) + dim = op.Gather(shape, axis, axis=0) + full_range = op.Range(0, dim, 1) + full_range_2d = op.Unsqueeze(full_range, [-1]) + # The update is applied to the data transposed to bring the updated axis to the front: + return op.ScatterND(transposed_data, full_range_2d, updates, reduction="none") + + def check(self, context, data, axis, transposed_data, **_): + # Check that updated-indices represent the full range of the first dimension of the transposed data. + # That is: check that the data.shape[axis] matches transposed_data.shape[0]. + result = onnxscript.rewriter.MatchResult() + axis_value = _ir_utils.get_singleton_value(axis) + if not isinstance(axis_value, int): + return result.fail("Axis value must be a constant integer.", axis) + shape: ir.Shape | None = data.shape + if shape is None: + return result.fail("Data shape is not statically known.", data) + updated_dim_value = shape[axis_value] + transposed_data_shape: ir.Shape | None = transposed_data.shape + if transposed_data_shape is None: + return result.fail( + "Transposed data shape is not statically known.", transposed_data + ) + actual_dim_value = transposed_data_shape[0] + if not _ir_utils.same_dim(updated_dim_value, actual_dim_value): + # The first dimension of the transposed data does not match the updated dimension, + # so we cannot apply this rule. + return result.fail( + "The first dimension of the transposed data does not match the updated dimension.", + [data, transposed_data], + ) + return True + + def rewrite(self, op, updates, **_): + return op.Identity(updates) + + +class ScatterAllStatic(RewriteRuleClassBase): + """Rewrite rule for eliminating redundant ScatterND with statically known indices. + + This handles the case where indices are constant values in the form [[0], [1], ..., [n-1]] + that update the entire first dimension of the data tensor. + """ + + def pattern(self, op, data, indices, updates): + """Pattern to match ScatterND with static indices.""" + return op.ScatterND(data, indices, updates) + + def check(self, context, data, indices, updates, **_): + """Check if the ScatterND is redundant due to static indices covering entire tensor.""" + # To validate data can be replaced directly by updates, we need to check the following: + # 1. they have the same shape + result = onnxscript.rewriter.MatchResult() + if data.shape is None: + return result.fail("The value 'data' shape is not statically known.", data) + if updates.shape is None: + return result.fail("The value 'updates' shape is not statically known.", updates) + if not _ir_utils.same_shape(data.shape, updates.shape): + return result.fail( + "The shape of 'data' and 'updates' are different.", [data, updates] + ) + + # 2. the indices is referring to the whole data, which is from 0 to data.shape[0] + if indices.const_value is None: + return result.fail("The value 'indices' is not statically known.", indices) + expected_indices = [[i] for i in range(data.shape[0])] + actual_indices = indices.const_value.numpy().tolist() + if actual_indices != expected_indices: + return result.fail("The 'indices' is not referring to the whole data.", indices) + + return True + + def rewrite(self, op, updates, **_): + """Replace ScatterND with Identity since updates covers entire tensor.""" + return op.Identity(updates) + + +no_op_dynamic_scatter_nd_rule = ScatterAllDynamic.rule() +no_op_static_scatter_nd_rule = ScatterAllStatic.rule() + +rules = RewriteRuleSet([no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule]) diff --git a/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py new file mode 100644 index 0000000000..96e3bcc80c --- /dev/null +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ruff: noqa: F821 + +import unittest + +import numpy as np +import onnx.parser +import onnx_ir as ir +import onnxruntime +from onnx_ir.passes.common import CheckerPass, ShapeInferencePass + +import onnxscript.optimizer +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.rewriter.rules.common import _redundant_scatter_nd + +shape_inference = ShapeInferencePass() +onnx_check = CheckerPass(True) + + +class RedundantScatterNdTest(unittest.TestCase): + def test_redundant_scatter_nd_dynamic_indices(self): + """Test redundant ScatterND with dynamically constructed indices.""" + + @script() + def model_script( + data: FLOAT[8, "N", 16], updates: FLOAT[8, "N", 16] + ) -> FLOAT[8, "N", 16]: + # Construct update-indices spanning an entire axis: + axis = op.Constant(value_int=1) + shape = op.Shape(data, start=0) + dim = op.Gather(shape, axis, axis=0) + full_range = op.Range(0, dim, 1) + full_range_2d = op.Unsqueeze(full_range, [-1]) + # The update is applied to the data transposed to bring the updated axis to the front: + transposed_data = op.Transpose(data, perm=[1, 0, 2]) + transposed_updates = op.Transpose(updates, perm=[1, 0, 2]) + scattered = op.ScatterND( + transposed_data, full_range_2d, transposed_updates, reduction="none" + ) + # Transpose the result back to the original shape: + output = op.Transpose(scattered, perm=[1, 0, 2]) + return output + + input_model_proto = model_script.to_model_proto() + model = ir.serde.deserialize_model(input_model_proto) + onnx_check(model) + shape_inference(model) + onnxscript.optimizer.fold_constants(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) + self.assertEqual(count, 1) + onnx_check(model) + optimized_model_proto = ir.serde.serialize_model(model) + # Test that both models are equivalent: + inputs = { + "data": np.random.rand(8, 4, 16).astype(np.float32), + "updates": np.random.rand(8, 4, 16).astype(np.float32), + } + session = onnxruntime.InferenceSession( + input_model_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + outputs = session.run(None, inputs) + optimized_session = onnxruntime.InferenceSession( + optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + optimized_outputs = optimized_session.run(None, inputs) + # Compare outputs + for output, optimized_output in zip(outputs, optimized_outputs): + np.testing.assert_allclose(output, optimized_output, rtol=1e-6, atol=1e-6) + + def test_redundant_scatter_nd_static_indices(self): + """Test redundant ScatterND with static indices (moved from collapse_slices_test.py).""" + model_proto = onnx.parser.parse_model( + """ + + agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output) + { + output = ScatterND (data, indices, updates) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + indices = np.arange(112).reshape(112, 1).astype(np.int64) + model = ir.serde.deserialize_model(model_proto) + # from numpy to ir.Tensor + indices_ir_tensor = ir.Tensor( + name="indices", + value=indices, + ) + # assign the tensor to a value + indices_value = model.graph[0].inputs[1] + indices_value.const_value = indices_ir_tensor + model.graph.initializers["indices"] = indices_value + original_model_proto = ir.serde.serialize_model(model) + + count = _redundant_scatter_nd.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + self.assertIn("Identity", [node.op_type for node in model.graph]) + + # Test numerical equivalence + input_data = np.random.rand(112, 16, 512).astype(np.float32) + inputs = {"data": input_data, "updates": input_data} + + # Run original model + session = onnxruntime.InferenceSession( + original_model_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + original_outputs = session.run(None, inputs) + + # Run optimized model + optimized_model_proto = ir.serde.serialize_model(model) + optimized_session = onnxruntime.InferenceSession( + optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + optimized_outputs = optimized_session.run(None, inputs) + + # Compare outputs + for original_output, optimized_output in zip(original_outputs, optimized_outputs): + np.testing.assert_allclose(original_output, optimized_output, rtol=1e-6, atol=1e-6) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/__init__.py b/onnxscript/rewriter/rules/fusion/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/rules/fusion/_gqa.py b/onnxscript/rewriter/rules/fusion/_gqa.py new file mode 100644 index 0000000000..8d6f156ed5 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_gqa.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Union + +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _basics, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("ONNXGQA", remove_nodes=False) + + def pattern( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + ): + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2) + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE) + present_key_BHStD = op.Reshape( + present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2) + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) + present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE) + present_value_BHStD = op.Reshape( + present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"] + ) + + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + pattern.Var("mask", can_match_none=True), + _outputs=["attention_BHSDh"], + ) + + return attention_BHSDh + + def check( + self, + context: _basics.MatchContext, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + present_key_BHStD, + present_value_BHStD, + **_, + ): + bindings: dict[str, Dim] = {} + # Check that inputs to new Attention node have expected shapes + _fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"]) + _fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"]) + _fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"]) + # We need to check that the Expand/Reshape arguments are as expected. + # As a substitute, we check that the outputs of Expand=>Reshape have expected shapes. + # TODO (rama): May be better to check the actual Expand/Reshape arguments. + _fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"]) + _fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"]) + + return True + + def rewrite( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + mask, + attention_BHSDh, + **_, + ): + original_attention_node = attention_BHSDh.producer() + original_attrs = original_attention_node.attributes + return op.Attention( + query_BHSD, + key_BHkvSD, + value_BHkvSD, + mask, + past_key_BHkvSpD, + past_value_BHkvSpD, + **original_attrs, + ) + + +_basic_gqa_rule = OnnxGroupQueryAttention.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) diff --git a/onnxscript/rewriter/rules/fusion/_gqa_test.py b/onnxscript/rewriter/rules/fusion/_gqa_test.py new file mode 100644 index 0000000000..baf80c4b8c --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_gqa_test.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx +import onnx_ir as ir +from packaging import version + +import onnxscript +import onnxscript.optimizer +import onnxscript.rewriter.testing +from onnxscript import FLOAT, script +from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa + +op = onnxscript.values.Opset("", 23) + +H = [8] # Number of attention heads +Hkv = [4] # Number of key/value heads (H should be divisible by Hkv) +D = [64] # Head size +G = [2] # Number of groups + + +@script(ir_version=10) +def _gqa_script( + query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64 + key_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + value_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + past_key_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 + past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 +) -> FLOAT[2, 8, 4, 64]: + """Basic GQA pattern that should be fused into an Attention op.""" + + # Concatenate past_key cache and current key + present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D] + + # Unsqueeze to add group dimension + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + + # Calculate shapes dynamically + B = op.Shape(query_BHSD, start=0, end=1) # [B] + T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P] + + # Create expand shape [B, Hkv, G, S+P, D] + expand_shape = op.Concat(B, Hkv, G, T, D, axis=0) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] + + # Create reshape shape [B, H, S+P, D] + reshape_shape = op.Concat(B, H, T, D, axis=0) + present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Same for value + present_value_BHkvStD = op.Concat( + past_value_BHkvPD, value_BHkvSD, axis=-2 + ) # [B, Hkv, S+P, D] + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + present_value_BHkvGStD = op.Expand( + present_value_BHkv1StD, expand_shape + ) # [B, Hkv, G, S+P, D] + present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Attention computation + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + ) + + return attention_BHSDh + + +class GQAFusionTest(unittest.TestCase): + def test_basic_gqa_fusion(self): + """Test basic GQA fusion pattern.""" + model_proto = _gqa_script.to_model_proto() + + # Apply GQA fusion + model = ir.serde.deserialize_model(model_proto) + onnxscript.optimizer.optimize(model) + count = fuse_gqa(model) + self.assertGreater(count, 0, "GQA fusion should have occurred") + + # We can't yet test numerical equivalence because of a bug in the op spec/implementation. + onnx_ver = version.parse(onnx.__version__) + if onnx_ver >= version.parse("1.19.1") and not ( + onnx_ver.is_prerelease or onnx_ver.is_devrelease + ): + # Only official releases >= 1.19.1 + onnxscript.optimizer.remove_unused_nodes(model) + rewritten_model_proto = ir.serde.serialize_model(model) + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, rewritten_model_proto, use_reference=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/_layer_norm.py b/onnxscript/rewriter/rules/fusion/_layer_norm.py new file mode 100644 index 0000000000..30a3428d15 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_layer_norm.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +Layer Normalization fusion optimization. + +This module contains rewrite rules for fusing Layer Normalization patterns into the +ONNX LayerNormalization operator. + +Layer Normalization performs normalization over the last D dimensions as specified by the axis. +The computation follows: Y = scale * (X - mean) / sqrt(variance + epsilon) + bias + +Key points for the fusion optimization: +* Following restrictions from opset 17 LayerNormalization: +* Input, scale, and bias must be of same type T in {float16, bfloat16, float, double} +* The normalization can be done in a different precision than the input type (bfloat16 or float), +which is also the precision of the output mean/invstddev +""" + +# input types permitted by LayerNormalization op (ONNX Opset 17) +LAYER_NORM_INPUT_TYPES = frozenset( + [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, + ] +) + +# Compute types permitted by LayerNormalization op (ONNX Opset 17), aka stash_type. +LAYER_NORM_COMPUTE_TYPES = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) + + +class LayerNormFusion(pattern.RewriteRuleClassBase): + """Fuse LayerNorm pattern into LayerNormalization op.""" + + def pattern(self, op, x, scale, epsilon): + # Compute mean: Mean = ReduceMean(X, axes=normalized_axes) + # TODO: support axes attribute too + mean = op.ReduceMean(x, [-1], keepdims=1) + + # Compute deviation: D = Sub(X, Mean) + deviation = op.Sub(x, mean) + + # Compute squared deviation: DD = Mul(D, D) + deviation_squared = pattern.OrValue( + [ + op.Mul(deviation, deviation), + op.Pow(deviation, 2), + ] + ) + + # Compute variance: Var = ReduceMean(DD, axes=normalized_axes) + variance = op.ReduceMean(deviation_squared, [-1], keepdims=1) + + # Add epsilon: VarEps = Add(Var, epsilon) + variance_plus_epsilon = op.Add(variance, epsilon) + + # Compute standard deviation: StdDev = Sqrt(VarEps) + std_dev = op.Sqrt(variance_plus_epsilon) + + # Compute reciprocal: InvStdDev = Reciprocal(StdDev) + # Normalize: Normalized = Mul(D, InvStdDev) + + inv_std_dev = op.Reciprocal(std_dev) + normalized = pattern.OrValue( + [op.Mul(deviation, inv_std_dev), op.Div(deviation, std_dev)] + ) + + # Scale: NormalizedScaled = Mul(Normalized, Scale) + normalized_scaled = op.Mul(normalized, scale) + + return normalized_scaled + + def check(self, context, x, epsilon, **_) -> pattern.MatchResult: # type: ignore[name-defined] + """Check if the pattern matches conditions for use of LayerNormalization op.""" + check_result = pattern.MatchResult() + + # Type validation: + if x.dtype not in LAYER_NORM_COMPUTE_TYPES: + return check_result.fail("Input is not a float type.", x) + self._stash_type = x.dtype + + # Check that epsilon is a scalar constant + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if epsilon_value is None: + return check_result.fail("Epsilon is not a constant scalar.", epsilon) + # Epsilon is guaranteed to be same type as x (float or double, in this pattern) + self._epsilon = float(epsilon_value) + + return check_result + + def rewrite(self, op, x, scale, epsilon, **_): + return op.LayerNormalization( + x, + scale, + axis=-1, + epsilon=self._epsilon, + stash_type=self._stash_type, + ) + + +class LayerNormBiasFusion(pattern.RewriteRuleClassBase): + """Fuse LayerNorm => Add into LayerNorm with bias.""" + + def pattern(self, op, x, scale, bias): + return op.LayerNormalization(x, scale, _outputs=["normalized"]) + bias + + def rewrite(self, op, x, scale, bias, normalized): + layernorm_node = normalized.producer() + attributes = layernorm_node.attributes + num_outputs = len(layernorm_node.outputs) + return op.LayerNormalization(x, scale, bias, _outputs=num_outputs, **attributes) + + +# Create rules for both with and without bias +_layer_norm_rule = LayerNormFusion.rule() +_layer_norm_with_bias_rule = LayerNormBiasFusion.rule() + +layer_normalization_rules = [_layer_norm_rule, _layer_norm_with_bias_rule] +layer_normalization_ruleset = pattern.RewriteRuleSet(layer_normalization_rules) + +fuse_layer_normalization = _fusion_utils.apply_fusion_rules(layer_normalization_ruleset) diff --git a/onnxscript/rewriter/rules/fusion/_layer_norm_test.py b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py new file mode 100644 index 0000000000..6ea7f116fb --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx_ir as ir + +import onnxscript +import onnxscript.optimizer +import onnxscript.rewriter.testing +from onnxscript import FLOAT, OnnxFunction, script +from onnxscript import opset18 as op +from onnxscript.rewriter.rules.fusion._layer_norm import fuse_layer_normalization + + +@script() +def _test_layer_norm_without_bias(x: FLOAT[2, 4, 8], scale: FLOAT[8]) -> FLOAT[2, 4, 8]: + """LayerNorm pattern without bias.""" + # Compute mean: Mean = ReduceMean(X, axes=normalized_axes) + mean = op.ReduceMean(x, [-1], keepdims=1) + + # Compute deviation: D = Sub(X, Mean) + deviation = op.Sub(x, mean) + + # Compute squared deviation: DD = Mul(D, D) + deviation_squared = op.Mul(deviation, deviation) + + # Compute variance: Var = ReduceMean(DD, axes=normalized_axes) + variance = op.ReduceMean(deviation_squared, [-1], keepdims=1) + + # Add epsilon: VarEps = Add(Var, epsilon) + epsilon = op.Constant(value_float=1e-5) + variance_plus_epsilon = op.Add(variance, epsilon) + + # Compute standard deviation: StdDev = Sqrt(VarEps) + std_dev = op.Sqrt(variance_plus_epsilon) + + # Compute reciprocal: InvStdDev = Reciprocal(StdDev) + inv_std_dev = op.Reciprocal(std_dev) + + # Normalize: Normalized = Mul(D, InvStdDev) + normalized = op.Mul(deviation, inv_std_dev) + + # Scale: NormalizedScaled = Mul(Normalized, Scale) + normalized_scaled = op.Mul(normalized, scale) + + return normalized_scaled + + +@script() +def _test_layer_norm_with_bias( + x: FLOAT[2, 4, 8], scale: FLOAT[8], bias: FLOAT[8] +) -> FLOAT[2, 4, 8]: + """LayerNorm pattern with bias.""" + # Compute mean: Mean = ReduceMean(X, axes=normalized_axes) + mean = op.ReduceMean(x, [-1], keepdims=1) + + # Compute deviation: D = Sub(X, Mean) + deviation = op.Sub(x, mean) + + # Compute squared deviation: DD = Mul(D, D) + deviation_squared = op.Mul(deviation, deviation) + + # Compute variance: Var = ReduceMean(DD, axes=normalized_axes) + variance = op.ReduceMean(deviation_squared, [-1], keepdims=1) + + # Add epsilon: VarEps = Add(Var, epsilon) + epsilon = op.Constant(value_float=1e-5) + variance_plus_epsilon = op.Add(variance, epsilon) + + # Compute standard deviation: StdDev = Sqrt(VarEps) + std_dev = op.Sqrt(variance_plus_epsilon) + + # Compute reciprocal: InvStdDev = Reciprocal(StdDev) + inv_std_dev = op.Reciprocal(std_dev) + + # Normalize: Normalized = Mul(D, InvStdDev) + normalized = op.Mul(deviation, inv_std_dev) + + # Scale: NormalizedScaled = Mul(Normalized, Scale) + normalized_scaled = op.Mul(normalized, scale) + + # Add bias: Y = Add(NormalizedScaled, B) + result = op.Add(normalized_scaled, bias) + + return result + + +class LayerNormFusionTest(unittest.TestCase): + def _check(self, test_script: OnnxFunction): + """Helper method to run a fusion test scenario.""" + model_proto = test_script.to_model_proto() + # Create test inputs + input_data = onnxscript.rewriter.testing.generate_random_inputs(model_proto) + + model = ir.serde.deserialize_model(model_proto) + fuse_layer_normalization(model) + + onnxscript.optimizer.remove_unused_nodes(model) + + # Check that a LayerNormalization node was created + self.assertEqual(["LayerNormalization"], [n.op_type for n in model.graph]) + + fused_model_proto = ir.serde.serialize_model(model) + + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, fused_model_proto, input_data + ) + + def test_layer_norm_fusion_without_bias(self): + """Test LayerNorm fusion without bias.""" + self._check(_test_layer_norm_without_bias) + + def test_layer_norm_fusion_with_bias(self): + """Test LayerNorm fusion with bias.""" + self._check(_test_layer_norm_with_bias) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/_rms_normalization.py b/onnxscript/rewriter/rules/fusion/_rms_normalization.py new file mode 100644 index 0000000000..f4892b4918 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rms_normalization.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +RMS Normalization: ONNX Opset 23 op +See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization + + +Key points for the fusion optimization: +* Input and scale are allowed to be of different types. +* The normalization of the input can be done in a different precision than the input type, +indicated by stash_type. +* Input (x) must be: float or double or float16 or bfloat16 +* Scale must be: float or double or float16 or bfloat16 +""" + +float_types = frozenset( + [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, + ] +) +fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) + + +class RmsNormFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, mul_order: bool): + super().__init__(name) + self._mul_order = mul_order + + def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): + x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) + x_square = op.Pow(x, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x, reciprocal_rms) + normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) + # Workaround: limitation in pattern matcher doesn't support OrValue for return value (last node in pattern) + if self._mul_order: + return op.Mul(normalized, scale) + else: + return op.Mul(scale, normalized) + + def check( + self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] + """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() + # epsilon must be a scalar + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if not isinstance(epsilon_value, float): # TODO: support other types + return check_result.fail("Epsilon is not a float value.", epsilon) + if x.dtype not in float_types: + return check_result.fail("Input is not a supported float type.", x) + if scale.dtype not in float_types: + return check_result.fail("Scale is not a supported float type.", scale) + self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype + if self._stash_dtype not in fp_float_types: + # TODO: ONNX documentation does not specify restrictions on stash_type, though + # ORT's SimplifiedLayerNormalization requires it to be float or double. + return check_result.fail("Normalization precision is not a float or double type.") + # target_dtype is guaranteed to be the same as scale type in a well-typed input + # for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. + # TODO (rama): Consider adding checks to protect against incorrectly typed models: + return check_result + + def rewrite(self, op, x, scale, epsilon, **_): + # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. + # No need to use com.microsoft domain here; but this is a custom op in ORT. + return op.RMSNormalization( + x, + scale, + axis=-1, + epsilon=_ir_utils.get_singleton_value(epsilon), + stash_type=self._stash_dtype, + ) + + +_rule1 = RmsNormFusion.rule("RmsNormFusion1", mul_order=True) +_rule2 = RmsNormFusion.rule("RmsNormFusion2", mul_order=False) + +rms_normalization_rules = [_rule1, _rule2] + +rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) + +fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset) diff --git a/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py new file mode 100644 index 0000000000..e70c4ec7a0 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx_ir as ir + +import onnxscript +from onnxscript.rewriter.rules.fusion import _rms_normalization + + +class RmsNormOnnxFusionsTest(unittest.TestCase): + def test_rms_normalization_fusion(self): + opset23 = onnxscript.values.Opset("", 23) + + @onnxscript.script() + def rms_norm_script(embedding, layernorm_weight): + two = opset23.Constant(value_float=2.0) + pow_1 = opset23.Pow(embedding, two) + mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0) + epsilon = opset23.Constant(value_float=1e-05) + add_1 = opset23.Add(mean, epsilon) + val_244 = opset23.Sqrt(add_1) + rsqrt = opset23.Reciprocal(val_244) + mul_3 = opset23.Mul(embedding, rsqrt) + mul_4 = opset23.Mul(layernorm_weight, mul_3) + return mul_4 + + rms_norm_model_proto = rms_norm_script.to_model_proto( + input_types=[onnxscript.FLOAT[128], onnxscript.FLOAT[128]], + output_types=[onnxscript.FLOAT[128]], + ) + model = ir.serde.deserialize_model(rms_norm_model_proto) + count = _rms_normalization.fuse_rms_normalization(model) + self.assertEqual(count, 1) + self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py new file mode 100644 index 0000000000..b659afdbc0 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +# Fusions for RotaryEmbedding: +# Fuse computation patterns seen in HF transformer models for RotaryEmbedding +# and map them to ONNX opset 23 RotaryEmbedding op. + +# Basic pattern: For example, see +# https://github.com/huggingface/transformers/blob/541bed22d6e4f97946a3a7d74f7e1a353e58643b/src/transformers/models/llama/modeling_llama.py#L104 +# def rotate_half(x): +# """Rotates half the hidden dims of the input.""" +# x1 = x[..., : x.shape[-1] // 2] +# x2 = x[..., x.shape[-1] // 2 :] +# return torch.cat((-x2, x1), dim=-1) +# and +# q_embed = (q * cos) + (rotate_half(q) * sin) + + +def _rotate_half_pattern(op, x, start1, end1, start2, end2): + # Slice(input, starts, ends, axes, steps) + x1 = op.Slice(x, start1, end1, [3], [1]) + x2 = op.Slice(x, start2, end2, [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + return rotated_x + + +class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__(name="RotaryEmbedding23", remove_nodes=False) + + def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): + freqs_repeated = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(freqs_repeated) + sin = op.Sin(freqs_repeated) + cos_4d = op.Unsqueeze(cos, one1) + sin_4d = op.Unsqueeze(sin, one2) + return x * cos_4d + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin_4d + + def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() + + if not _ir_utils.is_singleton_value(one1, 1): + return check_result.fail("Unsqueeze axes is not [1]", one1) + if not _ir_utils.is_singleton_value(one2, 1): + return check_result.fail("Unsqueeze axes is not [1]", one2) + + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) + if x is None or x.shape is None or len(x.shape) != 4: + return check_result.fail("Input is not known to be a 4D tensor.", x) + if not isinstance(x.shape[1], int): + return check_result.fail("Input dimension 1 (num_heads) is not static.", x) + head_size = x.shape[3] + if not isinstance(head_size, int): + return check_result.fail("Head size is not static.", x) + half_head_size = head_size // 2 + + # Check that x is being split into two equal halves of size half_head_size + if not ( + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.is_singleton_value(end1, half_head_size) + and _ir_utils.is_singleton_value(start2, half_head_size) + and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) + ): + return check_result.fail( + "x is not being split into two equal halves of size half_head_size." + ) + return check_result + + def rewrite(self, op, x, freqs, **_): + num_heads = x.shape[1] + cos = op.Cos(freqs) + sin = op.Sin(freqs) + return op.RotaryEmbedding( + x, + cos, + sin, + interleaved=0, + num_heads=num_heads, + ) + + +# Extensions for partial rotary embedding fusion: with partial rotary embedding, +# embedding is applied only to the first part of the input, and the second part is left unchanged, +# as captured in the pattern below. + +MAX_INT64 = 9223372036854775807 + + +class PartialRotaryEmbedding23Fusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, end1, start2): + x_part_1 = op.Slice(x, [0], end1, [3], [1]) + x_part_2 = op.Slice(x, start2, [MAX_INT64], [3], [1]) + x_part_1_rope = op.RotaryEmbedding( + x_part_1, + _allow_other_inputs=True, + _allow_other_attributes=True, + _outputs=["x_part_1_rope"], + ) + return op.Concat(x_part_1_rope, x_part_2, axis=-1) + + def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() + end1_value = _ir_utils.get_singleton_value(end1) + start2_value = _ir_utils.get_singleton_value(start2) + if not isinstance(end1_value, int) or not isinstance(start2_value, int): + return check_result.fail("Unable to validate slice start/end values.") + if end1_value != start2_value: + return check_result.fail( + "The end1 value of first slice and start2 value of second slice are not equal." + ) + rotary_embedding_attributes = x_part_1_rope.producer().attributes + if "rotary_embedding_dim" in rotary_embedding_attributes: + return check_result.fail("rotary_embedding_dim attribute already specified.") + if ( + "interleaved" in rotary_embedding_attributes + and rotary_embedding_attributes["interleaved"].value != 0 + ): + return check_result.fail("interleaved is not equal to 0.") + return check_result + + def rewrite(self, op, x, end1, x_part_1_rope, **_): + # Create a modified version of the RotaryEmbedding op: + rotary_embedding_dim = _ir_utils.get_singleton_value(end1) + original_node = x_part_1_rope.producer() + inputs = list(original_node.inputs) + inputs[0] = x + attrs = dict(original_node.attributes) + attrs["rotary_embedding_dim"] = rotary_embedding_dim + return op.RotaryEmbedding( + *inputs, + **attrs, + ) + + +_rule = RotaryEmbedding23Fusion.rule() + +_partial_embedding_rule = PartialRotaryEmbedding23Fusion.rule() + +rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) + +partial_embedding_rules = pattern.RewriteRuleSet([_partial_embedding_rule]) + +fuse_rotary_embedding = _fusion_utils.apply_fusion_rules(rotary_embedding_rules) + +fuse_partial_rotary_embedding = _fusion_utils.apply_fusion_rules(partial_embedding_rules) diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py new file mode 100644 index 0000000000..b8ffe95cac --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx +import onnx_ir as ir +from packaging.version import Version +from parameterized import parameterized + +import onnxscript +import onnxscript.rewriter.testing +from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rotary_embedding + + +class RotaryEmbeddingOnnxFusionTest(unittest.TestCase): + @parameterized.expand( + [ + ( + "test_case_1", + _rotary_embedding_models.test_case_1, + ), + ( + "test_case_2", + _rotary_embedding_models.test_case_2, + ), + ] + ) + def test_rotary_embedding_fusion(self, _: str, test_data_constructor): + test = test_data_constructor() + model: ir.Model = test.get_onnx_model() + model.graph.opset_imports[""] = 23 + model_proto = ir.serde.serialize_model(model) + onnxscript.optimizer.optimize(model) + _rotary_embedding.fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + rewritten_model_proto = ir.serde.serialize_model(model) + inputs = test.get_ort_inputs() + + onnx_version = Version(onnx.__version__) + min_version = Version("1.19.1") + is_stable = not (onnx_version.is_devrelease or onnx_version.is_prerelease) + if onnx_version >= min_version and is_stable: + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, rewritten_model_proto, args=inputs, use_reference=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py new file mode 100644 index 0000000000..2a9d24ee01 --- /dev/null +++ b/onnxscript/rewriter/testing.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Any + +import numpy as np +import onnx +import onnx.reference +import onnxruntime as ort + +from onnxscript import ir + + +def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]: + feeds: dict[str, Any] = {} + for input in model.graph.input: + input_type = input.type.tensor_type + shape = tuple(input_type.shape.dim) + if not all(hasattr(d, "dim_value") for d in shape): + raise ValueError(f"Input {input.name} has dynamic shape dimensions.") + shape = tuple(d.dim_value for d in shape) + if input_type.elem_type == onnx.TensorProto.FLOAT: + if shape: + feeds[input.name] = np.random.randn(*shape).astype(np.float32) + else: + feeds[input.name] = np.random.randn(1).astype(np.float32) + else: + raise ValueError(f"Not implemented for input type {input_type.elem_type}") + return feeds + + +def assert_numerically_equal( + original_model_proto: onnx.ModelProto | ir.Model, + rewritten_model_proto: onnx.ModelProto | ir.Model, + args: tuple[Any, ...] | dict[str, Any] | None = None, + ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL, + rtol: float = 1, + atol: float = 1e-3, + use_reference: bool = False, +): + """Assert that the two models are numerically equal. + + Args: + original_model_proto: The original model proto or ir.Model. + rewritten_model_proto: The rewritten by the rules model proto or ir.Model. + args: The positional arguments to pass to the model. + ort_optimization_level: Onnxruntime optimization level. + rtol: Relative tolerance. + atol: Absolute tolerance. + use_reference: If True, use ONNX reference implementation instead of ONNXRuntime. + """ + + if isinstance(original_model_proto, ir.Model): + original_model_proto = ir.serde.serialize_model(original_model_proto) + if isinstance(rewritten_model_proto, ir.Model): + rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto) + + if args is None: + original_proto_ort_inputs = generate_random_inputs(original_model_proto) + the_rewritten_proto_ort_inputs = original_proto_ort_inputs + elif isinstance(args, dict): + original_proto_ort_inputs = args + the_rewritten_proto_ort_inputs = args + else: + original_proto_ort_inputs = { + k.name: v for k, v in zip(original_model_proto.graph.input, args) + } + the_rewritten_proto_ort_inputs = { + k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) + } + + if use_reference: + # Use ONNX reference implementation + original_evaluator = _reference_session( + original_model_proto.SerializeToString(), ort_optimization_level + ) + original_outputs = original_evaluator.run(None, original_proto_ort_inputs) + + rewritten_evaluator = _reference_session( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = rewritten_evaluator.run(None, the_rewritten_proto_ort_inputs) + else: + # Use ONNXRuntime + original_proto_ort_inference_session = _ort_session_initializer( + original_model_proto.SerializeToString(), ort_optimization_level + ) + run_options = ort.RunOptions() + run_options.log_severity_level = 3 # 3: Error + original_outputs = original_proto_ort_inference_session.run( + None, original_proto_ort_inputs, run_options=run_options + ) + + the_rewritten_proto_ort_inference_session = _ort_session_initializer( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( + None, the_rewritten_proto_ort_inputs, run_options=run_options + ) + + np.testing.assert_allclose( + original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True + ) + + +def _ort_session_initializer( + model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel +) -> ort.InferenceSession: + """Initialize an ONNX Runtime inference session with the specified model.""" + import onnxruntime as ort + + session_options = ort.SessionOptions() + session_options.log_severity_level = 3 # 3: Error + session_options.graph_optimization_level = ort_optimization_level + possible_providers = ( + "CUDAExecutionProvider", + "CPUExecutionProvider", + ) + available_providers = set(ort.get_available_providers()) + providers = [ + provider for provider in possible_providers if provider in available_providers + ] + return ort.InferenceSession(model, providers=providers, sess_options=session_options) + + +def _reference_session( + model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel +) -> onnx.reference.ReferenceEvaluator: + """Initialize an ONNX reference evaluator with the specified model.""" + # Parse the model from bytes if needed + if isinstance(model, (str, bytes)): + model_proto = onnx.load_from_string(model) + else: + model_proto = model + + # Note: ort_optimization_level is ignored for reference implementation + # as it doesn't have equivalent optimization levels + return onnx.reference.ReferenceEvaluator(model_proto) diff --git a/onnxscript/sourceinfo.py b/onnxscript/sourceinfo.py index 1e02551f27..b1e19eff73 100644 --- a/onnxscript/sourceinfo.py +++ b/onnxscript/sourceinfo.py @@ -33,7 +33,7 @@ def msg(self, error_message: str) -> str: if self.function_name: source_loc = f"Function '{self.function_name}', line {lineno}" else: - source_loc = "Line {lineno}" + source_loc = f"Line {lineno}" if self.code: lines = self.code.split("\n") diff --git a/onnxscript/tensor.py b/onnxscript/tensor.py index 9acb80467b..f1d781b808 100644 --- a/onnxscript/tensor.py +++ b/onnxscript/tensor.py @@ -1,17 +1,13 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations from typing import Any, Optional import numpy as np -import onnx.helper -from onnx import TensorProto -from onnxscript import onnx_opset +from onnxscript import ir, onnx_opset from onnxscript._internal import autocast @@ -54,7 +50,7 @@ def dtype(self) -> np.dtype: @property def onnx_dtype(self) -> int: - return onnx.helper.np_dtype_to_tensor_dtype(self.dtype) + return ir.DataType.from_numpy(self.dtype) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value!r})" @@ -162,10 +158,10 @@ def __getitem__(self, index): def __mod__(self, other): if self.onnx_dtype in { - TensorProto.FLOAT, - TensorProto.DOUBLE, - TensorProto.FLOAT16, - TensorProto.BFLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, }: return self._opset.Mod(self, other, fmod=1) return self._opset.Mod(self, other) diff --git a/onnxscript/tensor_test.py b/onnxscript/tensor_test.py index e81d01472e..afe490e8dc 100644 --- a/onnxscript/tensor_test.py +++ b/onnxscript/tensor_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Unit tests for the tensor module.""" import unittest diff --git a/onnxscript/testing/__init__.py b/onnxscript/testing/__init__.py index bacfe97773..048b45e7e8 100644 --- a/onnxscript/testing/__init__.py +++ b/onnxscript/testing/__init__.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations __all__ = [ @@ -12,10 +14,12 @@ from typing import Any, Collection, Sequence import google.protobuf.message +import numpy as np import onnx from onnx import parser import onnxscript +from onnxscript import ir def assert_isomorphic(graph_or_function_1, graph_or_function_2): @@ -64,7 +68,7 @@ def to_map(proto): return to_map(proto1) == to_map(proto2) -def _same_tensor(tp1, tp2): +def _same_tensor(tp1: onnx.TensorProto, tp2: onnx.TensorProto): if tp1.dims != tp2.dims: return False if not _same_optional("data_type", tp1, tp2): @@ -72,18 +76,11 @@ def _same_tensor(tp1, tp2): # Segmented representation not supported yet if tp1.HasField("segment") or tp2.HasField("segment"): return False - if tp1.float_data != tp2.float_data: - return False - if tp1.int32_data != tp2.int32_data: - return False - if tp1.string_data != tp2.string_data: - return False - if tp1.int64_data != tp2.int64_data: - return False - if tp1.uint64_data != tp2.uint64_data: - return False - if tp1.double_data != tp2.double_data: - return False + if tp1.data_location == tp2.data_location == tp1.DataLocation.DEFAULT: + tensor1 = ir.from_proto(tp1) + tensor2 = ir.from_proto(tp2) + if not np.array_equal(tensor1.numpy(), tensor2.numpy(), equal_nan=True): + return False # Ignore name for comparison: # if not _same_optional("name", tp1, tp2): return False if not _same_optional("doc_string", tp1, tp2): @@ -372,7 +369,9 @@ def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]: def assert_onnx_proto_equal( - a: google.protobuf.message.Message | Any, b: google.protobuf.message.Message | Any + actual: google.protobuf.message.Message | Any, + expected: google.protobuf.message.Message | Any, + ignore_initializer_value_proto: bool = False, ) -> None: """Assert that two ONNX protos are equal. @@ -384,18 +383,31 @@ def assert_onnx_proto_equal( compared disregarding the order of their elements. Args: - a: The first ONNX proto. - b: The second ONNX proto. + actual: The first ONNX proto. + expected: The second ONNX proto. + ignore_initializer_value_proto: Ignore value protos for initializers if there + are extra ones in the actual proto. """ - assert type(a) == type(b), f"Type not equal: {type(a)} != {type(b)}" # pylint: disable=unidiomatic-typecheck + assert type(actual) is type(expected), ( + f"Type not equal: {type(actual)} != {type(expected)}" + ) - a_fields = {field.name: value for field, value in a.ListFields()} - b_fields = {field.name: value for field, value in b.ListFields()} + a_fields = {field.name: value for field, value in actual.ListFields()} + b_fields = {field.name: value for field, value in expected.ListFields()} all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys())) - for field in all_fields: + if isinstance(actual, onnx.GraphProto) and isinstance(expected, onnx.GraphProto): + actual_initializer_names = {i.name for i in actual.initializer} + expected_initializer_names = {i.name for i in expected.initializer} + else: + actual_initializer_names = set() + expected_initializer_names = set() + + # Record and report all errors + errors = [] + for field in all_fields: # pylint: disable=too-many-nested-blocks # Obtain the default value if the field is not set. This way we can compare the two fields. - a_value = getattr(a, field) - b_value = getattr(b, field) + a_value = getattr(actual, field) + b_value = getattr(expected, field) if ( isinstance(a_value, Sequence) and isinstance(b_value, Sequence) @@ -411,6 +423,22 @@ def assert_onnx_proto_equal( a_keys = [_opset_import_key(opset_import) for opset_import in a_value] b_keys = [_opset_import_key(opset_import) for opset_import in b_value] elif field == "value_info": + if ( + ignore_initializer_value_proto + and isinstance(actual, onnx.GraphProto) + and isinstance(expected, onnx.GraphProto) + ): + # Filter out initializers from the value_info list + a_value = [ + value_info + for value_info in a_value + if value_info.name not in actual_initializer_names + ] + b_value = [ + value_info + for value_info in b_value + if value_info.name not in expected_initializer_names + ] a_value = sorted(a_value, key=_value_info_key) b_value = sorted(b_value, key=_value_info_key) a_keys = [_value_info_key(value_info) for value_info in a_value] @@ -422,51 +450,62 @@ def assert_onnx_proto_equal( b_keys = [_function_key(functions) for functions in b_value] if a_keys != b_keys: - keys_only_in_a = set(a_keys) - set(b_keys) - keys_only_in_b = set(b_keys) - set(a_keys) + keys_only_in_actual = set(a_keys) - set(b_keys) + keys_only_in_expected = set(b_keys) - set(a_keys) error_message = ( - f"Field {field} not equal: keys_only_in_a={keys_only_in_a}, keys_only_in_b={keys_only_in_b}. " + f"Field {field} not equal: keys_only_in_actual={keys_only_in_actual}, keys_only_in_expected={keys_only_in_expected}. " f"Field type: {type(a_value)}. " f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}" ) - raise AssertionError(error_message) - if len(a_value) != len(b_value): + errors.append(error_message) + elif len(a_value) != len(b_value): error_message = ( f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} " f"Field type: {type(a_value)}" ) - raise AssertionError(error_message) - # Check every element - for i in range(len(a_value)): # pylint: disable=consider-using-enumerate - a_value_i = a_value[i] - b_value_i = b_value[i] - if isinstance(a_value_i, google.protobuf.message.Message) and isinstance( - b_value_i, google.protobuf.message.Message - ): - try: - assert_onnx_proto_equal(a_value_i, b_value_i) - except AssertionError as e: - error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}" - raise AssertionError(error_message) from e - elif a_value_i != b_value_i: - if ( - isinstance(a_value_i, float) - and isinstance(b_value_i, float) - and math.isnan(a_value_i) - and math.isnan(b_value_i) - ): - # Consider NaNs equal - continue - error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}" - for line in difflib.ndiff( - str(a_value_i).splitlines(), str(b_value_i).splitlines() - ): - error_message += "\n" + line - raise AssertionError(error_message) + errors.append(error_message) + else: + # Check every element + for i in range(len(a_value)): # pylint: disable=consider-using-enumerate + actual_value_i = a_value[i] + expected_value_i = b_value[i] + if isinstance( + actual_value_i, google.protobuf.message.Message + ) and isinstance(expected_value_i, google.protobuf.message.Message): + try: + assert_onnx_proto_equal( + actual_value_i, + expected_value_i, + ignore_initializer_value_proto=ignore_initializer_value_proto, + ) + except AssertionError as e: + error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}, actual_value_i: {actual_value_i}, expected_value_i: {expected_value_i}" + error_message = ( + str(e) + "\n\nCaused by the above error\n\n" + error_message + ) + errors.append(error_message) + elif actual_value_i != expected_value_i: + if ( + isinstance(actual_value_i, float) + and isinstance(expected_value_i, float) + and math.isnan(actual_value_i) + and math.isnan(expected_value_i) + ): + # Consider NaNs equal + continue + error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}" + for line in difflib.ndiff( + str(actual_value_i).splitlines(), + str(expected_value_i).splitlines(), + ): + error_message += "\n" + line + errors.append(error_message) elif isinstance(a_value, google.protobuf.message.Message) and isinstance( b_value, google.protobuf.message.Message ): - assert_onnx_proto_equal(a_value, b_value) + assert_onnx_proto_equal( + a_value, b_value, ignore_initializer_value_proto=ignore_initializer_value_proto + ) elif a_value != b_value: if ( isinstance(a_value, float) @@ -476,5 +515,11 @@ def assert_onnx_proto_equal( ): # Consider NaNs equal continue - error_message = f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}" - raise AssertionError(error_message) + error_message = ( + f"Field {field} not equal. field_actual: {a_value}, field_expected: {b_value}" + ) + errors.append(error_message) + if errors: + raise AssertionError( + f"Protos not equal: {type(actual)} != {type(expected)}\n" + "\n".join(errors) + ) diff --git a/onnxscript/tools/__init__.py b/onnxscript/tools/__init__.py new file mode 100644 index 0000000000..862c45ce31 --- /dev/null +++ b/onnxscript/tools/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/onnxscript/tools/memory_peak.py b/onnxscript/tools/memory_peak.py new file mode 100644 index 0000000000..1f9a7e319a --- /dev/null +++ b/onnxscript/tools/memory_peak.py @@ -0,0 +1,244 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import multiprocessing +import os + + +def get_memory_rss(pid: int) -> int: + """ + Returns the physical memory used by a process. + + Args: + pid: Process id, current one is `os.getpid()`. + + Returns: + Physical memory. + + It relies on the module *psutil*. + """ + import psutil + + process = psutil.Process(pid) + mem = process.memory_info().rss + return mem + + +class Monitor: + def __init__(self): + self.max_peak: float = 0 + self.average: float = 0 + self.n_measures: int = 0 + self.begin: float = 0 + self.end: float = 0 + + def to_dict(self, unit: int = 1) -> dict[str, float]: + funit = float(unit) + return dict( + peak=self.max_peak / funit, + mean=self.average * 1.0 / self.n_measures / funit, + n=self.n_measures / funit, + begin=self.begin / funit, + end=self.end / funit, + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(peak={self.max_peak}, " + f"average={self.average}, n={self.n_measures})" + ) + + def update(self, mem: float): + if self.n_measures == 0: + self.begin = mem + self.max_peak = max(mem, self.max_peak) + self.average += mem + self.end = mem + self.n_measures += 1 + + def send(self, conn): + conn.send(self.max_peak) + conn.send(self.average) + conn.send(self.n_measures) + conn.send(self.begin) + conn.send(self.end) + + @classmethod + def recv(cls, conn) -> Monitor: + m = cls() + m.max_peak = conn.recv() + m.average = conn.recv() + m.n_measures = conn.recv() + m.begin = conn.recv() + m.end = conn.recv() + return m + + +def _process_memory_spy(conn): + # Sends the value it started. + conn.send(-2) + + # process id to spy on + pid = conn.recv() + + # delay between two measures + timeout = conn.recv() + + # do CUDA + cuda = conn.recv() + + import psutil + + process = psutil.Process(pid) + + if cuda: + from pynvml import ( # type: ignore[import-not-found] + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlInit, + nvmlShutdown, + ) + + nvmlInit() + n_gpus = nvmlDeviceGetCount() + handles = [nvmlDeviceGetHandleByIndex(i) for i in range(n_gpus)] + + def gpu_used(): + return [nvmlDeviceGetMemoryInfo(h).used for h in handles] + + gpus = [Monitor() for i in range(n_gpus)] + else: + gpus = [] + + cpu = Monitor() + + conn.send(-2) + + # loop + while True: + mem = process.memory_info().rss + cpu.update(mem) + if cuda: + for r, g in zip(gpu_used(), gpus): + g.update(r) + if conn.poll(timeout=timeout): + code = conn.recv() + if code == -3: + break + + # final iteration + end = process.memory_info().rss + cpu.update(end) + if cuda: + for r, g in zip(gpu_used(), gpus): + g.update(r) + + # send + cpu.send(conn) + conn.send(len(gpus)) + for g in gpus: + g.send(conn) + if cuda: + nvmlShutdown() + conn.close() + + +class MemorySpy: + """ + Information about the spy. It class method `start`. + Method `stop` can be called to end the measure. + + Args: + pid: process id of the process to spy on + delay: spy on every delay seconds + cuda: enable cuda monitoring + """ + + def __init__(self, pid: int, delay: float = 0.01, cuda: bool = False): + self.pid = pid + self.delay = delay + self.cuda = cuda + self.start() + + def start(self) -> MemorySpy: + """Starts another process and tells it to spy.""" + self.parent_conn, self.child_conn = multiprocessing.Pipe() + self.child_process = multiprocessing.Process( + target=_process_memory_spy, args=(self.child_conn,) + ) + self.child_process.start() + data = self.parent_conn.recv() + if data != -2: + raise RuntimeError(f"The child processing is supposed to send -2 not {data}.") + self.parent_conn.send(self.pid) + self.parent_conn.send(self.delay) + self.parent_conn.send(1 if self.cuda else 0) + data = self.parent_conn.recv() + if data != -2: + raise RuntimeError( + f"The child processing is supposed to send -2 again not {data}." + ) + return self + + def stop(self) -> dict[str, list[Monitor]]: + """Stops spying on.""" + self.parent_conn.send(-3) + + cpu = [Monitor.recv(self.parent_conn)] + + n_gpus = self.parent_conn.recv() + gpus = [] + for _ in range(n_gpus): + gpus.append(Monitor.recv(self.parent_conn)) + + self.parent_conn.close() + self.child_process.join() + res = dict(cpu=cpu) + if self.cuda: + res["gpus"] = gpus + return res + + +def start_spying_on( + pid: int | None = None, delay: float = 0.01, cuda: bool = False +) -> MemorySpy: + """Starts the memory spy. The function starts another + process spying on the one sent as an argument. + + Example:: + + .. code-block:: python + + from onnxscript.tools.memory_peak import start_spying_on, flatten + + p = start_spying_on() + # ... + # code to measure + # ... + stat = p.stop() + print(stat) + print(flatten(stat)) + + Args: + pid: process id to spy or the the current one. + delay: delay between two measures. + cuda: True or False to get memory for cuda devices + """ + if pid is None: + pid = os.getpid() + return MemorySpy(pid, delay, cuda) + + +def flatten(ps: dict[str, list[Monitor]], prefix: str = "") -> dict[str, float]: + """Flattens a dictionary produced by :meth:`MemorySpy.stop`.""" + obs = ps["cpu"][0].to_dict(unit=2**20) + if "gpus" in ps: + for i, g in enumerate(ps["gpus"]): + for k, v in g.to_dict(unit=2**20).items(): + obs[f"gpu{i}_{k}"] = v + if prefix: + obs = {f"{prefix}{k}": v for k, v in obs.items()} + return obs diff --git a/onnxscript/tools/memory_peak_test.py b/onnxscript/tools/memory_peak_test.py new file mode 100644 index 0000000000..71bbc75c8f --- /dev/null +++ b/onnxscript/tools/memory_peak_test.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import os +import sys +import time +import unittest + +import numpy as np +import torch + +import onnxscript.tools.memory_peak + + +class TestMemoryPeak(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="other test are failing") + def test_memory(self): + mem = onnxscript.tools.memory_peak.get_memory_rss(os.getpid()) + self.assertIsInstance(mem, int) + + @unittest.skipIf(sys.platform == "win32", reason="other test are failing") + def test_spy(self): + p = onnxscript.tools.memory_peak.start_spying_on() + res = [] + for i in range(10): + time.sleep(0.005) + res.append(np.empty(i * 1000000)) + del res + time.sleep(0.02) + pres = p.stop() + self.assertIsInstance(pres, dict) + self.assertLessEqual(pres["cpu"][0].end, pres["cpu"][0].max_peak) + self.assertLessEqual(pres["cpu"][0].begin, pres["cpu"][0].max_peak) + self.assertIsInstance(pres["cpu"][0].to_dict(), dict) + + @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not here") + def test_spy_cuda(self): + p = onnxscript.tools.memory_peak.start_spying_on(cuda=True) + res = [] + for i in range(10): + time.sleep(0.005) + res.append(np.empty(i * 1000000)) + del res + time.sleep(0.02) + pres = p.stop() + self.assertIsInstance(pres, dict) + self.assertIsInstance(pres["cpu"], list) + self.assertEqual(len(pres["cpu"]), 1) + self.assertIsInstance(pres["gpus"], list) + self.assertLessEqual(pres["cpu"][0].end, pres["cpu"][0].max_peak) + self.assertLessEqual(pres["cpu"][0].begin, pres["cpu"][0].max_peak) + self.assertIn("gpus", pres) + self.assertLessEqual(pres["gpus"][0].end, pres["gpus"][0].max_peak) + self.assertLessEqual(pres["gpus"][0].begin, pres["gpus"][0].max_peak) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py new file mode 100644 index 0000000000..ed4648916b --- /dev/null +++ b/onnxscript/tools/transformers_models/__init__.py @@ -0,0 +1,190 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import random +from typing import Any, Sequence + +import onnx +import onnx.inliner +import torch + +import onnxscript.optimizer +import onnxscript.rewriter + + +def export_to_onnx( + model: Any, + *args: Sequence[Any], + optimize: bool = True, + export_api: bool = True, + no_grad: bool = False, +) -> onnx.ModelProto: + """ + Export a model to ONNX. + If optimize is True, it calls *onnxscript.optimizer.optimize*, + *onnxscript.rewriter.rewriter*, *onnx.inliner.inline_local_functions*. + If *export_api* is True, the function uses ``torch.onnx.export`` + and not ``torch.onnx.dynamo_export``. + """ + if no_grad: + with torch.no_grad(): + if export_api: + prog = torch.onnx.export(model, args, dynamo=True) # pylint: disable=no-value-for-parameter + else: + prog = torch.onnx.dynamo_export(model, *args) + else: + if export_api: + prog = torch.onnx.export(model, args, dynamo=True) # pylint: disable=no-value-for-parameter + else: + prog = torch.onnx.dynamo_export(model, *args) + assert prog is not None + model = prog.model + if optimize: + model = onnxscript.optimizer.optimize( + model, + num_iterations=2, + ) + model = onnxscript.rewriter.rewrite(model) + model_proto = onnxscript.ir.to_proto(model) + model_proto = onnx.inliner.inline_local_functions(model_proto) + return model_proto + + +def ids_tensor( + shape: Sequence[int], + vocab_size: int, + rng: random.Random | None = None, + name: str | None = None, +): + """Creates a random int32 tensor of the shape within the vocab size.""" + del name # unused + + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() + + +def get_input_dims_for_llm( + dynamic_shapes: bool, warmup: int, repeat: int +) -> list[tuple[int, int]]: + """Returns input dimensions for model such as llama, phi, ...""" + if not dynamic_shapes: + return [(2, 1024)] * (warmup + repeat) + w = [(2, 1024), (3, 1024), (2, 1096)] * warmup + w = w[:warmup] + r = [(2, 1024), (3, 1024), (4, 1024), (2, 1096), (2, 1112)] * repeat + r = r[:repeat] + return w + r + + +def get_model_and_inputs( + model: str, + config: str, + dynamic_shapes: bool, + device: str = "cpu", + num_hidden_layers: int = 1, + with_mask: bool = True, + implementation: str = "eager", + dtype: str | None = None, + warmup: int = 5, + repeat: int = 10, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict | None]: + """ + Returns a model and a couple of dummy inputs. + + Args: + model: model name, 'phi', 'llama', 'phi3', ... + config: 'small', 'medium', 'large', ... + dynamic_shapes: dynamic or static shapes + device: 'cpu' or 'cuda' + num_hidden_layers: Number of hidden layers. + with_mask: One input or two inputs. + implementation: eager or sdpa + warmup: Number of inputs to generate. + repeat: Number of inputs to generate for repeat. + dtype: If specified, cast the model and the inputs into this type. + + Returns: + model and list of inputs + """ + if model == "llama": + import onnxscript.tools.transformers_models.llama as m_llama + + tmodel, inputs, dynamic_shapes_def = m_llama.get_llama_model_from_config( + warmup=warmup, + repeat=repeat, + implementation=implementation, + with_mask=with_mask, + num_hidden_layers=num_hidden_layers, + dynamic_shapes=dynamic_shapes, + config=config, + ) + + elif model == "mistral": + import onnxscript.tools.transformers_models.mistral as m_mistral + + tmodel, inputs, dynamic_shapes_def = m_mistral.get_mistral_model_from_config( + warmup=warmup, + repeat=repeat, + implementation=implementation, + with_mask=with_mask, + num_hidden_layers=num_hidden_layers, + dynamic_shapes=dynamic_shapes, + config=config, + ) + + elif model == "phi": + import onnxscript.tools.transformers_models.phi as m_phi + + tmodel, inputs, dynamic_shapes_def = m_phi.get_phi_model_from_config( + warmup=warmup, + repeat=repeat, + implementation=implementation, + with_mask=with_mask, + num_hidden_layers=num_hidden_layers, + dynamic_shapes=dynamic_shapes, + config=config, + ) + + elif model == "phi3": + import onnxscript.tools.transformers_models.phi3 as m_phi3 + + tmodel, inputs, dynamic_shapes_def = m_phi3.get_phi3_model_from_config( + warmup=warmup, + repeat=repeat, + implementation=implementation, + with_mask=with_mask, + num_hidden_layers=num_hidden_layers, + dynamic_shapes=dynamic_shapes, + config=config, + ) + + else: + raise ValueError(f"Model {model!r} is unknown.") + + if dtype is not None: + dt = getattr(torch, dtype) + tmodel = tmodel.to(dt) + inputs = [ + tuple((i if i.dtype in {torch.int64, torch.int32} else i.to(dt)) for i in inp) + for inp in inputs + ] + + if device == "cuda": + tmodel = tmodel.to("cuda") + inputs = [tuple(i.to("cuda") for i in inp) for inp in inputs] + + return tmodel, inputs, dynamic_shapes_def diff --git a/onnxscript/tools/transformers_models/llama.py b/onnxscript/tools/transformers_models/llama.py new file mode 100644 index 0000000000..9b1337167f --- /dev/null +++ b/onnxscript/tools/transformers_models/llama.py @@ -0,0 +1,168 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +from typing import Any, Sequence + +import torch + +import onnxscript.tools.transformers_models + + +def get_llama_model( + input_dims: Sequence[tuple[int, int]] = ((2, 8), (4, 7), (9, 15)), + hidden_size: int = 16, + num_hidden_layers: int = 1, + vocab_size: int = 1024, + intermediate_size: int = 16, + max_position_embeddings: int = 1024, + num_attention_heads: int = 2, + _attn_implementation: str = "eager", # needed value to remove graph breaks + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model. + See `LlamaConfig + `_. + The parameters are chosen for a unit test configuration. + """ + from transformers import LlamaConfig + from transformers.models.llama.modeling_llama import LlamaModel + + dynamic_shapes = {0: {0: "batch", 1: "length"}} + if with_mask: + dynamic_shapes.update({1: {0: "batch", 1: "length"}}) + + config = LlamaConfig( + num_hidden_layers=num_hidden_layers, + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + num_attention_heads=num_attention_heads, + ) + if _attn_implementation: + config._attn_implementation = _attn_implementation # pylint: disable=protected-access + + if with_mask: + + class LlamaModelWrapperMask(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = LlamaModel(config) + + def forward(self, input_ids, attention_mask): + model_output = self.model( + input_ids, attention_mask=attention_mask, use_cache=False + ) + return model_output.to_tuple() + + def generate_example_inputs_mask(batch: int, seq: int, vocab_size: int): + input_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch, seq], vocab_size + ) + input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32)) + assert input_mask.dtype == torch.float32 + return input_ids, input_mask + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs_mask(b, s, vocab_size)) + + return LlamaModelWrapperMask(config), example_args_collection, dynamic_shapes + + # no mask + + class LlamaModelWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = LlamaModel(config) + + def forward(self, input_ids): + model_output = self.model(input_ids, use_cache=False) + return model_output.to_tuple() + + def generate_example_inputs(batch: int, seq: int, vocab_size: int): + input_ids = onnxscript.tools.transformers_models.ids_tensor([batch, seq], vocab_size) + return (input_ids,) + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs(b, s, vocab_size)) + + return LlamaModelWrapper(config), example_args_collection, dynamic_shapes + + +def get_llama_model_from_config( + warmup: int = 5, + repeat: int = 10, + config: str = "small", + num_hidden_layers: int = 1, + implementation: str = "eager", + dynamic_shapes: bool = False, + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model Phi to test or benchmark. + + Args: + warmup: Number of inputs to generate. + repeat: Number of inputs to generate for repeat. + config: small, medium or large + num_hidden_layers: Number of hidden layers. + implementation: eager or sdpa + with_mask: One or two inputs. + dynamic_shapes: dynamic shapes or not + + Returns: + Model and list of inputs. + """ + if config == "small": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=16, + num_hidden_layers=num_hidden_layers, + vocab_size=1024, + intermediate_size=16, + max_position_embeddings=1024, + num_attention_heads=2, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config == "medium": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=1024, + num_hidden_layers=num_hidden_layers, + vocab_size=1024, + intermediate_size=1024, + max_position_embeddings=1024, + num_attention_heads=2, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config in ("large", "default"): + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=4096, + num_hidden_layers=num_hidden_layers, + vocab_size=32000, + intermediate_size=11008, + max_position_embeddings=2048, + num_attention_heads=32, + _attn_implementation=implementation, + with_mask=with_mask, + ) + else: + raise ValueError(f"Unexpected configuration {config!r}.") + + return get_llama_model(**conf_dict) # type: ignore[arg-type] diff --git a/onnxscript/tools/transformers_models/llama_test.py b/onnxscript/tools/transformers_models/llama_test.py new file mode 100644 index 0000000000..5cb3159600 --- /dev/null +++ b/onnxscript/tools/transformers_models/llama_test.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import sys +import unittest + +import numpy as np +import onnxruntime +import torch + +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.llama +from onnxscript._internal.version_utils import ( + has_transformers, + ignore_warnings, + torch_older_than, + transformers_older_than, +) + + +class TestExportLlama(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.5"), reason="fails to export") + @unittest.skipIf( + transformers_older_than("4.41"), reason="cannot mutate tensors with frozen storage" + ) + @ignore_warnings(UserWarning) + def test_llama_export_cpu(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.llama.get_llama_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.5"), reason="fails to export") + @unittest.skipIf( + transformers_older_than("4.41"), reason="cannot mutate tensors with frozen storage" + ) + @ignore_warnings(UserWarning) + def test_llama_export_cpu_export_api(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.llama.get_llama_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx( + model, *input_tensors, export_api=True + ) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + @ignore_warnings(UserWarning) + def test_llama_export_cuda(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.llama.get_llama_model() + ) + input_tensors_cpu = input_tensors_many[0] + model = model.to("cuda") + input_tensors = [i.to("cuda") for i in input_tensors_cpu] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CUDAExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/mistral.py b/onnxscript/tools/transformers_models/mistral.py new file mode 100644 index 0000000000..d053b90571 --- /dev/null +++ b/onnxscript/tools/transformers_models/mistral.py @@ -0,0 +1,238 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +from typing import Any, Sequence + +import torch + +import onnxscript.tools.transformers_models + + +def _prepare_config_and_inputs( + batch_size: int, + seq_length: int, + vocab_size: int, + type_sequence_label_size: int = 2, + type_vocab_size: int = 16, + num_labels: int = 3, + num_choices: int = 4, + use_input_mask: bool = False, + use_token_type_ids: bool = False, + use_labels: bool = False, +) -> tuple[Any, ...]: + input_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], vocab_size + ) + + input_mask = None + if use_input_mask: + input_mask = torch.tril(torch.ones(batch_size, seq_length)) + + token_type_ids = None + if use_token_type_ids: + assert type_vocab_size > 0, "type_vocab_size is null" + token_type_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], type_vocab_size + ) + + sequence_labels = None + token_labels = None + choice_labels = None + if use_labels: + assert type_sequence_label_size > 0, "type_sequence_label_size is null" + assert num_labels > 0, "num_labels is null" + assert num_choices > 0, "num_choices is null" + sequence_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], type_sequence_label_size + ) + token_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], num_labels + ) + choice_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], num_choices + ) + + return ( + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + +def get_mistral_model( + input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)), + hidden_size=32, + num_hidden_layers=2, + vocab_size=99, + intermediate_size=16, + max_position_embeddings=512, + num_attention_heads=2, + num_key_value_heads=2, + sliding_window=4096, + _attn_implementation="eager", # needed value to remove graph breaks + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model. + See `MistralConfig + `_. + The parameters are chosen for a unit test configuration. + """ + from transformers import MistralConfig + from transformers.models.mistral.modeling_mistral import MistralModel + + config = MistralConfig( + num_hidden_layers=num_hidden_layers, + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + sliding_window=sliding_window, + ) + + dynamic_shapes = {0: {0: "batch", 1: "length"}} + if with_mask: + dynamic_shapes.update({1: {0: "batch", 1: "length"}}) + + if _attn_implementation: + config._attn_implementation = _attn_implementation # pylint: disable=protected-access + + def generate_example_inputs(batch: int, seq: int, vocab_size: int, with_mask: bool): + ( + input_ids, + _, # token_type_ids, + input_mask, + _, # sequence_labels, + _, # token_labels, + _, # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=with_mask, + ) + if with_mask: + return input_ids, input_mask + return (input_ids,) + + if with_mask: + + class MistralModelWrapperWithMask(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = MistralModel(config) + + def forward(self, input_ids, attention_mask): + model_output = self.model( + input_ids, attention_mask=attention_mask, use_cache=False + ) + return model_output.to_tuple() + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append( + generate_example_inputs(b, s, vocab_size, with_mask) + ) + + return MistralModelWrapperWithMask(config), example_args_collection, dynamic_shapes + + class MistralModelWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = MistralModel(config) + + def forward(self, input_ids): + model_output = self.model(input_ids, use_cache=False) + return model_output.to_tuple() + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs(b, s, vocab_size, with_mask)) + + return MistralModelWrapper(config), example_args_collection, dynamic_shapes + + +def get_mistral_model_from_config( + warmup: int = 5, + repeat: int = 10, + config: str = "small", + num_hidden_layers: int = 1, + implementation: str = "eager", + dynamic_shapes: bool = False, + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model Phi to test or benchmark. + + Args: + warmup: Number of inputs to generate. + repeat: Number of inputs to generate for repeat. + config: small, medium or large + num_hidden_layers: number of hidden layers + implementation: eager or sdpa + with_mask: One or two inputs. + dynamic_shapes: dynamic shapes or not + + Returns: + Model and list of inputs. + """ + if config == "small": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=32, + num_hidden_layers=num_hidden_layers, + vocab_size=99, + intermediate_size=16, + max_position_embeddings=512, + num_attention_heads=4, + num_key_value_heads=2, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config == "medium": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=1024, + num_hidden_layers=num_hidden_layers, + vocab_size=1024, + intermediate_size=1024, + num_attention_heads=4, + num_key_value_heads=4, + max_position_embeddings=1024, + sliding_window=4096, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config in ("large", "default"): + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=4096, + num_hidden_layers=num_hidden_layers, + vocab_size=32000, + intermediate_size=14336, + num_attention_heads=32, + num_key_value_heads=8, + max_position_embeddings=131072, + sliding_window=4096, + _attn_implementation=implementation, + with_mask=with_mask, + ) + else: + raise ValueError(f"Unexpected configuration {config!r}.") + + return get_mistral_model(**conf_dict) # type: ignore[arg-type] diff --git a/onnxscript/tools/transformers_models/mistral_test.py b/onnxscript/tools/transformers_models/mistral_test.py new file mode 100644 index 0000000000..2883fbd32e --- /dev/null +++ b/onnxscript/tools/transformers_models/mistral_test.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import sys +import unittest + +import numpy as np +import onnxruntime +import torch + +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.mistral +from onnxscript._internal.version_utils import ( + has_transformers, + ignore_warnings, + torch_older_than, + transformers_older_than, +) + + +class TestExportMistral(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + @unittest.skipIf( + transformers_older_than("4.42"), reason="cannot mutate tensors with frozen storage" + ) + @ignore_warnings(UserWarning) + def test_mistral_export_cpu(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.mistral.get_mistral_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.5"), reason="fails to export") + @unittest.skipIf( + transformers_older_than("4.42"), reason="cannot mutate tensors with frozen storage" + ) + @ignore_warnings(UserWarning) + def test_mistral_export_cpu_export_api(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.mistral.get_mistral_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx( + model, *input_tensors, export_api=True + ) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @ignore_warnings(UserWarning) + def test_phi_export_cuda(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.mistral.get_mistral_model() + ) + input_tensors_cpu = input_tensors_many[0] + model = model.to("cuda") + input_tensors = [i.to("cuda") for i in input_tensors_cpu] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CUDAExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/phi.py b/onnxscript/tools/transformers_models/phi.py new file mode 100644 index 0000000000..f1cb88edd0 --- /dev/null +++ b/onnxscript/tools/transformers_models/phi.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +from typing import Any, Sequence + +import torch + +import onnxscript.tools.transformers_models + + +def _prepare_config_and_inputs( + batch_size: int, + seq_length: int, + vocab_size: int, + type_sequence_label_size: int = 2, + type_vocab_size: int = 16, + num_labels: int = 3, + num_choices: int = 4, + use_input_mask: bool = False, + use_token_type_ids: bool = False, + use_labels: bool = False, +) -> tuple[Any, ...]: + input_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], vocab_size + ) + + input_mask = None + if use_input_mask: + input_mask = torch.tril(torch.ones(batch_size, seq_length)) + + token_type_ids = None + if use_token_type_ids: + assert type_vocab_size > 0, "type_vocab_size is null" + token_type_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], type_vocab_size + ) + + sequence_labels = None + token_labels = None + choice_labels = None + if use_labels: + assert type_sequence_label_size > 0, "type_sequence_label_size is null" + assert num_labels > 0, "num_labels is null" + assert num_choices > 0, "num_choices is null" + sequence_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], type_sequence_label_size + ) + token_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], num_labels + ) + choice_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], num_choices + ) + + return ( + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + +def get_phi_model( + input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)), + hidden_size: int = 32, + num_hidden_layers: int = 2, + vocab_size: int = 99, + intermediate_size: int = 16, + max_position_embeddings: int = 512, + num_attention_heads: int = 4, + num_key_value_heads: int = 2, + _attn_implementation: str = "eager", # needed value to remove graph breaks + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model. + See `PhiConfig + `_. + The parameters are chosen for a unit test configuration from `test_modeling_phi.py + `_. + """ + from transformers import PhiConfig + from transformers.models.phi.modeling_phi import PhiModel + + dynamic_shapes = {0: {0: "batch", 1: "length"}} + if with_mask: + dynamic_shapes.update({1: {0: "batch", 1: "length"}}) + + config = PhiConfig( + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + vocab_size=vocab_size, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + ) + if _attn_implementation: + config._attn_implementation = _attn_implementation # pylint: disable=protected-access + + if with_mask: + + class PhiModelWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = PhiModel(config) + + def forward(self, input_ids, attention_mask): + model_output = self.model( + input_ids, attention_mask=attention_mask, use_cache=False + ) + return model_output.to_tuple() + + def generate_example_inputs(batch: int, seq: int, vocab_size: int): + ( + input_ids, + _, # token_type_ids, + input_mask, + _, # sequence_labels, + _, # token_labels, + _, # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=True, + ) + return input_ids, input_mask + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs(b, s, vocab_size)) + + return PhiModelWrapper(config), example_args_collection, dynamic_shapes + + # no mask + + class PhiModelWrapperNoMask(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = PhiModel(config) + + def forward(self, input_ids): + model_output = self.model(input_ids, use_cache=False) + return model_output.to_tuple() + + def generate_example_inputs_no_mask(batch: int, seq: int, vocab_size: int): + ( + input_ids, + _, # token_type_ids, + _, # input_mask, + _, # sequence_labels, + _, # token_labels, + _, # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=True, + ) + return (input_ids,) + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs_no_mask(b, s, vocab_size)) + + return PhiModelWrapperNoMask(config), example_args_collection, dynamic_shapes + + +def get_phi_model_from_config( + warmup: int = 5, + repeat: int = 10, + config: str = "small", + num_hidden_layers: int = 1, + implementation: str = "eager", + dynamic_shapes: bool = False, + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model Phi to test or benchmark. + + Args: + warmup: Number of inputs to generate. + repeat: Number of inputs to generate for repeat. + config: small, medium or large + num_hidden_layers: number of hidden layers + implementation: eager or sdpa + with_mask: One or two inputs. + dynamic_shapes: dynamic shapes or not + + Returns: + Model and list of inputs. + """ + if config == "small": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=32, + num_hidden_layers=num_hidden_layers, + vocab_size=99, + intermediate_size=16, + max_position_embeddings=512, + num_attention_heads=4, + num_key_value_heads=2, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config == "medium": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=1024, + num_hidden_layers=num_hidden_layers, + vocab_size=1024, + intermediate_size=1024, + num_attention_heads=4, + num_key_value_heads=4, + max_position_embeddings=1024, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config in ("large", "default"): + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=2048, + num_hidden_layers=num_hidden_layers, + vocab_size=51200, + intermediate_size=8192, + num_attention_heads=32, + num_key_value_heads=None, + max_position_embeddings=2048, + _attn_implementation=implementation, + with_mask=with_mask, + ) + else: + raise ValueError(f"Unexpected configuration {config!r}.") + + return get_phi_model(**conf_dict) # type: ignore[arg-type] diff --git a/onnxscript/tools/transformers_models/phi3.py b/onnxscript/tools/transformers_models/phi3.py new file mode 100644 index 0000000000..f5bf7beb54 --- /dev/null +++ b/onnxscript/tools/transformers_models/phi3.py @@ -0,0 +1,259 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any, Sequence + +import torch + +import onnxscript.tools.transformers_models + + +def has_phi3() -> bool: + """Tells if package *transformers* contains the phi3 model.""" + try: + from transformers import Phi3Config + + assert Phi3Config + except ImportError: + return False + return True + + +def _prepare_config_and_inputs( + batch_size: int, + seq_length: int, + vocab_size: int, + type_sequence_label_size: int = 2, + type_vocab_size: int = 16, + num_labels: int = 3, + num_choices: int = 4, + use_input_mask: bool = False, + use_token_type_ids: bool = False, + use_labels: bool = False, +) -> tuple[Any, ...]: + input_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], vocab_size + ) + + input_mask = None + if use_input_mask: + input_mask = torch.tril(torch.ones(batch_size, seq_length)) + + token_type_ids = None + if use_token_type_ids: + assert type_vocab_size > 0, "type_vocab_size is null" + token_type_ids = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], type_vocab_size + ) + + sequence_labels = None + token_labels = None + choice_labels = None + if use_labels: + assert type_sequence_label_size > 0, "type_sequence_label_size is null" + assert num_labels > 0, "num_labels is null" + assert num_choices > 0, "num_choices is null" + sequence_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], type_sequence_label_size + ) + token_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size, seq_length], num_labels + ) + choice_labels = onnxscript.tools.transformers_models.ids_tensor( + [batch_size], num_choices + ) + + return ( + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + +def get_phi3_model( + input_dims: Sequence[tuple[int, int]] = ((13, 7), (14, 7), (15, 8)), + hidden_size: int = 32, + num_hidden_layers: int = 2, + vocab_size: int = 99, + intermediate_size: int = 16, + max_position_embeddings: int = 512, + num_attention_heads: int = 4, + num_key_value_heads: int = 2, + _attn_implementation: str = "eager", # needed value to remove graph breaks + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model. + See `PhiConfig + `_. + The parameters are chosen for a unit test configuration from `test_modeling_phi.py + `_. + """ + from transformers import Phi3Config, Phi3Model + + dynamic_shapes = {0: {0: "batch", 1: "length"}} + if with_mask: + dynamic_shapes.update({1: {0: "batch", 1: "length"}}) + + config = Phi3Config( + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + vocab_size=vocab_size, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + pad_token_id=min(32000, vocab_size - 1), + ) + if _attn_implementation: + config._attn_implementation = _attn_implementation # pylint: disable=protected-access + + if with_mask: + + class Phi3ModelWrapperNoMask(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = Phi3Model(config) + + def forward(self, input_ids, attention_mask): + model_output = self.model( + input_ids, attention_mask=attention_mask, use_cache=False + ) + return model_output.to_tuple() + + def generate_example_inputs_no_mask(batch: int, seq: int, vocab_size: int): + ( + input_ids, + _, # token_type_ids, + input_mask, + _, # sequence_labels, + _, # token_labels, + _, # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=True, + ) + return input_ids, input_mask + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs_no_mask(b, s, vocab_size)) + + return Phi3ModelWrapperNoMask(config), example_args_collection, dynamic_shapes + + # no mask + + class Phi3ModelWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = Phi3Model(config) + + def forward(self, input_ids): + model_output = self.model(input_ids, use_cache=False) + return model_output.to_tuple() + + def generate_example_inputs(batch: int, seq: int, vocab_size: int): + ( + input_ids, + *_, + # token_type_ids, + # input_mask, + # sequence_labels, + # token_labels, + # choice_labels, + ) = _prepare_config_and_inputs( + batch_size=batch, + seq_length=seq, + vocab_size=vocab_size, + use_input_mask=True, + ) + return (input_ids,) + + example_args_collection = [] + for b, s in input_dims: + example_args_collection.append(generate_example_inputs(b, s, vocab_size)) + + return Phi3ModelWrapper(config), example_args_collection, dynamic_shapes + + +def get_phi3_model_from_config( + warmup: int = 5, + repeat: int = 10, + config: str = "small", + num_hidden_layers: int = 1, + implementation: str = "eager", + dynamic_shapes: bool = False, + with_mask: bool = True, +) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: + """ + Returns a model Phi to test or benchmark. + + Args: + warmup: Number of inputs to generate. + repeat: Number of inputs to generate for repeat. + config: small, medium or large + num_hidden_layers: number of hidden layers + implementation: eager or sdpa + with_mask: One or two inputs. + dynamic_shapes: dynamic shapes or not + + Returns: + Model and list of inputs. + """ + if config == "small": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=32, + num_hidden_layers=num_hidden_layers, + vocab_size=99, + intermediate_size=16, + max_position_embeddings=512, + num_attention_heads=4, + num_key_value_heads=2, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config == "medium": + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=1024, + num_hidden_layers=num_hidden_layers, + vocab_size=1024, + intermediate_size=1024, + num_attention_heads=4, + num_key_value_heads=4, + max_position_embeddings=1024, + _attn_implementation=implementation, + with_mask=with_mask, + ) + elif config in ("large", "default"): + conf_dict = dict( + input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( + dynamic_shapes, warmup, repeat + ), + hidden_size=2048, + num_hidden_layers=num_hidden_layers, + vocab_size=51200, + intermediate_size=8192, + num_attention_heads=32, + num_key_value_heads=None, + max_position_embeddings=2048, + _attn_implementation=implementation, + with_mask=with_mask, + ) + else: + raise ValueError(f"Unexpected configuration {config!r}.") + + return get_phi3_model(**conf_dict) # type: ignore[arg-type] diff --git a/onnxscript/tools/transformers_models/phi3_test.py b/onnxscript/tools/transformers_models/phi3_test.py new file mode 100644 index 0000000000..db47b7d1f1 --- /dev/null +++ b/onnxscript/tools/transformers_models/phi3_test.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import sys +import unittest + +import numpy as np +import onnxruntime +import torch + +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.phi3 +from onnxscript._internal.version_utils import ( + has_transformers, + ignore_warnings, + torch_older_than, +) + +has_phi3 = onnxscript.tools.transformers_models.phi3.has_phi3 + + +class TestExportPhi3(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(not has_phi3(), reason="transformers is not recent enough") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + @ignore_warnings(UserWarning) + def test_phi3_export_cpu(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(not has_phi3(), reason="transformers is not recent enough") + @unittest.skipIf(torch_older_than("2.5"), reason="fails to export") + @ignore_warnings(UserWarning) + def test_phi3_export_cpu_export_api(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model() + ) + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx( + model, *input_tensors, export_api=True + ) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(not has_phi3(), reason="transformers is not recent enough") + @ignore_warnings(UserWarning) + def test_phi3_export_cuda(self): + model, input_tensors_many, _ = ( + onnxscript.tools.transformers_models.phi3.get_phi3_model() + ) + input_tensors_cpu = input_tensors_many[0] + model = model.to("cuda") + input_tensors = [i.to("cuda") for i in input_tensors_cpu] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CUDAExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/tools/transformers_models/phi_test.py b/onnxscript/tools/transformers_models/phi_test.py new file mode 100644 index 0000000000..9b88203084 --- /dev/null +++ b/onnxscript/tools/transformers_models/phi_test.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import sys +import unittest + +import numpy as np +import onnxruntime +import torch + +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.phi +from onnxscript._internal.version_utils import ( + has_transformers, + ignore_warnings, + torch_older_than, +) + + +class TestExportPhi(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.6"), reason="fails to export") + @ignore_warnings(UserWarning) + def test_phi_export_cpu(self): + model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model() + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.6"), reason="fails to export") + @ignore_warnings(UserWarning) + def test_phi_export_cpu_export_api(self): + model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model() + input_tensors = input_tensors_many[0] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx( + model, *input_tensors, export_api=True + ) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CPUExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().numpy(), results[0], atol=1e-5) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @ignore_warnings(UserWarning) + def test_phi_export_cuda(self): + model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model() + input_tensors_cpu = input_tensors_many[0] + model = model.to("cuda") + input_tensors = [i.to("cuda") for i in input_tensors_cpu] + expected = model(*input_tensors) + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) + names = [i.name for i in proto.graph.input] + np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] + feeds = dict(zip(names, np_input_tensors)) + sess = onnxruntime.InferenceSession( + proto.SerializeToString(), providers=["CUDAExecutionProvider"] + ) + results = sess.run(None, feeds) + np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 53b640ab71..fb7b8a370d 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import collections @@ -65,7 +63,22 @@ def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeTy # A sorted list of all type strings used in an OpSchema ALL_TENSOR_TYPE_STRINGS = tuple( - sorted(tensor_type.to_string() for tensor_type in onnx_types.tensor_type_registry.values()) + sorted( + tensor_type.to_string() + for tensor_type in onnx_types.tensor_type_registry.values() + # Skip FLOAT4E2M1 for versions older than 1.18, and FLOAT8E8M0 for versions older than 1.19 + # TODO(after onnx requirement bump): Remove this check + if ( + not ( + not hasattr(onnx.TensorProto, "FLOAT4E2M1") + and tensor_type == onnx_types.FLOAT4E2M1 + ) + and not ( + not hasattr(onnx.TensorProto, "FLOAT8E8M0") + and tensor_type == onnx_types.FLOAT8E8M0 + ) + ) + ) ) diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 18728ae761..4104eb51dd 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import unittest from typing import Any, List, Optional, Sequence, TypeVar, Union diff --git a/onnxscript/utils/evaluation_utils.py b/onnxscript/utils/evaluation_utils.py index eb93b79cb0..b981fe6708 100644 --- a/onnxscript/utils/evaluation_utils.py +++ b/onnxscript/utils/evaluation_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import pathlib diff --git a/onnxscript/utils/timing_utils.py b/onnxscript/utils/timing_utils.py index 6805a7e19c..98c48dc6da 100644 --- a/onnxscript/utils/timing_utils.py +++ b/onnxscript/utils/timing_utils.py @@ -1,18 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import time import onnx from onnxscript import optimizer -# from onnxscript.rewriter.rules import all_rules - def timeit(f, message): def timed(*args, **kw): ts = time.time() result = f(*args, **kw) te = time.time() - print(f"{message} time: {te-ts}") + print(f"{message} time: {te - ts}") return result return timed diff --git a/onnxscript/utils/utils.py b/onnxscript/utils/utils.py index 26ef525b1c..39457e7ab5 100644 --- a/onnxscript/utils/utils.py +++ b/onnxscript/utils/utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations from typing import Any diff --git a/onnxscript/values.py b/onnxscript/values.py index 31ebe3000d..1897ae14d5 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -1,10 +1,9 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import dataclasses +import functools import inspect import logging import types @@ -12,19 +11,28 @@ from enum import IntFlag from typing import ( # type: ignore[attr-defined] Any, + Callable, ClassVar, + Generic, Optional, Protocol, Sequence, + TypeVar, _GenericAlias, ) import onnx import onnx.defs +from typing_extensions import ParamSpec from onnxscript import converter as converter_module from onnxscript import irbuilder, sourceinfo, type_annotation from onnxscript._internal import ast_utils, deprecation +from onnxscript.ir import _schemas + +_R = TypeVar("_R") +_P = ParamSpec("_P") + _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, @@ -108,8 +116,6 @@ def __getattr__(self, attr: str): raise AttributeError(f"Attribute {attr} not found.") from exc def add_function_def(self, fun): - if fun.name in self.function_defs: - logger.warning("%s: Already defined.", fun.name) self.function_defs[fun.name] = fun def _prepare_inputs(self, _: onnx.defs.OpSchema, *inputs): @@ -120,7 +126,7 @@ def _prepare_inputs(self, _: onnx.defs.OpSchema, *inputs): # TODO: validate the op schema as 'None' values are removed? input_list = list(inputs) while input_list and input_list[-1] is None: - del input_list[-1] + input_list.pop() return input_list @@ -170,10 +176,10 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: """Get the default value of an ONNX attribute.""" if attr_proto.type == onnx.AttributeProto.UNDEFINED: return _EmptyDefault - return onnx.helper.get_attribute_value(attr_proto) + return onnx.helper.get_attribute_value(attr_proto) # noqa: TID251 -def param_schemas_from_op_schema( +def _param_schemas_from_op_schema( op_schema: onnx.defs.OpSchema, ) -> tuple[ParamSchema, ...]: """Get the parameter schemas from an ONNX OpSchema.""" @@ -222,7 +228,7 @@ def _param_schema_from_function_ir_attr(attr: irbuilder.IRAttributeParameter): ) -def param_schemas_from_function_ir( +def _param_schemas_from_function_ir( function_ir: irbuilder.IRFunction, ) -> tuple[ParamSchema, ...]: """Get the parameter schemas from a FunctionIR.""" @@ -259,7 +265,8 @@ def opset(self) -> Opset: ... @property def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... - def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: ... + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: ... class Op(OpLike): @@ -274,18 +281,19 @@ class Op(OpLike): """ def __init__( - self, opset: Opset, opname: str, op_schema: Optional[onnx.defs.OpSchema] = None + self, opset: Opset, name: str, op_schema: Optional[onnx.defs.OpSchema] = None ) -> None: self._opset = opset - self._name = opname - self._op_schema = op_schema or opset[opname] + self._name = name + self._op_schema = op_schema or opset[name] + self._signature: Optional[_schemas.OpSignature] = None self._param_schemas: Optional[tuple[ParamSchema, ...]] = None if self._op_schema is None: logger.debug( "An OpSchema was not provided for Op '%s' and " "there is not one found in opset '%s'.", - opname, + name, opset, ) @@ -316,6 +324,22 @@ def has_schema(self) -> bool: """Returns True if this op has an OpSchema.""" return self.op_schema is not None + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_op_schema(self.op_schema) + return self._signature + + @op_signature.setter + def op_signature(self, value: _schemas.OpSignature): + self._signature = value + def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: """Returns the parameter schemas for this op, if it has one.""" if self._param_schemas is not None: @@ -325,7 +349,7 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: if op_schema is None: return None - self._param_schemas = param_schemas_from_op_schema(op_schema) + self._param_schemas = _param_schemas_from_op_schema(op_schema) return self._param_schemas @@ -362,7 +386,7 @@ def as_tuple(self) -> tuple[str, list[str], str]: return (self.name, self.allowed_types, self.description) -def op_schema_from_function_ir( +def _op_schema_from_function_ir( function_ir: irbuilder.IRFunction, opset: Opset ) -> onnx.defs.OpSchema: """Construct an ONNX OpSchema from an IRFunction.""" @@ -437,7 +461,7 @@ def op_schema_from_function_ir( ) -class OnnxFunction(Op): +class OnnxFunction(Op, Generic[_P, _R]): """Represents an ONNX op for which a function-body has been defined in onnxscript. Attributes: @@ -453,7 +477,7 @@ class OnnxFunction(Op): def __init__( self, opset: Optional[Opset], - pyfun: types.FunctionType, + pyfun: Callable, irfun: irbuilder.IRFunction, source: str, kwargs: dict[str, Any], @@ -477,13 +501,16 @@ def __init__( self._param_schemas: Optional[tuple[ParamSchema, ...]] = None self._op_schema: Optional[onnx.defs.OpSchema] = None + # Allow the object to be inspected as a function + functools.update_wrapper(self, pyfun) + # Experimental fields - self.experimental_traceable = False + self.traceable = False @property @deprecation.deprecated( since="0.1", - removed_in="0.3", + removed_in="the future", instructions="use '.name' instead", ) def opname(self) -> str: @@ -497,10 +524,28 @@ def op_schema(self) -> Optional[onnx.defs.OpSchema]: if self._op_schema is not None: return self._op_schema - self._op_schema = op_schema_from_function_ir(self.function_ir, self.opset) + self._op_schema = _op_schema_from_function_ir(self.function_ir, self.opset) return self._op_schema + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_function( + self.function, domain=self.function_ir.domain, name=self.name + ) + return self._signature + + @op_signature.setter + def op_signature(self, value: _schemas.OpSignature): + self._signature = value + def __getitem__(self, instance): """Returns a lambda to evaluate function using given evaluator instance. @@ -518,12 +563,15 @@ def fun(*args, **kwargs): return fun - def __call__(self, *args, **kwargs): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """Implements an eager-mode execution of an onnxscript function.""" # FIXME(after #225): Move import to the top of the file. from onnxscript import evaluator # pylint: disable=import-outside-toplevel - return evaluator.default().eval_function(self, args, kwargs) + return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.function!r})" def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function.""" @@ -533,7 +581,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - self._param_schemas = param_schemas_from_function_ir(self.function_ir) + self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas def to_function_proto(self) -> onnx.FunctionProto: @@ -566,10 +614,13 @@ class TracedOnnxFunction(Op): func: Function. """ - def __init__(self, opset: Opset, func: types.FunctionType): + def __init__(self, opset: Opset, func: Callable): super().__init__(opset, func.__name__) self.func = func + # Allow the object to be inspected as a function + functools.update_wrapper(self, func) + def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) @@ -603,10 +654,28 @@ def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema # FIXME(justinchuby): outputs are empty. Need to fix. - self._op_schema = op_schema_from_function_ir(self.function_ir, self._opset) + self._op_schema = _op_schema_from_function_ir(self.function_ir, self._opset) return self._op_schema + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_function( + self.func, domain="_traced", name=self.name + ) + return self._signature + + @op_signature.setter + def op_signature(self, value: _schemas.OpSignature): + self._signature = value + def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function.""" if self._param_schemas is not None: @@ -615,7 +684,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - self._param_schemas = param_schemas_from_function_ir(self.function_ir) + self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas diff --git a/onnxscript/values_test.py b/onnxscript/values_test.py index ed21ff2775..c33e623334 100644 --- a/onnxscript/values_test.py +++ b/onnxscript/values_test.py @@ -1,3 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import inspect +import typing import unittest import onnxscript @@ -15,6 +21,48 @@ def function(input1, input2, attr1: int, attr2: int = 1): self.assertEqual(traced_function.name, function.__name__) self.assertEqual(traced_function.func, function) + def test_param_schemas_in_correct_order_with_mixed_inputs_and_attrs(self): + opset = values.Opset("test", 1) + + def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"): + return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3) + + traced_function = values.TracedOnnxFunction(opset, function) + param_schemas = traced_function.param_schemas() + expected_ordered_param_names = [ + "input1", + "input2", + "attr1", + "attr2", + "input3", + "attr3", + ] + self.assertEqual(len(param_schemas), len(expected_ordered_param_names)) + for i, param_schema in enumerate(param_schemas): + self.assertEqual(param_schema.name, expected_ordered_param_names[i]) + + def test_it_preserves_the_function_signature(self): + opset = values.Opset("test", 1) + + def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"): + return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3) + + traced_function = values.TracedOnnxFunction(opset, function) + signature = inspect.signature(traced_function) + self.assertEqual(signature.parameters["input1"].name, "input1") + self.assertEqual(signature.parameters["input2"].name, "input2") + self.assertEqual(signature.parameters["attr1"].name, "attr1") + self.assertEqual(signature.parameters["attr2"].name, "attr2") + self.assertEqual(signature.parameters["input3"].name, "input3") + self.assertEqual(signature.parameters["attr3"].name, "attr3") + + annotations = typing.get_type_hints(traced_function) + self.assertEqual(annotations["attr1"], int) + self.assertEqual(annotations["attr2"], float) + self.assertEqual(annotations["attr3"], str) + + +class OnnxFunctionTest(unittest.TestCase): def test_param_schemas_in_correct_order_with_mixed_inputs_and_attrs(self): opset = values.Opset("test", 1) @@ -34,3 +82,27 @@ def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "def self.assertEqual(len(param_schemas), len(expected_ordered_param_names)) for i, param_schema in enumerate(param_schemas): self.assertEqual(param_schema.name, expected_ordered_param_names[i]) + + def test_it_preserves_the_function_signature(self): + opset = values.Opset("test", 1) + + @onnxscript.script(default_opset=opset) + def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"): + return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3) + + signature = inspect.signature(function) + self.assertEqual(signature.parameters["input1"].name, "input1") + self.assertEqual(signature.parameters["input2"].name, "input2") + self.assertEqual(signature.parameters["attr1"].name, "attr1") + self.assertEqual(signature.parameters["attr2"].name, "attr2") + self.assertEqual(signature.parameters["input3"].name, "input3") + self.assertEqual(signature.parameters["attr3"].name, "attr3") + + annotations = typing.get_type_hints(function) + self.assertEqual(annotations["attr1"], int) + self.assertEqual(annotations["attr2"], float) + self.assertEqual(annotations["attr3"], str) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py new file mode 100644 index 0000000000..b0831a00f9 --- /dev/null +++ b/onnxscript/version_converter/__init__.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +__all__ = [ + "ConvertVersionPass", + "convert_version", +] + +import logging + +import onnx +import onnx_ir.passes.common as common_passes + +from onnxscript import ir +from onnxscript.version_converter import _c_api_utils, _version_converter + +logger = logging.getLogger(__name__) + + +class ConvertVersionPass(ir.passes.InPlacePass): + """Convert the model to the specified ONNX opset version. + + This pass leverages the onnxscript version converter to convert the model. If + the conversion is not supported, it falls back to the onnx C API to convert + the model. This pass is in-place. + + The pass is an no-op if the c-api fails. + + Attributes: + target_version: The target ONNX opset version to convert the model to. + fallback: Whether to fallback to the onnx version converter if the + target version is not supported. Default is False. + """ + + def __init__(self, target_version: int, fallback: bool = False) -> None: + super().__init__() + self.target_version = target_version + self.fallback = fallback + self.convert_pass = ir.passes.Sequential( + common_passes.InlinePass(), + _ConvertVersionPassRequiresInline( + target_version=target_version, + fallback=fallback, + ), + common_passes.RemoveUnusedNodesPass(), + common_passes.RemoveUnusedFunctionsPass(), + common_passes.RemoveUnusedOpsetsPass(), + ) + + def call(self, model: ir.Model) -> ir.passes.PassResult: + return self.convert_pass(model) + + +class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass): + """Convert the model to the specified ONNX opset version. + + This pass leverages the onnxscript version converter to convert the model. If + the conversion is not supported, it falls back to the onnx C API to convert + the model. This pass is in-place. + + The pass is an no-op if the c-api fails. + + Attributes: + target_version: The target ONNX opset version to convert the model to. + fallback: Whether to fallback to the onnx version converter if the + target version is not supported. + """ + + def __init__(self, target_version: int, fallback: bool) -> None: + super().__init__() + self.target_version = target_version + self.fallback = fallback + + def call(self, model: ir.Model) -> ir.passes.PassResult: + if model.functions: + raise ValueError( + "The model contains functions. The version conversion pass does not support " + "functions. Please use `common_passes.InlinePass` to inline the " + f"functions before applying this pass ({self.__class__.__name__})." + ) + if "" in model.graph.opset_imports: + onnx_opset_version = model.graph.opset_imports[""] + if onnx_opset_version == self.target_version: + # No need to convert the version + return ir.passes.PassResult(model, False) + + # When fallback is disabled, always use the onnxscript version converter; + # When fallback is enabled, use the onnxscript version converter + # if the target version is supported. Otherwise, use the onnx C API + # to convert the model. + if not self.fallback or _version_converter.version_supported( + model, self.target_version + ): + _version_converter.convert_version( + model, + target_version=self.target_version, + ) + return ir.passes.PassResult(model, True) + + if not self.fallback: + logger.warning( + "The model version conversion is not supported by the onnxscript version converter " + "and fallback is disabled. The model was not modified" + " (target version: %d). " + "Set fallback=True to enable fallback to the onnx c-api version converter.", + self.target_version, + ) + return ir.passes.PassResult(model, False) + else: + logger.warning( + "The model version conversion is not supported by the onnxscript version converter " + "and fallback is enabled. The model will be converted using the onnx C API " + "(target version: %d).", + self.target_version, + ) + + # If the onnxscript version converter does not support the conversion, + # we can use the onnx C API to convert the model + def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: + """Partial function to check the model.""" + return onnx.version_converter.convert_version( + proto, target_version=self.target_version + ) + + try: + converted_proto = _c_api_utils.call_onnx_api( + func=_partial_convert_version, model=model + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to convert the model to the target version %d using the ONNX C API. " + "The model was not modified", + self.target_version, + exc_info=e, + ) + return ir.passes.PassResult(model, False) + + converted_model = ir.from_proto(converted_proto) + + # Recover the initializers in the converted model + for input in converted_model.graph.inputs: + if input.name in model.graph.initializers: + input.const_value = model.graph.initializers[input.name].const_value + converted_model.graph.register_initializer(input) + user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)] + converted_model.graph.inputs.clear() + converted_model.graph.inputs.extend(user_inputs) + + # Return the converted graph to the original model to keep the pass in-place + model.graph = converted_model.graph + return ir.passes.PassResult(model, True) + + +def convert_version( + model: ir.Model | onnx.ModelProto, target_version: int, fallback=None +) -> None: + """Convert the model to the specified ONNX opset version. + + Args: + model: The model to convert. + target_version: The target ONNX opset version. + fallback: Whether to fallback to the onnx version converter if the + target version is not supported. Default is False. + """ + if isinstance(model, onnx.ModelProto): + model_proto = model + model = ir.from_proto(model) + else: + model_proto = None + + assert isinstance(model, ir.Model) + ConvertVersionPass(target_version=target_version, fallback=fallback)(model) + + if model_proto is not None: + # Update the model proto in-place + model_proto.graph.Clear() + del model_proto.functions[:] + model_proto.graph.CopyFrom(ir.to_proto(model.graph)) diff --git a/onnxscript/version_converter/_c_api_utils.py b/onnxscript/version_converter/_c_api_utils.py new file mode 100644 index 0000000000..7f9ac687f4 --- /dev/null +++ b/onnxscript/version_converter/_c_api_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Utilities for interfacing with onnx C APIs.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Callable, TypeVar + +from onnxscript import ir + +if TYPE_CHECKING: + import onnx + + +logger = logging.getLogger(__name__) +# Temporarily remove initializers larger than this size to keep model size down +# for the onnx.shape_inference call because it needs to serialize the model +_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB +_R = TypeVar("_R") + + +def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R: + """Call an ONNX C API function by temporarily removing initializers. + + This is necessary because the ONNX C API does not support large models + with initializers that have large tensor values. The input model is left + unchanged no matter the call succeeds or not. + + Args: + func: Partially applied function that takes a model proto and returns anything. + model: The IR model to pass to the API function. + + Returns: + The resulting ModelProto that contains the result of the API call. + """ + + # Store the original initializer values so they can be restored + initializer_values = tuple(model.graph.initializers.values()) + tensors = {v.name: v.const_value for v in initializer_values} + original_inputs_len = len(model.graph.inputs) + + # Turn the initializers into inputs and clear the initializers + # to limit the model size + for initializer in initializer_values: + # Make sure the initializer has its shape/type set + assert initializer.const_value is not None + if initializer.shape is None: + initializer.shape = initializer.const_value.shape # type: ignore[assignment] + if initializer.dtype is None: + initializer.dtype = initializer.const_value.dtype + if initializer not in model.graph.inputs: + model.graph.inputs.append(initializer) + if initializer.const_value.size > _BIG_TENSOR_SIZE_LIMIT: + # Temporarily remove the initializer value to reduce model size + # for onnx.shape_inference + initializer.const_value = None + assert initializer.name is not None + model.graph.initializers.pop(initializer.name) + + proto = ir.serde.serialize_model(model) + + try: + # Call the ONNX C API function + result = func(proto) + finally: + # Restore the original initializer values so the model is unchanged + for initializer in initializer_values: + initializer.const_value = tensors[initializer.name] + model.graph.register_initializer(initializer) + + # Restore the original inputs + inputs = model.graph.inputs[:original_inputs_len] + model.graph.inputs.clear() + model.graph.inputs.extend(inputs) + + return result diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py new file mode 100644 index 0000000000..dddf11150c --- /dev/null +++ b/onnxscript/version_converter/_version_converter.py @@ -0,0 +1,339 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Convert the model to the specified ONNX opset version.""" + +from __future__ import annotations + +import dataclasses +import functools +import logging +from typing import Callable, Sequence, Union + +import onnx_ir.convenience as ir_convenience + +import onnxscript.ir._tape as _tape +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +SUPPORTED_MAX_ONNX_OPSET = 23 +SUPPORTED_MIN_ONNX_OPSET = 18 + + +def _get_onnx_opset_version(model: ir.Model) -> int | None: + """Get the ONNX opset version imported by the model.""" + model_version1 = model.opset_imports.get("") + model_version2 = model.opset_imports.get("ai.onnx") + if model_version1 is not None and model_version2 is not None: + if model_version1 != model_version2: + raise ValueError( + f"Model imports multiple onnx opsets: {model_version1} and {model_version2}." + ) + return model_version1 or model_version2 + + +def _set_onnx_opset_version(model: ir.Model, version: int) -> None: + """Set the ONNX opset version imported by the model.""" + if "ai.onnx" in model.opset_imports: + del model.opset_imports["ai.onnx"] + model.opset_imports[""] = version + + +class VersionConverterError(RuntimeError): + """Raised when an node's version cannot be upgraded/downgraded successfully.""" + + +@dataclasses.dataclass +class Replacement: + """A replacement for a node in the graph.""" + + new_outputs: Sequence[ir.Value] + new_nodes: Sequence[ir.Node] + + +# A version-adapter function takes a node, a RewriterContext and returns +# a Replacement for the node or None (if no replacement is needed). + +RewriterContext = _tape.Builder +ReturnValue = Union[Sequence[ir.Value], ir.Value, None] +AdapterFunction = Callable[[ir.Node, RewriterContext], ReturnValue] + + +def version_supported(model: ir.Model, target_version: int) -> bool: + """Check if the target version is supported by the current version.""" + if "" in model.graph.opset_imports: + current_version = model.graph.opset_imports[""] + else: + return True + return ( + SUPPORTED_MIN_ONNX_OPSET + <= current_version + <= target_version + <= SUPPORTED_MAX_ONNX_OPSET + ) + + +class AdapterRegistry: + """A class that maintains a registry of adapters for ops.""" + + def __init__(self): + self.op_adapters: dict[tuple[str, str, int, bool], AdapterFunction] = {} + + def lookup_adapters( + self, + domain: str, + opname: str, + original_version: int, + up_conversion: bool = True, + ) -> AdapterFunction | None: + adapter_func = self.op_adapters.get((domain, opname, original_version, up_conversion)) + if adapter_func is not None: + return adapter_func + return None + + def register( + self, opname: str, domain: str = "", node_version=None, up_conversion=True + ) -> Callable[[AdapterFunction], AdapterFunction]: + """Register an adapter based on the domain, operator type, node version and whether to upgrade/downgrade node version""" + + def decorator(function: AdapterFunction) -> AdapterFunction: + @functools.wraps(function) + def wrapped_function(*args, **kwargs): + return function(*args, **kwargs) + + self.op_adapters[(domain, opname, node_version, up_conversion)] = function + return wrapped_function + + return decorator + + +registry: AdapterRegistry = AdapterRegistry() + +register = registry.register + + +def _get_input(node: ir.Node, index: int) -> ir.Value | None: + if index < len(node.inputs): + return node.inputs[index] + return None + + +def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None: + if name in node.attributes: + attr = node.attributes[name] + if not isinstance(attr, ir.Attr): + return None + attr_val = attr.value + if isinstance(attr_val, int): + return attr_val + # This is an invalid model: attribute has invalid/unexpected type. + # For now, we just return None. We could raise an error too. + return None + return default + + +def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> str | None: + if name in node.attributes: + attr = node.attributes[name] + if not isinstance(attr, ir.Attr): + return None + attr_val = attr.value + if isinstance(attr_val, str): + return attr_val + # This is an invalid model: attribute has invalid/unexpected type. + # For now, we just return None. We could raise an error too. + return None + return default + + +## Op-specific adapters + +# Opset 19 -> 20 + + +@register("DFT", node_version=19, up_conversion=True) +def dft_19_20(node: ir.Node, op): + input = node.inputs[0] + inverse = _get_int_attribute(node, "inverse", 0) + onesided = _get_int_attribute(node, "onesided", 0) + axis = _get_int_attribute(node, "axis", None) + if axis is not None: + axis_value = op.Constant(value_int=axis) + return op.DFT(input, axis_value, inverse=inverse, onesided=onesided) + return None + + +@register("GridSample", node_version=19, up_conversion=True) +def gridsample_19_20(node: ir.Node, op): + x = node.inputs[0] + grid = node.inputs[1] + align_corners = _get_int_attribute(node, "align_corners", 0) + mode = _get_str_attribute(node, "mode", "linear") + padding_mode = _get_str_attribute(node, "padding_mode", "zeros") + if mode == "bilinear": + return op.GridSample( + x, grid, align_corners=align_corners, mode="linear", padding_mode=padding_mode + ) + elif mode == "bicubic": + return op.GridSample( + x, grid, align_corners=align_corners, mode="cubic", padding_mode=padding_mode + ) + return None + + +# Opset 20 -> 21 + + +@register("GroupNormalization", node_version=20, up_conversion=True) +def groupnormalization_20_21(node: ir.Node, op): + x = _get_input(node, 0) + scale = _get_input(node, 1) + bias = _get_input(node, 2) + if x is None or scale is None or bias is None: + raise VersionConverterError(f"Missing input for {node}") + + x_shape = x.shape + if x_shape is None: + raise VersionConverterError(f"Missing required shape for {x}") + num_channels = x_shape[1] + if not isinstance(num_channels, int): + return None + + scale_shape = scale.shape + bias_shape = bias.shape + if scale_shape is None or bias_shape is None: + return None + if not isinstance(scale_shape[0], int) or not isinstance(bias_shape[0], int): + return None + + num_groups = _get_int_attribute(node, "num_groups", None) + if num_groups is None: + raise VersionConverterError("Missing required attribute: num_groups") + if ( + num_groups != num_channels + and num_groups == scale_shape[0] + and num_groups == bias_shape[0] + ): + reshape_1_sizes = op.Constant(value_ints=[-1, 1]) + reshape_2_sizes = op.Constant(value_ints=[-1]) + c_div = int(num_channels / num_groups) + expand_sizes = op.Constant(value_ints=[1, c_div]) + + # Modify scale input + scale_reshape_1 = op.Reshape(scale, reshape_1_sizes) + scale_expand = op.Expand(scale_reshape_1, expand_sizes) + scale_reshape_2 = op.Reshape(scale_expand, reshape_2_sizes) + + # Modify bias input + bias_reshape_1 = op.Reshape(bias, reshape_1_sizes) + bias_expand = op.Expand(bias_reshape_1, expand_sizes) + bias_reshape_2 = op.Reshape(bias_expand, reshape_2_sizes) + + return op.GroupNormalization(x, scale_reshape_2, bias_reshape_2, num_groups=num_groups) + return None + + +class _VersionConverter: + def __init__(self, target_version: int): + self._target_version = target_version + + def process_node( + self, node: ir.Node, from_version: int, up_conversion: bool = True + ) -> Replacement | None: + assert node.domain == "" + adapter = registry.lookup_adapters( + node.domain, node.op_type, from_version, up_conversion + ) + if adapter is None: + return None + context = RewriterContext() + output = adapter(node, context) + if output is not None: + if isinstance(output, ir.Value): + output = [output] + return Replacement(output, context.nodes) + return None + + def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: + logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) + + ir_convenience.replace_nodes_and_values( + root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs + ) + + def visit_attribute(self, attr: ir.Attr) -> None: + if attr.is_ref(): + return + if attr.type == ir.AttributeType.GRAPH: + self.visit_graph(attr.as_graph()) + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.as_graphs(): + self.visit_graph(graph) + + def visit_node( + self, + node: ir.Node, + root: ir.Graph | ir.Function, + from_version: int, + up_conversion: bool = True, + ) -> None: + if up_conversion: + to_version = from_version + 1 + else: + to_version = from_version - 1 + replacement = self.process_node(node, from_version, up_conversion) + if replacement is None: + # No change. Process attributes. + for attr in node.attributes.values(): + self.visit_attribute(attr) + node.version = to_version + else: + for new_node in replacement.new_nodes: + # TODO: control-flow + new_node.version = to_version + self.replace_node(node, replacement, root) + + def visit_graph(self, graph: ir.Graph) -> None: + for node in graph: + if node.domain != "": + continue + node_version = node.version or self._default_onnx_opset + if node_version is None: + raise VersionConverterError(f"Node {node} has no version.") + # Iterate each node from current node version -> target version + # and updating node based on the correct adapter + # Up-conversion [ver->ver+1] or down-conversion [ver->ver-1] + # TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted + if self._target_version < node_version: + raise VersionConverterError( + f"Target opset: {self._target_version} less than node version: {node.version}, " + "downstream version conversion not currently handled." + ) + for from_version in range(node_version, self._target_version): + try: + self.visit_node(node, graph, from_version, up_conversion=True) + except VersionConverterError as e: + logger.warning( + "Skipping version conversion for node %s due to exception: %s", + node.op_type, + e, + ) + + def visit_model(self, model: ir.Model) -> None: + self._default_onnx_opset = _get_onnx_opset_version(model) + self.visit_graph(model.graph) + _set_onnx_opset_version(model, self._target_version) + + +def convert_version(model: ir.Model, target_version: int) -> None: + """Convert the model to the specified ONNX opset version.""" + if (target_version > SUPPORTED_MAX_ONNX_OPSET) or ( + target_version < SUPPORTED_MIN_ONNX_OPSET + ): + raise ValueError( + f"Target opset version {target_version} is not supported. " + f"Supported range: {SUPPORTED_MIN_ONNX_OPSET} to {SUPPORTED_MAX_ONNX_OPSET}." + ) + version_converter = _VersionConverter(target_version=target_version) + version_converter.visit_model(model) diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py new file mode 100644 index 0000000000..cf6507196b --- /dev/null +++ b/onnxscript/version_converter/_version_converter_test.py @@ -0,0 +1,322 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx.defs +import pytest + +from onnxscript import ir, version_converter + + +class AdapterCoverageTest(unittest.TestCase): + def get_all_unique_schema_versions(self) -> dict[str, list]: + """Collect all unique versions of ONNX standard domain ops""" + op_version_dict = {} + all_schemas = onnx.defs.get_all_schemas_with_history() + for schema in all_schemas: + if schema.name not in op_version_dict: + op_version_dict[schema.name] = [schema.since_version] + else: + if schema.since_version not in op_version_dict[schema.name]: + op_version_dict[schema.name].append(schema.since_version) + return op_version_dict + + # TODO(shubhambhokare1) : Using existing onnx testing suite to verify operator adapter's functionality + def test_upstream_coverage(self): + op_version_dict = self.get_all_unique_schema_versions() + op_upgrades = [] + for op_type in op_version_dict: # pylint: disable=consider-using-dict-items + for opset_version in op_version_dict[op_type]: + op_upgrades.append((op_type, opset_version)) + + adapter_list = version_converter._version_converter.registry.op_adapters # pylint: disable=protected-access + for adapter_sig in adapter_list: + adapter_info = list(adapter_sig) + domain, name, upgrade_version = ( + adapter_info[0], + adapter_info[1], + adapter_info[2] + 1, + ) + self.assertEqual(domain, "") + self.assertIn((name, upgrade_version), op_upgrades) + + @pytest.mark.xfail(reason="TODO: Cleanup error status API.") + def test_version_convert_no_source_version(self): + model = ir.from_onnx_text( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (gridsample, shape_c) + } + """ + ) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") + + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + + +class VersionConverter18to17Test(unittest.TestCase): + @pytest.mark.xfail(strict=True, reason="Version downgrade not yet supported.") + def test_version_convert_compatible(self): + model = ir.from_onnx_text( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + target_version = 17 + version_converter.convert_version(model, target_version=target_version) + + +class VersionConverter18to19Test(unittest.TestCase): + def test_version_convert_compatible(self): + model = ir.from_onnx_text( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + target_version = 19 + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 19) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 19) + self.assertEqual(model.graph.node(4).op_type, "MatMul") + self.assertEqual(model.graph.node(4).version, 19) + + +class VersionConverter19to20Test(unittest.TestCase): + def test_version_convert_compatible(self): + model = ir.from_onnx_text( + """ + + agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + dft = DFT (reshape_x) + shape_c = Constant() + output = Reshape (dft, shape_c) + } + """ + ) + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(2).op_type, "Constant") + self.assertEqual(model.graph.node(3).version, 20) + self.assertEqual(model.graph.node(3).op_type, "DFT") + self.assertEqual(model.graph.node(3).version, 20) + self.assertEqual(len(model.graph.node(3).inputs), 2) + + def test_version_convert_gridsample_linear(self): + model = ir.from_onnx_text( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (gridsample, shape_c) + } + """ + ) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") + + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).version, 20) + self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") + + def test_version_convert_gridsample_cubic(self): + model = ir.from_onnx_text( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (gridsample, shape_c) + } + """ + ) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).attributes["mode"].value, "bicubic") + + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).version, 20) + self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic") + + def test_version_convert_inline(self): + model = ir.from_onnx_text( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + output = foo(gridsample) + } + + + foo (x) => (dft) { + dft = DFT (x) + } + """ + ) + target_version = 20 + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + self.assertEqual(model.graph.node(0).op_type, "Constant") + self.assertEqual(model.graph.node(0).version, 20) + self.assertEqual(model.graph.node(1).op_type, "Reshape") + self.assertEqual(model.graph.node(1).version, 20) + self.assertEqual(model.graph.node(4).op_type, "GridSample") + self.assertEqual(model.graph.node(4).version, 20) + self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") + self.assertEqual(model.graph.node(6).op_type, "DFT") + self.assertEqual(model.graph.node(6).version, 20) + self.assertEqual(len(model.graph.node(6).inputs), 2) + + +class VersionConverter20to21Test(unittest.TestCase): + def test_version_groupnorm(self): + model = ir.from_onnx_text( + """ + + agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output) + { + groupnorm = GroupNormalization (input_x, scale, bias) + shape_c = Constant() + output = Reshape (groupnorm, shape_c) + } + """ + ) + target_version = 21 + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + self.assertEqual(model.graph.node(3).op_type, "Reshape") + self.assertEqual(model.graph.node(3).version, 21) + self.assertEqual(model.graph.node(4).op_type, "Expand") + self.assertEqual(model.graph.node(4).version, 21) + self.assertEqual(model.graph.node(5).op_type, "Reshape") + self.assertEqual(model.graph.node(5).version, 21) + self.assertEqual(model.graph.node(6).op_type, "Reshape") + self.assertEqual(model.graph.node(6).version, 21) + self.assertEqual(model.graph.node(7).op_type, "Expand") + self.assertEqual(model.graph.node(7).version, 21) + self.assertEqual(model.graph.node(8).op_type, "Reshape") + self.assertEqual(model.graph.node(8).version, 21) + self.assertEqual(model.graph.node(9).op_type, "GroupNormalization") + self.assertEqual(model.graph.node(9).version, 21) + + def test_version_groupnorm_no_bias(self): + model = ir.from_onnx_text( + """ + + agraph (float[1, 4, 512, 512] input_x, float[2] scale) => (float[4, 512, 512] output) + { + groupnorm = GroupNormalization (input_x, scale) + shape_c = Constant() + output = Reshape (groupnorm, shape_c) + } + """ + ) + target_version = 21 + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + self.assertEqual(model.graph.node(0).op_type, "GroupNormalization") + self.assertEqual(model.graph.node(0).version, 20) + + +class VersionConverter23to24Test(unittest.TestCase): + @pytest.mark.xfail(strict=True, reason="Version upgrade beyond 23 not yet supported.") + def test_version_convert_compatible(self): + model = ir.from_onnx_text( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + target_version = 24 + version_converter.convert_version(model, target_version=target_version) + + +if __name__ == "__main__": + unittest.main() diff --git a/opgen/README.md b/opgen/README.md new file mode 100644 index 0000000000..af6b7bbebc --- /dev/null +++ b/opgen/README.md @@ -0,0 +1,17 @@ +# Generator for onnx_opset + +Use this module the generate onnx_opset implementations when new opsets are introduced with new ONNX versions. + +## Generate + +```sh +python opgen +``` + +Run + +```sh +python opgen -h +``` + +for more information. diff --git a/opgen/__main__.py b/opgen/__main__.py index 081ee5da64..400408465c 100644 --- a/opgen/__main__.py +++ b/opgen/__main__.py @@ -1,7 +1,9 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- +"""Main entry point for generating the onnx_opset modules. + +Example Usage: python opgen --exclude ai.onnx.preview.training/1 +""" import argparse import shutil @@ -9,7 +11,7 @@ import textwrap from pathlib import Path -from opgen.onnx_opset_builder import ( +from onnx_opset_builder import ( OpsetId, OpsetsBuilder, format_opsetid, diff --git a/opgen/onnx_opset_builder.py b/opgen/onnx_opset_builder.py index 8e528d5c15..f5c3c0daab 100644 --- a/opgen/onnx_opset_builder.py +++ b/opgen/onnx_opset_builder.py @@ -1,7 +1,5 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations @@ -9,16 +7,15 @@ from textwrap import dedent from typing import Annotated, Any, Iterable, Optional, Set, TextIO +import onnx +import pygen as cg from onnx.defs import ( - AttributeProto, OpSchema, get_all_schemas_with_history, onnx_opset_version, ) from onnx.helper import get_attribute_value -import opgen.pygen as cg - __all__ = [ "OpsetId", "parse_opsetid", @@ -61,8 +58,7 @@ def __init__(self, domain: str, name: str, version: int): def __repr__(self) -> str: return ( - f"QualOpName(domain={self.domain!r}, " - f"version={self.version!r}, name={self.name!r})" + f"QualOpName(domain={self.domain!r}, version={self.version!r}, name={self.name!r})" ) def __str__(self) -> str: @@ -142,14 +138,12 @@ def _write_header(self, writer: TextIO): writer.write("# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️ \n") writer.write("# ⚙️ Generated by 'python -m opgen'\n") writer.write(dashline) - writer.write("# Copyright (c) Microsoft Corporation. ") - writer.write("All rights reserved.\n") + writer.write("# Copyright (c) Microsoft Corporation.\n") writer.write("# Licensed under the MIT License.\n") writer.write(dashline) writer.write("# pylint: disable=W0221,W0222,R0901,W0237\n") writer.write("# mypy: disable-error-code=override\n") - writer.write("# ruff: noqa: N801,E741\n") - writer.write("# ruff: noqa: D214,D402,D405,D411,D412,D416,D417\n") + writer.write("# ruff: noqa: N801,E741,RUF036,D214,D402,D405,D411,D412,D416,D417\n") writer.write(dashline) writer.write("\n") writer.write("from __future__ import annotations\n") @@ -346,7 +340,7 @@ def constraint_is_compatible( for existing_constraints in input_constraints, output_constraints: if (existing := existing_constraints.get(constraint_name, None)) is not None: if len(existing) != len(constraint_types): - return False # differing number of constraints, can't be compatible + return False # differing number of constraints, can't be compatible for a, b in zip(existing, constraint_types): if str(a) != str(b): return False # a constrained type does not match @@ -561,7 +555,7 @@ def _make_function_input_args(self, schema: OpSchema) -> Iterable[cg.Arg]: def _make_function_attr_args(self, schema: OpSchema) -> Iterable[cg.Arg]: generate_kwonly_sentinel = True - for attr in schema.attributes.values(): + for attr in sorted(schema.attributes.values(), key=lambda a: a.name): attr_type = parse_attr_type(attr.type) default_value = None @@ -679,33 +673,33 @@ def error(message: Optional[str] = None): def parse_attr_type(type) -> cg.TypeRef: - if type == AttributeProto.FLOAT: + if type == onnx.AttributeProto.FLOAT: return cg.FloatTypeRef() - if type == AttributeProto.INT: + if type == onnx.AttributeProto.INT: return cg.IntTypeRef() - if type == AttributeProto.STRING: + if type == onnx.AttributeProto.STRING: return cg.StrTypeRef() - if type == AttributeProto.TENSOR: + if type == onnx.AttributeProto.TENSOR: return cg.TypeRef(MODULE_ONNX, "TensorProto") - if type == AttributeProto.SPARSE_TENSOR: + if type == onnx.AttributeProto.SPARSE_TENSOR: return cg.TypeRef(MODULE_ONNX, "SparseTensorProto") - if type == AttributeProto.GRAPH: + if type == onnx.AttributeProto.GRAPH: return cg.TypeRef(MODULE_ONNX, "GraphProto") - if type == AttributeProto.TYPE_PROTO: + if type == onnx.AttributeProto.TYPE_PROTO: return cg.TypeRef(MODULE_ONNX, "TypeProto") - if type == AttributeProto.FLOATS: + if type == onnx.AttributeProto.FLOATS: return cg.TypingRefs.Sequence(cg.FloatTypeRef()) - if type == AttributeProto.INTS: + if type == onnx.AttributeProto.INTS: return cg.TypingRefs.Sequence(cg.IntTypeRef()) - if type == AttributeProto.STRINGS: + if type == onnx.AttributeProto.STRINGS: return cg.TypingRefs.Sequence(cg.StrTypeRef()) - if type == AttributeProto.TENSORS: + if type == onnx.AttributeProto.TENSORS: return cg.TypingRefs.Sequence(cg.TypeRef(MODULE_ONNX, "TensorProto")) - if type == AttributeProto.SPARSE_TENSORS: + if type == onnx.AttributeProto.SPARSE_TENSORS: return cg.TypingRefs.Sequence(cg.TypeRef(MODULE_ONNX, "SparseTensorProto")) - if type == AttributeProto.GRAPHS: + if type == onnx.AttributeProto.GRAPHS: return cg.TypingRefs.Sequence(cg.TypeRef(MODULE_ONNX, "GraphProto")) - if type == AttributeProto.TYPE_PROTOS: + if type == onnx.AttributeProto.TYPE_PROTOS: return cg.TypingRefs.Sequence(cg.TypeRef(MODULE_ONNX, "TypeProto")) raise NotImplementedError(f"attribute type not implemented: {type}") diff --git a/opgen/pygen.py b/opgen/pygen.py index ffc412f9ec..bea7431186 100644 --- a/opgen/pygen.py +++ b/opgen/pygen.py @@ -367,7 +367,7 @@ def accept(self, visitor: Visitor): self._dispatch_visit(visitor.visit_constant) -class ExprList(Expr, Generic[TExpr], ABC): +class ExprList(Expr, ABC, Generic[TExpr]): class Roles: Elements = Role("ExprList.Elements") diff --git a/pyproject.toml b/pyproject.toml index 545fd21082..4f7edc9bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,14 @@ [build-system] -requires = ["setuptools>=61.0.0"] +requires = ["setuptools>=70.0.0"] build-backend = "setuptools.build_meta" [project] name = "onnxscript" -dynamic = ["version"] +dynamic = ["version", "urls"] description = "Naturally author ONNX functions and models using a subset of Python" authors = [{ name = "Microsoft Corporation", email = "onnx@microsoft.com" }] readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { file = "LICENSE" } classifiers = [ "Development Status :: 4 - Beta", @@ -17,14 +17,21 @@ classifiers = [ "Operating System :: POSIX", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", ] -dependencies = ["numpy", "onnx>=1.16", "typing_extensions"] +dependencies = [ + "ml_dtypes", + "numpy", + "onnx_ir>=0.1.10,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx>=1.16", + "packaging", + "typing_extensions>=4.10", +] [tool.setuptools.packages.find] include = ["onnxscript*"] @@ -34,7 +41,6 @@ onnxscript = ["py.typed"] onnx = ["py.typed"] [tool.pytest.ini_options] -filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"] addopts = "-rsfEX --tb=short --color=yes" [tool.mypy] @@ -73,52 +79,6 @@ module = [ ] ignore_errors = true -# FIXME(#1378): Remove this overrides section -[[tool.mypy.overrides]] -module = [ - "onnxrewriter.rewriter.generic_pattern_test.*", -] -check_untyped_defs = false -disable_error_code = 'override,import-untyped,no-untyped-def,assignment' -disallow_incomplete_defs = true -disallow_untyped_defs = true -disallow_untyped_decorators = true -show_column_numbers = true -strict_optional = true -warn_incomplete_stub = true -warn_no_return = true -warn_unused_configs = true -warn_unused_ignores = false - -# FIXME(#1378): Remove this overrides section -[[tool.mypy.overrides]] -module = [ - "onnxrewriter.rewriter.generic_pattern.*", -] -check_untyped_defs = false -disable_error_code = 'override,import-untyped,no-untyped-def,assignment,union-attr,func-returns-value,annotation-unchecked,arg-type,index,name-defined,attr-defined' -disallow_incomplete_defs = true -disallow_untyped_defs = true -disallow_untyped_decorators = true -show_column_numbers = true -strict_optional = true -warn_incomplete_stub = true -warn_no_return = true -warn_unused_configs = true -warn_unused_ignores = false - -[tool.black] -target-version = ["py38", "py39", "py310", "py311"] -# Black's extend-exclude needs to be a regex string -extend-exclude = "/tests/models|/tests/onnx_backend_test_code" -line-length = 95 - -[tool.isort] -profile = "black" -extend_skip_glob = [ - "tests/onnx_backend_test_code/*.py", -] - [tool.pylint.messages_control] # NOTE: This list is for vscode. Add new disables in pyproject_pylint.toml for lintrunner # Exclude patterns should be modified in .lintrunner.toml @@ -137,7 +97,10 @@ disable = [ convention = "google" [tool.ruff] +line-length = 95 target-version = "py38" + +[tool.ruff.lint] select = [ "B", # flake8-bugbear "C4", # flake8-comprehensions @@ -163,7 +126,13 @@ select = [ "W", # pycodestyle "YTT", # flake8-2020 ] +# Select preview rules +preview = true +extend-select = [ + "CPY001", # Copyright header +] ignore = [ + "B9", # Opinionated bugbear rules "C408", # Sometimes it is preferable when we construct kwargs "D1", # D1 is for missing docstrings, which is not yet enforced. "D202", # D202 Too strict. "No blank lines allowed after function docstring" @@ -172,7 +141,9 @@ ignore = [ "D400", "D401", # First line of docstring should be in imperative mood "D415", # D415 Not yet enforced. "First line should end with a period, question mark, or exclamation point" + "E1", "E2", "E3", # Pycodestyle formatting rules that conflicts with the formatter "E501", # Line length. Not enforced because black will handle formatting + "SIM103", # "Return the condition directly" obscures logic sometimes "N802", # Nxx: ONNX Script function sometimes use upper case for names. "N803", "N806", @@ -181,6 +152,9 @@ ignore = [ "PERF203", # try-except in loops sometimes necessary "PERF401", # List comprehension is not always readable "PYI041", # int | float is more clear + "RUF022", # We don't need to sort __all__ for elements to be grouped + "RUF031", # Parentheses for tuple in subscripts is more readable + "RUF052", # Variables with `_` prefix may not be dummy variables in all cases "SIM102", # Collapible if statements are not always more readable "SIM108", # We don't always encourage ternary operators "SIM114", # Don't always combine if branches for debugability @@ -188,22 +162,28 @@ ignore = [ "TRY003", # Messages can be constructed in the exception "UP006", # keep-runtime-typing "UP007", # keep-runtime-typing + "UP045", # TODO: Support new style type annotations ] -line-length = 95 ignore-init-module-imports = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "pathlib".msg = "Using pathlib can impact performance. Use os.path instead" +"onnx.helper".msg = "onnx helpers tend to be protobuf-y and slow. Consider using ir.tensor, ir.DataType and related methods instead" +"onnx.numpy_helper".msg = "onnx numpy helpers tend to be slow. Consider using ir.tensor, ir.DataType and related methods instead" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = ["TID252"] # Allow relative imports in init files "setup.py" = ["TID251"] # pathlib is allowed in supporting code -"**/{examples,tests,docs,tools,utils,opgen}/*" = ["TID251"] # pathlib is allowed in supporting code +"**/{examples,tests,docs,tools,utils,opgen,_framework_apis}/*" = ["TID251"] # pathlib is allowed in supporting code "**/*_test.py" = ["TID251"] # pathlib is allowed in tests +"onnxscript/onnx_opset/_impl/*.py" = ["RUF036"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. ban-relative-imports = "all" -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" + +[tool.ruff.lint.flake8-copyright] +notice-rgx = "(?i)Copyright \\(c\\) Microsoft Corporation" diff --git a/pyproject_pylint.toml b/pyproject_pylint.toml index e90adccb23..a764937fb5 100644 --- a/pyproject_pylint.toml +++ b/pyproject_pylint.toml @@ -2,6 +2,7 @@ [tool.pylint.messages_control] disable = [ + "arguments-differ", # TODO: abstract methods in Rewriter "attribute-defined-outside-init", # TODO: mostly in onnxscript/converter.py "cell-var-from-loop", # Bugbear B023 "consider-using-from-import", @@ -18,11 +19,13 @@ disable = [ "no-name-in-module", "redefined-builtin", # TODO: should we avoid redefined-builtin? "too-few-public-methods", + "too-many-ancestors", "too-many-arguments", "too-many-branches", "too-many-instance-attributes", "too-many-lines", "too-many-locals", + "too-many-positional-arguments", "too-many-public-methods", "too-many-return-statements", "too-many-statements", # TODO: we should work on these: too-many-xxx series diff --git a/requirements-dev.txt b/requirements-dev.txt index dfbe51ac23..b689d9bad5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,22 +2,20 @@ setuptools>=61.0.0 numpy onnx-weekly>=1.17.0.dev20240325 onnxruntime>=1.17.0 -typing_extensions +typing_extensions>=4.10 rich>=13.7.1 # Docs site furo jax[cpu] matplotlib -ml_dtypes myst-parser[linkify] sphinx-copybutton sphinx-exec-code sphinx-gallery sphinx>=6 - -# Torch lib -beartype!=0.16.0 +myst_nb +chardet # Testing expecttest==0.1.6 @@ -29,8 +27,9 @@ pytest-subtests pytest-xdist pytest!=7.1.0 pyyaml -torch>=2.1 -pyinstrument +torch>=2.3 +torchvision>=0.18.0 +transformers>=4.37.2 # Lint lintrunner>=0.10.7 diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index 3d562a116d..e005031603 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.17.0.dev20240415 +onnx-weekly==1.20.0.dev20251006 diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index 349b61034e..cb16597719 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ -# https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/ort-nightly/overview +# https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -ort-nightly==1.18.0.dev20240329005 +onnxruntime==1.23.1 diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 30a03a03bb..f95977610e 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,11 +1,11 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.4.3 +ruff==0.13.2 # MYPY -mypy==1.9.0 -types-PyYAML==6.0.12.11 +mypy==1.10.1 +types-PyYAML==6.0.12.20250915 # PYLINT -pylint==2.17.6 +pylint==3.3.9 # EDITORCONFIG-CHECKER -editorconfig-checker==2.7.3 +editorconfig-checker==3.4.0 diff --git a/setup.py b/setup.py index 32d496b7a5..f253346046 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- """NOTE: Put all metadata in pyproject.toml. Do not include complex logic in setup.py.""" import datetime @@ -17,7 +15,7 @@ version = VERSION_FILE.read_text().strip() project_urls = { - "Homepage": "https://onnxscript.ai/", + "Homepage": "https://microsoft.github.io/onnxscript/", "Repository": "https://github.com/microsoft/onnxscript", } if os.environ.get("ONNX_SCRIPT_RELEASE") != "1": diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx index e0d380b46b..77cfc7709c 100644 --- a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:06d78f841f26ec59cea1d15dd2c2a086cb907d6644ef8dac15e6d366935413e8 -size 43087292 +oid sha256:6dcf6976f8e324c497b0b74b2b9733c4b454f2c259488f5544bbc1aaaf57714c +size 43091738 diff --git a/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx b/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx index 2eede96c91..69a9c4c073 100644 --- a/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx +++ b/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a336102b11d8439daa2c1a164a851f34414529a5610a046943fd869b1b44336f -size 14665355 +oid sha256:ba424976b53bf2f141bfd001b48c0cc1c5c798b49def51f39a72f17e1f74e3a2 +size 14673089 diff --git a/testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx b/testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx index 61122be18a..a5433b830e 100644 --- a/testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx +++ b/testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:31fbebb580ff85ed8eefa7fb95d4e2cbda41fe267afeaae2d4f4177264d1f4e7 -size 46918368 +oid sha256:12d24be13a03ea8ddebcc5ea229390d49fb0da08ad1df896b03703c664e2def1 +size 46921843 diff --git a/testdata/e2e_models/torchscript_model/torchscript_model.onnx b/testdata/e2e_models/torchscript_model/torchscript_model.onnx index 7d450d2b8b..dd9bd08100 100644 --- a/testdata/e2e_models/torchscript_model/torchscript_model.onnx +++ b/testdata/e2e_models/torchscript_model/torchscript_model.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:efd167b736106103235f42b480027c28c798dd46117526ca49067a2bdbc7b327 -size 311182 +oid sha256:6519a87ecf89132a9d39c59c47a442ae5833faf14811575d0b2323e8d13e30a8 +size 313873 diff --git a/tests/__init__.py b/tests/__init__.py index 862c45ce31..59e481eb93 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,2 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 4c57480645..8099de9f12 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -1 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Shared components for testing.""" diff --git a/tests/common/onnx_script_test_case.py b/tests/common/onnx_script_test_case.py index 5608f415dc..3a46a870a0 100644 --- a/tests/common/onnx_script_test_case.py +++ b/tests/common/onnx_script_test_case.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import copy @@ -192,6 +190,7 @@ def run_converter_test( onnx_case_model: Optional[onnx.ModelProto] = None, *, ir_version: int = 9, + rtol: Optional[float] = None, ): # FIXME(justinchuby): Defaulting to ir_version 9 because ONNX Runtime supports # up to IR version 9 as of 4/2/2024. We should have a better mechanism to @@ -252,7 +251,7 @@ def run_converter_test( raise AssertionError(f"Unable to load model\n{model}") from e # input['input_2'] = None actual = session.run(None, input) - np.testing.assert_allclose(actual, param.output, rtol=self.rtol) + np.testing.assert_allclose(actual, param.output, rtol=rtol or self.rtol) def run_eager_test( self, diff --git a/tests/common/testutils.py b/tests/common/testutils.py index c0dafbff1b..2a2697b240 100644 --- a/tests/common/testutils.py +++ b/tests/common/testutils.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations import functools @@ -11,10 +9,11 @@ import numpy as np import onnx +import onnx_ir as ir import onnxruntime +import torch from onnxscript import optimizer -from onnxscript._legacy_ir import visitor from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils @@ -31,7 +30,7 @@ def skip_if_no_cuda(reason: str): def skip_dec(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - if not onnxruntime.get_device() == "GPU": + if not torch.cuda.is_available() or not onnxruntime.get_device() == "GPU": raise unittest.SkipTest(f"GPU is not available. {reason}") return func(self, *args, **kwargs) @@ -40,20 +39,6 @@ def wrapper(self, *args, **kwargs): return skip_dec -class OpTypeAnalysisVisitor(visitor.ProtoVisitorCore): - def __init__(self): - super().__init__() - self.op_types = set() - - def visit_model(self, model: onnx.ModelProto): - self.op_types = set() - super().visit_model(model) - - def process_node(self, node: onnx.NodeProto): - self.op_types.add((node.domain, node.op_type, getattr(node, "overload", ""))) - return super().process_node(node) - - def test_onnxruntime_rewrite( model_basename: str, model_count: int, @@ -85,10 +70,11 @@ def test_onnxruntime_rewrite( # onnx.save(rewritten, model_dir / f"{model_name}_opt.onnx") # Check expected operator is found. - optype_analysis = OpTypeAnalysisVisitor() - optype_analysis.visit_model(rewritten) + op_types = set() + for node in ir.from_proto(model).graph.all_nodes(): + op_types.add((node.domain, node.op_type, node.overload)) for domain, op_type, overload in expected_optypes: - if (domain, op_type, overload) not in optype_analysis.op_types: + if (domain, op_type, overload) not in op_types: raise AssertionError( f"Expected op type {domain}:{op_type}:{overload} not found in rewritten model." ) diff --git a/tests/eager_mode_test.py b/tests/eager_mode_test.py index b8ea940dae..566169f223 100644 --- a/tests/eager_mode_test.py +++ b/tests/eager_mode_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import unittest diff --git a/tests/eager_test.py b/tests/eager_test.py index ffed8be5f8..e8dd5c2e74 100644 --- a/tests/eager_test.py +++ b/tests/eager_test.py @@ -1,11 +1,12 @@ -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=import-outside-toplevel +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import itertools import unittest import numpy as np import parameterized +import torch from tests.common import onnx_script_test_case from tests.models import signal_dft @@ -82,10 +83,6 @@ def _stft( onesided=False, hop_length=None, ): - try: - import torch - except ImportError as e: - raise ImportError("torch is not installed.") from e ft = torch.stft( torch.from_numpy(x), n_fft=fft_length, @@ -166,9 +163,9 @@ def test_dft_cfft_last_axis(self): np.testing.assert_allclose(expected1, expected2) with self.subTest( c_shape=c.shape, - le=list(le), + le=le.tolist(), expected_shape=expected1.shape, - weights=we, + weights=we.tolist(), ): case = onnx_script_test_case.FunctionTestParams( signal_dft.dft_last_axis, [x, le, False], [expected1] @@ -195,7 +192,7 @@ def test_dft_rfft(self, x_, s: int): nax = np.array([ax], dtype=np.int64) with self.subTest( x_shape=x.shape, - le=list(le), + le=le.tolist(), ax=ax, expected_shape=expected.shape, ): @@ -233,7 +230,7 @@ def test_dft_cfft(self, x, y): np.testing.assert_allclose(expected1, expected2) with self.subTest( c_shape=c.shape, - le=list(le), + le=le.tolist(), ax=ax, expected_shape=expected1.shape, ): @@ -259,7 +256,7 @@ def test_dft_rifft(self, x_): nax = np.array([ax], dtype=np.int64) with self.subTest( x_shape=x.shape, - le=list(le), + le=le.tolist(), ax=str(ax), expected_shape=expected.shape, ): @@ -298,7 +295,7 @@ def test_dft_cifft(self, x, y): np.testing.assert_allclose(expected1, expected2) with self.subTest( c_shape=c.shape, - le=list(le), + le=le.tolist(), ax=str(ax), expected_shape=expected1.shape, ): diff --git a/tests/external_tensor_test.py b/tests/external_tensor_test.py index d908ba6cfb..f12e5720cd 100644 --- a/tests/external_tensor_test.py +++ b/tests/external_tensor_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import os import tempfile import unittest diff --git a/tests/function_libs/torch_lib/README.md b/tests/function_libs/torch_lib/README.md index 129b23adce..b8264dda87 100644 --- a/tests/function_libs/torch_lib/README.md +++ b/tests/function_libs/torch_lib/README.md @@ -1,16 +1,19 @@ -# Test op correctness by comparing with PyTorch results +# Test op correctness by comparing with PyTorch results using OpInfo + +`OpInfo` is PyTorch's standard mechanism for composing test data for operators. +Read more about them on https://github.com/pytorch/pytorch/blob/ce4a097bf769d753712a1fd969b446c59e29d8b9/torch/testing/_internal/opinfo/core.py#L362. ## Usage ```bash # All -pytest onnxscript/tests/function_libs/torch_lib/ops_test.py +python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py # To run tests on a specific operator (e.g. torch.ceil): -pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil +python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil # To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention): -pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention +python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention ``` ### Environment variables @@ -25,4 +28,53 @@ in onnxruntime by running the inference sessions in a separate process. ## How to add a new operator test -See _usage_ in [ops_test_data.py](./ops_test_data.py) +See _usage_ in [`ops_test_data.py`](./ops_test_data.py) + +## How to add custom OpInfo tests + +Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it. + +Follow the steps below to create new OpInfo tests: + +1. Use the implementation for `ops.aten.slice_scatter` as a reference (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2412-L2418) to declare an OpInfo in [`extra_opinfo.py`](./extra_opinfo.py) + + ```py + opinfo_core.OpInfo( + "ops.aten.slice_scatter", + aten_name="slice_scatter", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_slice_scatter, + supports_out=False, + ), + ``` + + - The first argument should be the operator name under the `torch.ops` namespace. For example, if you want to test the `prims.var` op, then put `"ops.prims.var"`. It should almost always start with `ops.`. + - Follow existing examples to specify the `dtypes` you want to test the op on. + - Specify `op=` if the target operator is not the same as the OpInfo name (first arg). For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2065-L2068. + + ```py + opinfo_core.OpInfo( + "ops.aten.bernoulli.p_deterministic", + op=torch.ops.aten.bernoulli.p, + ``` + + The op is `torch.ops.aten.bernoulli.p`, which is different from the name `ops.aten.bernoulli.p_deterministic`. OpInfo names need to be globally unique in a test suite. When `op` is not specified, it will look for the op in `torch.` using its name. + +2. Implement the `sample_inputs_func`. (Ref: https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L1242-L1268) + 1. Copy the function and decide what the input shapes should be. Use `make_arg` to generate a torch.Tensor. Alternatively you could also use `torch.tensor` to generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with + + ```py + yield opinfo_core.SampleInput(input, args=(...), kwargs={...}) + ``` + + `input` is the first arg. The rest of the args are in `args`. +3. Enable the test case in [`ops_test_data.py`](./ops_test_data.py) + 1. Add a `TorchLibOpInfo` entry to the `TESTED_TORCHLIB_OPS` list. (For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L2116) + + ```py + TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter) + ``` + + You can additionally specify dtype tolerance (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L539) or conditional skips (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L586-L590). + +Now that the test is added, you may run the test like mentioned above. Set `CREATE_REPRODUCTION_REPORT=1` to get markdown reports and view failing input combinations should any test case fails. diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py new file mode 100644 index 0000000000..754f5e2a25 --- /dev/null +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo + +import unittest + +import torch +from torch.onnx._internal.exporter import _testing + + +class TorchLibe2eTest(unittest.TestCase): + def test_investigate_one_particular_model(self): + """This test can be used to investigate a particular issue.""" + red, include, stype = "amin", False, "int32" + dtype = getattr(torch, stype) + + class Model(torch.nn.Module): + def __init__(self, include, red): + super().__init__() + self.include = include + self.red = red + + def forward(self, x, indices, updates): + x = x.clone() + return x.scatter_reduce( + 0, indices, updates, self.red, include_self=self.include + ) + + model = Model(include, red) + xs = ( + torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype), + torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64), + torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype), + ) + onnx_program = torch.onnx.export(model, xs, dynamo=True) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_int_float(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**0.5 + + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(2),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_int_int(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**2 + + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(2),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_float16_int(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**2 + + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_float16_float(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**0.5 + + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_integer_1(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_integer_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind, dim=0) + + onnx_program = torch.onnx.export( + Model(), + ( + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + torch.tensor([1, 2], dtype=torch.int64), + ), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor_none(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind) + + inputs = ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.tensor([1, 2, 3, 2], dtype=torch.int64), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + dynamo=True, + optimize=False, + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "ind"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_symbolic_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, y): + return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave( + y, x.shape[1], dim=1 + ) + + inputs = ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "y"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_sdpa_with_bool_attn_mask(self): + class ScaledDotProductAttention(torch.nn.Module): + def forward(self, query, key, value, attn_mask): + return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable + query, key, value, attn_mask=attn_mask + ) + + model = ScaledDotProductAttention() + attn_mask = torch.ones(2, 4, 8, 8).bool() # boolean mask for attention + attn_mask[0, 0, 0, :] = False # masking an entire row (padding token) + query = key = value = torch.randn(2, 4, 8, 16) + + onnx_program = torch.onnx.export( + model, + (query, key, value, attn_mask), + input_names=["query", "key", "value", "attn_mask"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_dynamic_paddings(self): + class Model(torch.nn.Module): + def forward(self, x): + height = x.size(2) # height is SymInt + x = torch.nn.functional.pad(x, (0, 0, 0, height), mode="replicate") + return x + + onnx_program = torch.onnx.export( + Model(), + (torch.rand(1, 1, 1, 1),), + dynamo=True, + dynamic_shapes=({2: torch.export.Dim("H")},), + ) + _testing.assert_onnx_program(onnx_program) + + def test_enable_gqa_in_attention(self): + class Model(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable + q, + k, + v, + enable_gqa=True, + ) + + model = Model() + + query = torch.randn(2, 4, 8, 16) + key = torch.randn(2, 2, 8, 16) + value = torch.randn(2, 2, 8, 16) + + onnx_program = torch.onnx.export( + model, + ( + query, + key, + value, + ), + input_names=["query", "key", "value"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_bitwise_and_scalar(self): + class Model(torch.nn.Module): + def forward(self, x): + return x & 3 + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([1, 2, 3, 4, 5]),), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/function_libs/torch_lib/error_reproduction.py b/tests/function_libs/torch_lib/error_reproduction.py index 5448666469..1eac88c48a 100644 --- a/tests/function_libs/torch_lib/error_reproduction.py +++ b/tests/function_libs/torch_lib/error_reproduction.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import difflib @@ -198,7 +200,7 @@ def create_reproduction_report( ) # Turn test name into a valid file name - markdown_file_name = f'{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md' + markdown_file_name = f"{short_test_name.replace('/', '-').replace(':', '-')}-{str(time.time()).replace('.', '_')}.md" markdown_file_path = save_error_report(markdown_file_name, markdown) print(f"Created reproduction report at {markdown_file_path}") @@ -245,7 +247,7 @@ def create_mismatch_report( error_stack=error_stack, ) - markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md' + markdown_file_name = f"mismatch-{short_test_name.replace('/', '-').replace(':', '-')}-{str(time.time()).replace('.', '_')}.md" markdown_file_path = save_error_report(markdown_file_name, markdown) print(f"Created reproduction report at {markdown_file_path}") diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 8c935c72e6..5d7deb1695 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ Test data for aten operators which don't exist in PyTorch file: pytorch/torch/testing/_internal/common_methods_invocations.py. @@ -35,6 +37,37 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(item, dtype=dtype) +def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for bilinear operation.""" + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + # Test cases: (batch_size, in1_features, in2_features, out_features) + cases = [ + (2, 3, 4, 5), # Basic case + (1, 2, 2, 1), # Minimal case + (3, 5, 7, 4), # Different dimensions + (2, 1, 1, 3), # Single input features + ] + + for batch_size, in1_features, in2_features, out_features in cases: + input1 = make_arg((batch_size, in1_features)) + input2 = make_arg((batch_size, in2_features)) + weight = make_arg((out_features, in1_features, in2_features)) + bias = make_arg((out_features,)) + + # Test with bias + yield opinfo_core.SampleInput(input1, args=(input2, weight, bias)) + + # Test without bias (only for first case to avoid too many tests) + if batch_size == 2: + yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -85,6 +118,35 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) +def sample_inputs_broadcast_in_dim(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + # cases: (input_shape, target_shape, broadcast_dimensions) + # broadcast_dimensions maps each input dim to an axis in target_shape + cases = ( + # scalar -> 1-D tensor + ((), (3,), ()), + # identity (no-op broadcast) + ((3,), (3,), (0,)), + # rank-preserving broadcast where singleton dims expand + ((1, 3, 1), (2, 3, 4), (0, 1, 2)), + # input rank 2 -> output rank 3, input dims map to trailing axes + ((3, 1), (2, 3, 4), (1, 2)), + # add leading broadcast axis + ((3, 4), (1, 3, 4), (1, 2)), + # insert broadcasting in middle axis + ((3,), (2, 3, 1), (1,)), + ) + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + for shape, target_shape, broadcast_dimensions in cases: + tensor = make_arg(shape) + yield opinfo_core.SampleInput(tensor, args=(target_shape, broadcast_dimensions)) + + def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -237,6 +299,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): "groups": 1, }, ), + ( + (1, 3, 224, 224), + (32, 3, 3, 3), + None, + { + "stride": (2,), + "padding": (1,), + "dilation": (1,), + "transposed": False, + "output_padding": (0, 0), + "groups": 1, + }, + ), ( (1, 3, 3, 224, 224), (32, 3, 3, 3, 3), @@ -250,21 +325,19 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs): "groups": 1, }, ), - # FIXME(jiz): Uncomment out these test data once - # torch 2.0 is released. - # ( - # (1, 3, 224, 224, 224), - # (32, 3, 3, 3, 3), - # (32,), - # { - # "stride": (2, 2, 2), - # "padding": (1, 1, 1), - # "dilation": (1, 1, 1), - # "transposed": False, - # "output_padding": (0, 0, 0), - # "groups": 1, - # }, - # ), + ( + (1, 3, 224, 224, 224), + (32, 3, 3, 3, 3), + (32,), + { + "stride": (2, 2, 2), + "padding": (1, 1, 1), + "dilation": (1, 1, 1), + "transposed": False, + "output_padding": (0, 0, 0), + "groups": 1, + }, + ), ( (2, 4, 6, 6), (4, 1, 3, 3), @@ -671,24 +744,38 @@ def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_): def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): del self # Unused - oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad) - + real_dtype = torch.float + if dtype == torch.complex128: + real_dtype = torch.double + oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, real_dtype, requires_grad) + oned_tensor_result = oned_tensor() + nd_tensor_result = nd_tensor() + complex_oned_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access + oned_tensor_result, [0], normalization=0, onesided=False + ) + # for normalization in (0, 1, 2): for normalization in (0, 1, 2): # 1-D yield opinfo_core.SampleInput( - oned_tensor(), dim=(0,), normalization=normalization, last_dim_size=12 + complex_oned_tensor, + dim=(0,), + normalization=normalization, + last_dim_size=31, ) # N-D for dim in [ (0,), (1,), (2,), - (1, 2), - (0, 1), - (0, 1, 2), ]: + complex_nd_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access + nd_tensor_result, dim, normalization=0, onesided=False + ) yield opinfo_core.SampleInput( - nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6 + complex_nd_tensor, + dim=dim, + normalization=normalization, + last_dim_size=complex_nd_tensor.shape[dim[-1]], ) @@ -788,20 +875,63 @@ def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs - data = torch_testing.make_tensor( - (10, 3), - device=device, - dtype=dtype, - requires_grad=requires_grad, - ) - indices = (torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4)),) - values = torch_testing.make_tensor( - (2, 4, 3), - device=device, - dtype=dtype, - requires_grad=requires_grad, + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad ) - yield opinfo_core.SampleInput(data, indices, values) + + cases = [ + # Cases: one None + ((1, 3, 4), [None, torch.arange(2, device=device), None], (1, 2, 4)), + ((10, 3, 4), [torch.arange(5, device=device), None, None], (5, 3, 4)), + ((10, 3, 4, 6), [None, None, None, torch.arange(3, device=device)], (10, 3, 4, 3)), + # Cases: two None + ( + (10, 3, 4), + [None, torch.arange(3, device=device), torch.arange(3, device=device)], + (10, 3), + ), + ( + (10, 3, 4, 6), + [ + torch.arange(2, device=device), + None, + torch.arange(2, device=device), + torch.arange(2, device=device), + ], + (2, 3), + ), + ( + (10, 3, 4), + [torch.arange(2, device=device), torch.arange(2, device=device), None], + (2, 4), + ), + # Cases: Single indexing + ((10, 3, 4), [None, None, torch.tensor([0], device=device)], (10, 3, 1)), + ((10, 3, 4), [torch.tensor([0], device=device), None, None], (1, 3, 4)), + ((10, 3, 4, 6), [None, torch.tensor([0], device=device), None, None], (10, 1, 4, 6)), + # Cases: Single element + ( + (10, 3, 4), + [ + torch.tensor([0], device=device), + torch.tensor([0], device=device), + torch.tensor([0], device=device), + ], + (1,), + ), + # Cases: Multidimensional index + ( + (10, 3), + [torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))], + (2, 4, 3), + ), + ] + + for data_shape, indices, values_shape in cases: # type: ignore[misc] + data = make_arg(data_shape) + values = make_arg(values_shape) # type: ignore[has-type] + + yield opinfo_core.SampleInput(data, indices, values) def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs): @@ -850,18 +980,6 @@ def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): ((S, S), {}), ((0, S, 0), {}), ((S,), {}), - ] - for shape, kwargs in inputs: - t = torch_testing.make_tensor( - shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad - ) - yield opinfo_core.SampleInput(t, **kwargs) - - -def sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - del self # Unused - - inputs = [ ((S,), {"dtype": dtype}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) @@ -1163,26 +1281,6 @@ def sample_inputs_rand_like(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(make_arg(shape)) -def sample_inputs_rand_like_dtype(op_info, device, dtype, requires_grad, **kwargs): - del op_info # Unused - del kwargs # Unused - - make_arg = functools.partial( - torch_testing.make_tensor, - device=device, - dtype=torch.float32, - requires_grad=requires_grad, - ) - shapes = ( - (M,), - (S, S), - (S, S, S), - ) - - for shape in shapes: - yield opinfo_core.SampleInput(make_arg(shape), kwargs=dict(dtype=dtype)) - - def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): high = 10 @@ -1210,14 +1308,6 @@ def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) -def sample_inputs_randint_like_dtype(self, device, dtype, requires_grad, **kwargs): - high = 10 - - for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput(sample.input, high, *sample.args, **sample.kwargs) - - def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **kwargs): low = 2 high = 10 @@ -1227,15 +1317,6 @@ def sample_inputs_randint_like_low_dtype(self, device, dtype, requires_grad, **k yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) -def sample_inputs_randint_like_low_dtype_dtype(self, device, dtype, requires_grad, **kwargs): - low = 2 - high = 10 - - for sample in sample_inputs_like_fns_dtype(self, device, dtype, requires_grad, **kwargs): - # With low and high - yield opinfo_core.SampleInput(sample.input, low, high, *sample.args, **sample.kwargs) - - def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): del op # Unused del device # Unused @@ -1315,6 +1396,109 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(input_, args=(src, *args)) +def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Basic test cases for scatter.src + cases = [ + # (self_shape, index_shape, src_shape, dim) + ((5, 5), (2, 3), (2, 3), 0), # 2D scatter on dim=0 + ((5, 5), (3, 2), (3, 2), 1), # 2D scatter on dim=1 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 0), # 3D scatter on dim=0 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 1), # 3D scatter on dim=1 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 2), # 3D scatter on dim=2 + ((10,), (3,), (3,), 0), # 1D scatter + ] + + for self_shape, index_shape, src_shape, dim in cases: + self_tensor = make_arg(self_shape) + # Create valid indices for the given dimension without duplication + index_buffer_shape = list(index_shape) + index_buffer_shape[dim] = self_shape[dim] + index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ + tuple(slice(None, d, None) for d in index_shape) + ] + src_tensor = make_arg(src_shape) + yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor)) + + # Additional test cases for scalar and single-element tensor combinations with dim=0 + # Test case: scalar index, scalar src (dim_size=5) + dim_size = 5 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + scalar_src = make_arg(()) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, scalar_src)) + + # Test case: single-element tensor index, scalar src (dim_size=7) + dim_size = 7 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + scalar_src = make_arg(()) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, scalar_src)) + + # Test case: scalar index, single-element tensor src (dim_size=3) + dim_size = 3 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + src_1d = make_arg((1,)) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, src_1d)) + + # Test case: single-element tensor index, single-element tensor src (dim_size=10) + dim_size = 10 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + src_1d = make_arg((1,)) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, src_1d)) + + +def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Basic test cases for scatter.value + cases = [ + # (self_shape, index_shape, dim, value) + ((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value + ((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value + ((3, 4, 5), (2, 2, 3), 0, False), # 3D scatter on dim=0 with scalar value + ((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value + ((3, 4, 5), (2, 2, 3), 2, -1), # 3D scatter on dim=2 with scalar value + ((10,), (3,), 0, 5.0), # 1D scatter with scalar value + ] + + for self_shape, index_shape, dim, value in cases: + self_tensor = make_arg(self_shape) + # Create valid indices for the given dimension without duplication + index_buffer_shape = list(index_shape) + index_buffer_shape[dim] = self_shape[dim] + index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ + tuple(slice(None, d, None) for d in index_shape) + ] + yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value)) + + # Additional test cases for scalar and single-element tensor combinations with dim=0 + # Test case: scalar index with scalar value (dim_size=6, value_type=torch.long) + dim_size = 6 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + random_value = torch.randint(0, 10, (), device=device, dtype=torch.long).item() + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, random_value)) + + # Test case: single-element tensor index with scalar value (dim_size=8, value_type=torch.float) + dim_size = 8 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + random_value = torch.rand((), device=device, dtype=torch.float).item() + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, random_value)) + + def sample_inputs__scaled_dot_product_flash_attention( op_info, device, dtype, requires_grad, **kwargs ): @@ -1368,12 +1552,13 @@ def sample_inputs__scaled_dot_product_efficient_attention( make = opinfo_core.partial( opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad ) - batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 + batch, seq_q, seq_kv, num_heads, head_dim = 2, 3, 6, 4, 8 dim_4_q_shape = (batch, num_heads, seq_q, head_dim) dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] for qkv_shape, is_causal, dropout_p, compute_log_sumexp in opinfo_core.product( qkv_shapes, [True, False], [0.0], [True, False] @@ -1384,7 +1569,7 @@ def sample_inputs__scaled_dot_product_efficient_attention( make(shape_q), make(shape_kv), make(shape_kv), - attn_bias=None, + attn_bias=None, # TODO: Add attn_bias is_causal=is_causal, dropout_p=dropout_p, compute_log_sumexp=compute_log_sumexp, @@ -1422,6 +1607,30 @@ def sample_inputs__softmax( yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) +def sample_inputs_prims_std_var(op_info, device, dtype, requires_grad, **kwargs): + del op_info # Unused + del kwargs # Unused + tensor_nd = functools.partial( + opinfo_core.make_tensor, + (S, S, S), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + tensor_1d = functools.partial( + opinfo_core.make_tensor, (S,), device=device, dtype=dtype, requires_grad=requires_grad + ) + + yield opinfo_core.SampleInput(tensor_nd(), dims=(1,), correction=0) + yield opinfo_core.SampleInput(tensor_1d(), dims=(0,), correction=0) + yield opinfo_core.SampleInput(tensor_1d(), dims=(0,), correction=1) + + yield opinfo_core.SampleInput(tensor_nd(), dims=(1,), correction=1) + yield opinfo_core.SampleInput(tensor_nd(), dims=(1,), correction=S // 2) + yield opinfo_core.SampleInput(tensor_nd(), dims=(), correction=0) + # Negative indices are not supported + + def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -1585,7 +1794,7 @@ def shape(size, rank, with_batch_channel=True): None, # output_size align_corners, ), - kwargs=dict(scale_factors=(1.7, 1.7)), + kwargs=dict(scale_factors=[1.7, 1.7]), ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1593,7 +1802,7 @@ def shape(size, rank, with_batch_channel=True): None, # if this is None, the scalar must be list align_corners, ), - kwargs=dict(scale_factors=(0.6, 0.6)), + kwargs=dict(scale_factors=[0.6, 0.6]), ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1601,7 +1810,7 @@ def shape(size, rank, with_batch_channel=True): None, # if this is None, the scalar must be list align_corners, ), - kwargs=dict(scale_factors=(0.6, 4.2)), + kwargs=dict(scale_factors=[0.6, 4.2]), ) @@ -1651,7 +1860,6 @@ def sample_inputs_upsample_nearest1d(op_info, device, dtype, requires_grad, **kw N, C = 2, 3 D = 4 - SS = 3 L = 5 rank = 1 @@ -1670,8 +1878,6 @@ def shape(size, rank, with_batch_channel=True): high=1, ) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) - yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(S, rank, False), @@ -1680,15 +1886,53 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(L, rank, False), ) + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(S, rank, False), # output_size + # [1.7], # scaler + # ) + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(S, rank, False), # if this is None, the scalar must be list + # [0.6], + # ) + + +def sample_inputs_upsample_nearest1d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + L = 5 + + rank = 1 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), None, # output_size - (1.7,), # scaler + scale_factors=(1.7,), ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), - None, # if this is None, the scalar must be list - (0.6,), + None, + scale_factors=(0.6,), ) @@ -1698,7 +1942,6 @@ def sample_inputs_upsample_nearest2d(op_info, device, dtype, requires_grad, **kw N, C = 2, 3 D = 4 - SS = 3 L = 5 rank = 2 @@ -1717,8 +1960,6 @@ def shape(size, rank, with_batch_channel=True): high=1, ) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) - yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(S, rank, False), @@ -1727,26 +1968,62 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(L, rank, False), ) - # ONNX don't support below cases: both output_size and scaler are not None # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 1.7, # scaler + # 1.7, 2.0, # scaler # ) # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 0.6, + # 0.6, 0.4, # ) +def sample_inputs_upsample_nearest2d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + L = 5 + + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(1.7, 2.0), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(0.6, 0.4), + ) + + def sample_inputs_upsample_nearest3d(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs N, C = 2, 3 D = 4 - SS = 3 L = 5 rank = 3 @@ -1765,8 +2042,6 @@ def shape(size, rank, with_batch_channel=True): high=1, ) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) - yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(S, rank, False), @@ -1775,19 +2050,56 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(L, rank, False), ) - # ONNX don't support below cases: both output_size and scaler are not None # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 1.7, # scaler + # 1.7, 1.5, 2.0, # scaler # ) # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 0.6, + # 0.6, 0.3, 0.5, # ) +def sample_inputs_upsample_nearest3d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + L = 5 + + rank = 3 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(1.7, 1.5, 2.0), # scaler + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(0.6, 0.3, 0.5), + ) + + def sample_inputs_upsample_trilinear3d(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -1826,6 +2138,97 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs__unique(op_info, device, dtype, requires_grad, **kwargs): + for sample in common_methods_invocations.sample_inputs_unique( + op_info, device, dtype, requires_grad, **kwargs + ): + return_counts = sample.kwargs.pop("return_counts", None) + dim = sample.kwargs.pop("dim", None) + # take only those samples that do not ask for counts or a dim + if not return_counts and dim is None: + yield sample + + +def sample_inputs__unique2(op_info, device, dtype, requires_grad, **kwargs): + for sample in common_methods_invocations.sample_inputs_unique( + op_info, device, dtype, requires_grad, **kwargs + ): + # take only those samples that do not ask for a dim + if sample.kwargs.pop("dim", None) is None: + yield sample + + +def sample_inputs_unique_dim(op_info, device, dtype, requires_grad, **kwargs): + for sample in common_methods_invocations.sample_inputs_unique( + op_info, device, dtype, requires_grad, **kwargs + ): + # take only those samples that ask for a dim + if sample.kwargs.get("dim") is not None: + yield sample + + +def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 3 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True, None) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners, None + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(None, align_corners), + kwargs=dict(scale_factors=(1.7, 1.7, 1.7)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(None, align_corners), + kwargs=dict(scale_factors=(0.6, 0.6, 0.6)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(None, align_corners), + kwargs=dict(scale_factors=(0.6, 1.7, 4.2)), + ) + + +def sample_inputs_window_functions(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + del device + del requires_grad + + for window_length in [2, 3, 7, 10, 32]: + yield opinfo_core.SampleInput(window_length, kwargs=dict(dtype=dtype)) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -1911,6 +2314,13 @@ def __init__(self): # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. OP_DB: List[opinfo_core.OpInfo] = [ + opinfo_core.OpInfo( + "bilinear", + op=torch.nn.functional.bilinear, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_bilinear, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", @@ -1928,6 +2338,13 @@ def __init__(self): sample_inputs_func=sample_inputs_bernoulli_p_deterministic, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.blackman_window", + aten_name="blackman_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.col2im", aten_name="col2im", @@ -1991,6 +2408,26 @@ def __init__(self): sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), + opinfo_core.BinaryUfuncInfo( + "ops.aten.floor_divide", + aten_name="floor_divide", + dtypes=common_dtype.all_types_and_half(), + rhs_make_tensor_kwargs=dict(exclude_zero=True), + ), + opinfo_core.OpInfo( + "ops.aten.hamming_window", + aten_name="hamming_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.hann_window", + aten_name="hann_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.index.Tensor", aten_name="index.Tensor", @@ -2032,7 +2469,7 @@ def __init__(self): opinfo_core.OpInfo( "ops.aten._local_scalar_dense", aten_name="_local_scalar_dense", - dtypes=common_dtype.all_types(), + dtypes=common_dtype.all_types_and(torch.bool), sample_inputs_func=sample_inputs__local_scalar_dense, supports_out=False, ), @@ -2147,14 +2584,6 @@ def __init__(self): sample_inputs_func=sample_inputs_rand_like, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.rand_like__dtype", - op=torch.ops.aten.rand_like, - aten_name="rand_like", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_rand_like_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.randint", aten_name="randint", @@ -2176,14 +2605,6 @@ def __init__(self): sample_inputs_func=sample_inputs_randint_like, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.randint_like__dtype", - op=torch.ops.aten.randint_like, - aten_name="randint_like", - dtypes=common_dtype.integral_types(), - sample_inputs_func=sample_inputs_randint_like_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.randint_like.low_dtype", aten_name="randint_like.low_dtype", @@ -2191,14 +2612,6 @@ def __init__(self): sample_inputs_func=sample_inputs_randint_like_low_dtype, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.randint_like.low_dtype__dtype", - op=torch.ops.aten.randint_like.low_dtype, - aten_name="randint_like.low_dtype", - dtypes=common_dtype.integral_types(), - sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.randn", aten_name="randn", @@ -2213,14 +2626,6 @@ def __init__(self): sample_inputs_func=sample_inputs_like_fns, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.randn_like_dtype", - op=torch.ops.aten.randn_like, - aten_name="randn", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_like_fns_dtype, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.reflection_pad1d", aten_name="ops.aten.reflection_pad1d", @@ -2269,6 +2674,22 @@ def __init__(self): sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.scatter.src", + op=torch.ops.aten.scatter.src, + aten_name="scatter.src", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_scatter_src, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.scatter.value", + op=torch.ops.aten.scatter.value, + aten_name="scatter.value", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_scatter_value, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten._softmax", op=torch.ops.aten._softmax, # pylint: disable=protected-access @@ -2314,6 +2735,30 @@ def __init__(self): sample_inputs_func=sample_inputs_unfold, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._unique.default", + aten_name="_unique.default", + dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8), + sample_inputs_func=sample_inputs__unique, + supports_out=False, + supports_autograd=False, + ), + opinfo_core.OpInfo( + "ops.aten._unique2.default", + aten_name="_unique2.default", + dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8), + sample_inputs_func=sample_inputs__unique2, + supports_out=False, + supports_autograd=False, + ), + opinfo_core.OpInfo( + "ops.aten.unique_dim.default", + aten_name="unique_dim.default", + dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8), + sample_inputs_func=sample_inputs_unique_dim, + supports_out=False, + supports_autograd=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_bicubic2d.default", aten_name="upsample_bicubic2d", @@ -2321,6 +2766,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_2d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._upsample_bicubic2d_aa", + aten_name="_upsample_bicubic2d_aa", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_bicubic2d.vec", aten_name="upsample_bicubic2d.vec", @@ -2335,6 +2787,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_2d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._upsample_bilinear2d_aa", + aten_name="_upsample_bilinear2d_aa", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_bilinear2d.vec", aten_name="upsample_bilinear2d.vec", @@ -2356,6 +2815,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_nearest1d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest1d.vec", + aten_name="upsample_nearest1d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest1d_vec, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_nearest2d", aten_name="upsample_nearest2d", @@ -2363,6 +2829,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_nearest2d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest2d.vec", + aten_name="upsample_nearest2d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest2d_vec, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_nearest3d", aten_name="upsample_nearest3d", @@ -2371,12 +2844,45 @@ def __init__(self): supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.upsample_trilinear3d", + "ops.aten.upsample_nearest3d.vec", + aten_name="upsample_nearest3d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest3d_vec, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_trilinear3d.default", aten_name="upsample_trilinear3d", dtypes=common_dtype.floating_types_and(torch.bfloat16), sample_inputs_func=sample_inputs_upsample_trilinear3d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_trilinear3d.vec", + aten_name="upsample_trilinear3d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_trilinear3d_vec, + supports_out=False, + ), + opinfo_core.ReductionOpInfo( + "ops.prims.broadcast_in_dim.default", + op=torch.ops.prims.broadcast_in_dim.default, + dtypes=common_dtype.all_types(), + sample_inputs_func=sample_inputs_broadcast_in_dim, + supports_out=False, + ), + opinfo_core.ReductionOpInfo( + "ops.prims.var.default", + nan_policy="propagate", + supports_out=True, + promotes_int_to_float=True, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + dtypes=common_dtype.floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_prims_std_var, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index cf29a8b804..a45050fb22 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Test op correctness by comparing with PyTorch results. Usage: @@ -37,7 +39,7 @@ from torch.utils import _pytree as pytree import onnxscript -import onnxscript.evaluator +from onnxscript._internal import version_utils from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -96,50 +98,24 @@ def _should_skip_xfail_test_sample( class TestFunctionValidity(unittest.TestCase): - def test_all_script_functions_are_onnx_functions(self): - for info in ops_test_data.TESTED_TORCHLIB_OPS: - if info.trace_only: - continue - with self.subTest(name=info.op_info_name): - func = info.op - if not isinstance(func, onnxscript.OnnxFunction): - raise TypeError( - f"'{func}' is not an OnnxFunction. Was it decorated with '@torch_op'? " - "If the function is trace_only, please specify trace_only=True " - "in the TorchLibOpInfo entry." - ) - - def test_all_trace_only_functions_are_not_onnx_functions(self): - for info in ops_test_data.TESTED_TORCHLIB_OPS: - if not info.trace_only: - continue - with self.subTest(name=info.op_info_name): - func = info.op - if not isinstance(func, onnxscript.TracedOnnxFunction): - raise TypeError( - f"'{func.name}' is not a TracedOnnxFunction. " - "If the function is not trace_only, please remove trace_only=True " - "in the TorchLibOpInfo entry." - ) - @parameterized.parameterized.expand( - [ - (info.op.name, info) - for info in ops_test_data.TESTED_TORCHLIB_OPS - if not info.trace_only - ] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_script_function_passes_checker( self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo ): + if not isinstance(torchlib_op_info.op, onnxscript.OnnxFunction): + self.skipTest("Traced functions does not have a function proto") function_proto = torchlib_op_info.op.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo): func = torchlib_op_info.op + if not hasattr(func, "op_schema"): + raise AssertionError(f"Function {func.__name__} does not have op_schema attribute") schema = func.op_schema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) @@ -227,8 +203,7 @@ def run_test_output_match( reference_torch_outputs, _ = pytree.tree_flatten(torch_output) if ( op.name.startswith("split") - or op.name.startswith("chunk") - or op.name.startswith("unbind") + or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7")) or op.name in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"} ): @@ -294,68 +269,6 @@ def run_test_output_match( raise -class TestOutputConsistencyEager(unittest.TestCase): - """Test output consistency between the ONNX op run with ONNX eager mode and PyTorch eager mode. - - This is a parameterized test suite. - """ - - def setUp(self) -> None: - torch.manual_seed(42) - np.random.seed(42) - ort.set_seed(42) - - @ops_test_common.add_decorate_info( - ops_test_data.OPS_DB, - "TestOutputConsistencyEager", - "test_output_match_opinfo_", - skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS, - ) - @common_device_type.ops( # type: ignore[misc] - [info for info in ops_test_data.OPS_DB if info.name in ops_test_data.TESTED_OPS], - allowed_dtypes=TESTED_DTYPES, - ) - def test_output_match_opinfo_( - self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo - ): - # Base test method for testing each op with the eager executor, used by instantiate_device_type_tests. - run_test_output_match( - self, - device, - dtype, - op, - ops_test_common.eager_executor, - ops_test_data.TORCHLIB_OPINFO_MAPPING, - ) - - @ops_test_common.add_decorate_info( - ops_test_data.OPS_DB, - "TestOutputConsistencyEager", - "test_complex_output_match_opinfo_", - skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS, - ) - @common_device_type.ops( # type: ignore[misc] - [ - info - for info in ops_test_data.OPS_DB - if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING - ], - allowed_dtypes=COMPLEX_TYPES, - ) - def test_complex_output_match_opinfo_( - self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo - ): - """Base test method for testing each op with the eager executor, used by instantiate_device_type_tests.""" - run_test_output_match( - self, - device, - dtype, - op, - ops_test_common.eager_executor, - ops_test_data.COMPLEX_FUNCTION_MAPPING, - ) - - class TestOutputConsistencyFullGraph(unittest.TestCase): """Test output consistency between exported ONNX op run as a graph and PyTorch eager mode. @@ -418,10 +331,6 @@ def test_complex_output_match_opinfo_( ) -common_device_type.instantiate_device_type_tests( - TestOutputConsistencyEager, globals(), only_for=["cpu", "cuda"] -) - common_device_type.instantiate_device_type_tests( TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"] ) diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index 34f5b58446..decaddddf4 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Common utils for testing operators.""" from __future__ import annotations @@ -8,6 +10,7 @@ import multiprocessing import os import pprint +import sys import unittest import warnings from typing import ( @@ -23,15 +26,18 @@ import numpy as np import onnx +import onnx_ir.passes.common as common_passes import onnxruntime as ort import onnxruntime.capi.onnxruntime_pybind11_state import pytest import torch +from torch.onnx._internal.exporter import _building, _tensors from torch.testing._internal.opinfo import core as opinfo_core import onnxscript import onnxscript.evaluator -from onnxscript.function_libs.torch_lib import graph_building +from onnxscript import ir +from onnxscript.function_libs.torch_lib.ops import common as common_ops from tests.function_libs.torch_lib import error_reproduction T = TypeVar("T") @@ -56,6 +62,7 @@ ) TEST_OPSET_VERSION = 18 +IS_MACOS = sys.platform.startswith("darwin") IS_WINDOWS = os.name == "nt" @@ -172,9 +179,9 @@ def add_decorate_info( # If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo # because it could be an OpInfo that is in torch-nightly but not older versions. continue - assert ( - opinfo is not None - ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + assert opinfo is not None, ( + f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + ) decorators = list(opinfo.decorators) new_decorator = opinfo_core.DecorateInfo( decorate_meta.decorator, @@ -249,7 +256,7 @@ def duplicate_opinfo_for_prims( raise RuntimeError(f"OpInfo '{name}' not found in the database.") -TORCH_TYPE_TO_ONNX = { +_TORCH_TYPE_TO_ONNX = { torch.bool: onnx.TensorProto.BOOL, torch.uint8: onnx.TensorProto.UINT8, torch.int8: onnx.TensorProto.INT8, @@ -263,6 +270,27 @@ def duplicate_opinfo_for_prims( torch.complex128: onnx.TensorProto.COMPLEX128, torch.bfloat16: onnx.TensorProto.BFLOAT16, } +_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { + torch.bfloat16: ir.DataType.BFLOAT16, + torch.bool: ir.DataType.BOOL, + torch.complex128: ir.DataType.COMPLEX128, + torch.complex64: ir.DataType.COMPLEX64, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, + torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, + torch.float8_e5m2: ir.DataType.FLOAT8E5M2, + torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.int8: ir.DataType.INT8, + torch.uint8: ir.DataType.UINT8, + torch.uint16: ir.DataType.UINT16, + torch.uint32: ir.DataType.UINT32, + torch.uint64: ir.DataType.UINT64, +} def convert_tensor_to_numpy(input: Any) -> Any: @@ -273,7 +301,7 @@ def convert_tensor_to_numpy(input: Any) -> Any: return input.detach().cpu().numpy() if isinstance(input, complex): return torch.view_as_real(torch.tensor(input)).detach().cpu().numpy() - if isinstance(input, (tuple, list)): + if isinstance(input, list): if len(input) == 0: return np.array((), dtype=np.int64) if any(isinstance(x, torch.Tensor) for x in input): @@ -298,7 +326,7 @@ def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: if key == "device": continue if key == "dtype": - value = TORCH_TYPE_TO_ONNX[value] + value = _TORCH_TYPE_TO_ONNX[value] if isinstance(value, torch.Tensor): value = np.array(value.cpu()) new_kwargs[key] = value @@ -365,12 +393,7 @@ def _safe_ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any] def _format_model_and_input_information(onnx_model, inputs): - return ( - f"Inputs:\n" - f"{pprint.pformat(inputs)}\n" - f"Model:\n" - f"{onnx.printer.to_text(onnx_model)}" - ) + return f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}" TORCH_DTYPE_TO_ONNX_STRING = { @@ -389,6 +412,19 @@ def _format_model_and_input_information(onnx_model, inputs): } +def add_torchlib_common_imports(model: ir.Model) -> None: + """Hack to add torchlib common imports to the model.""" + + model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 + rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) + is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto()) + model.functions[rank_func.identifier()] = rank_func + model.functions[is_scalar_func.identifier()] = is_scalar_func + removal_pass = common_passes.RemoveUnusedFunctionsPass() + assert removal_pass.in_place + removal_pass(model) + + def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: """Checks if the dtype is compatible with the schema. @@ -458,19 +494,33 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, """Captures the graph of a function and evaluates it using TorchScriptEvaluator.""" # Initialize the ONNX graph - onnxscript_graph = graph_building.TorchScriptGraph() - tracer = graph_building.TorchScriptTracingEvaluator(onnxscript_graph) + graph = ir.Graph( + (), + (), + nodes=(), + opset_imports={ + "": 18, + "pkg.torch.onnx": 1, + "pkg.onnxscript.torch_lib.common": 1, + "pkg.onnxscript.torch_lib": 1, + }, + name="main_graph", + ) + opset = onnxscript.opset18 + tracer = _building.OpRecorder(opset, {}) ort_inputs = {} onnxscript_args: list[Any] = [] onnxscript_kwargs = {} for i, arg in enumerate(args): if isinstance(arg, np.ndarray): input_name = f"input_{i}" - input = onnxscript_graph.add_input( - input_name, - torch.tensor(arg).shape, - torch.tensor(arg).dtype, + input = _tensors.SymbolicTensor( + opset=opset, + name=input_name, + shape=ir.Shape(arg.shape), + type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(arg).dtype]), ) + graph.inputs.append(input) onnxscript_args.append(input) ort_inputs[input_name] = arg elif isinstance(arg, (list, tuple)): @@ -480,11 +530,13 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, if isinstance(subarg, np.ndarray): input_name = f"input_{i}_{j}" tensor = torch.tensor(subarg) - input = onnxscript_graph.add_input( - input_name, - tensor.shape, - tensor.dtype, + input = _tensors.SymbolicTensor( + opset=opset, + name=input_name, + shape=ir.Shape(tensor.shape), + type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[tensor.dtype]), ) + graph.inputs.append(input) sequence_input.append(input) ort_inputs[input_name] = subarg else: @@ -496,11 +548,13 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, onnxscript_args.append(arg) for key, value in kwargs.items(): if isinstance(value, np.ndarray): - input = onnxscript_graph.add_input( - key, - torch.tensor(value).shape, - torch.tensor(value).dtype, + input = _tensors.SymbolicTensor( + opset=opset, + name=key, + shape=ir.Shape(torch.tensor(value).shape), + type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(value).dtype]), ) + graph.inputs.append(input) ort_inputs[key] = value onnxscript_kwargs[key] = input else: @@ -514,38 +568,48 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, # We need to set the size of the output tensors for the ONNX model to be valid for output, symbolic_output in zip(outputs, symbolic_outputs): if isinstance(output, Sequence): - # Output is a sequence, skip setting the type and leave it - # for ONNX shape_inference to handle + # Output is a sequence + elem_dtype = _TORCH_DTYPE_TO_ONNX[output[0].dtype] + symbolic_output.type = ir.SequenceType(ir.TensorType(elem_dtype)) continue output = ( output if isinstance(output, torch.Tensor) else torch.tensor(output, device="cpu") ) - symbolic_output.shape = output.shape - symbolic_output.dtype = output.dtype - - onnxscript_graph.register_outputs(symbolic_outputs) - - onnx_model = onnxscript_graph.to_model_proto(TEST_OPSET_VERSION) - onnx_model = onnx.shape_inference.infer_shapes(onnx_model, data_prop=True) + symbolic_output.shape = ir.Shape(output.shape) + symbolic_output.dtype = _TORCH_DTYPE_TO_ONNX[output.dtype] + + graph.outputs.extend(symbolic_outputs) + graph.extend(tracer.nodes) + onnx_model = ir.Model(graph, ir_version=10, producer_name="torch_test") + for identifier, onnxscript_function in tracer.functions.items(): + if identifier in onnx_model.functions: + continue + if isinstance(onnxscript_function, ir.Function): + ir_function = onnxscript_function + else: + # TODO: Get IR function directly when onnxscript is updated + proto = onnxscript_function.to_function_proto() + ir_function = ir.serde.deserialize_function(proto) + onnx_model.functions[identifier] = ir_function + add_torchlib_common_imports(onnx_model) # Make sure the model is valid + model_proto = ir.to_proto(onnx_model) try: - onnx.checker.check_model(onnx_model, full_check=True) + onnx.checker.check_model(model_proto, full_check=True) except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: - raise AssertionError( - f"ONNX model is invalid. Model:\n{onnx.printer.to_text(onnx_model)}" - ) from e - + raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e + model_proto = onnx.shape_inference.infer_shapes(model_proto, data_prop=True) try: if ( os.environ.get("CATCH_ORT_SEGFAULT") == "1" or os.environ.get("CREATE_REPRODUCTION_REPORT") == "1" ): # Use an individual process to run ONNX Runtime to catch segfaults - return _safe_ort_session_run(onnx_model.SerializeToString(), ort_inputs) + return _safe_ort_session_run(model_proto.SerializeToString(), ort_inputs) - return _ort_session_run(onnx_model.SerializeToString(), ort_inputs) + return _ort_session_run(model_proto.SerializeToString(), ort_inputs) except ( # pylint: disable=c-extension-no-member onnxruntime.capi.onnxruntime_pybind11_state.Fail, @@ -557,26 +621,26 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, ) as e: if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": error_reproduction.create_reproduction_report( - test_name, onnx_model, ort_inputs, e + test_name, model_proto, ort_inputs, e ) raise RuntimeError( "ONNX Runtime failed to evaluate:\n" - + _format_model_and_input_information(onnx_model, ort_inputs) + + _format_model_and_input_information(model_proto, ort_inputs) ) from e except OrtAbortedError as e: if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": # Save the model and inputs to a file for reproduction error_reproduction.create_reproduction_report( - test_name, onnx_model, ort_inputs, e + test_name, model_proto, ort_inputs, e ) raise OrtAbortedError( "ONNX Runtime aborted:\n" - + _format_model_and_input_information(onnx_model, ort_inputs) + + _format_model_and_input_information(model_proto, ort_inputs) ) from e except Exception as e: if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": error_reproduction.create_reproduction_report( - test_name, onnx_model, ort_inputs, e + test_name, model_proto, ort_inputs, e ) raise diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 5a4cb195cc..b60fd8cf31 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Test op correctness by comparing with PyTorch results. ## Usage @@ -11,8 +13,7 @@ 1. To enable test cases for an operator Add a `TorchLibOpInfo` entry to `TORCH_LIB_OPINFO` in `ops_test_data.py`. - Explicitly specify `trace_only` if the op is trace_only. Specify `complex` - if the function is designed for complex inputs. + Specify `complex` if the function is designed for complex inputs. The `op_info_name` in `TorchLibOpInfo` needs to be unique in the TORCH_LIB_OPINFO list, but complex=True ops can share the same name with non-complex ops @@ -38,6 +39,7 @@ import copy import dataclasses import functools +import sys from typing import Any, Callable, Collection, Optional import numpy as np @@ -46,12 +48,12 @@ from torch.testing._internal.opinfo import definitions as opinfo_definitions from typing_extensions import Self -from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib import _flags from onnxscript.function_libs.torch_lib.ops import core as core_ops from onnxscript.function_libs.torch_lib.ops import fft as fft_ops from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops from onnxscript.function_libs.torch_lib.ops import nn as nn_ops +from onnxscript.function_libs.torch_lib.ops import prims as prims_ops from onnxscript.function_libs.torch_lib.ops import special as special_ops from onnxscript.function_libs.torch_lib.ops import vision as vision_ops from tests.function_libs.torch_lib import extra_opinfo, ops_test_common @@ -72,8 +74,6 @@ class TorchLibOpInfo: op_info_name: str # The torchlib ONNX Function to test op: Callable[..., Any] - # Explicitly specify when the op is trace_only - trace_only: bool = False # The input wrangler function to adjust the input to fit the aten signature input_wrangler: Optional[ Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]] @@ -255,10 +255,8 @@ def _embedding_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: """Remove arguments not present in the aten op signature.""" - if "max_norm" in kwargs: - del kwargs["max_norm"] - if "norm_type" in kwargs: - del kwargs["norm_type"] + kwargs.pop("max_norm", None) + kwargs.pop("norm_type", None) return args, kwargs @@ -266,16 +264,7 @@ def _empty_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: """Remove arguments not present in the aten op signature.""" - if "requires_grad" in kwargs: - del kwargs["requires_grad"] - return args, kwargs - - -def _flip_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Make the dims as tensor - kwargs["dims"] = np.array(kwargs["dims"], dtype=np.int64) + kwargs.pop("requires_grad", None) return args, kwargs @@ -292,12 +281,39 @@ def _grid_sample_input_wrangler( return args, kwargs -def _linalg_vector_norm_input_wrangler( +def _im2col_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - # Make the dims as tensor - if "dim" in kwargs: - kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64) + # Move kernel_size, dilation, padding and stride from args to kwargs + if len(args) == 5: + # Handle stride + stride = args.pop() + if isinstance(stride, np.ndarray): # convert stride to list[int] + stride = stride.tolist() + kwargs["stride"] = stride + # Handle padding + padding = args.pop() + if isinstance(padding, np.ndarray): # convert padding to list[int] + padding = padding.tolist() + kwargs["padding"] = padding + # Handle dilation + dilation = args.pop() + if isinstance(dilation, np.ndarray): # convert dilation to list[int] + dilation = dilation.tolist() + kwargs["dilation"] = dilation + # Handle kernel_size + kernel_size = args.pop() + if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int] + kernel_size = kernel_size.tolist() + kwargs["kernel_size"] = kernel_size + + return args, kwargs + + +def _index_put_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + args[1] = [np.array(elem) for elem in args[1]] return args, kwargs @@ -305,8 +321,7 @@ def _max_pool_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: # Remove return_indices argument because this op doesn't accept it - if "return_indices" in kwargs: - del kwargs["return_indices"] + kwargs.pop("return_indices", None) return args, kwargs @@ -344,18 +359,7 @@ def _nll_loss_input_wrangler( def _nonzero_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - if "as_tuple" in kwargs: - del kwargs["as_tuple"] - return args, kwargs - - -def _permute_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Change the dims argument back to a list because ONNX Transpose does not - # support dynamic perms - kwargs["dims"] = args.pop() - kwargs["dims"] = kwargs["dims"].tolist() + kwargs.pop("as_tuple", None) return args, kwargs @@ -392,17 +396,20 @@ def _roll_input_wrangler( dims = args.pop(2) kwargs["dims"] = [] kwargs["dims"].append(dims) - if len(args) >= 2: - if isinstance(args[1], int): # convert shift to tensor - args[1] = np.array([args[1]], dtype=np.int64) + if isinstance(args[1], np.ndarray): # convert shift to list[int] + shifts = args.pop(1) + kwargs["shifts"] = shifts.tolist() + elif isinstance(args[1], int): + shifts = args.pop(1) + kwargs["shifts"] = [] + kwargs["shifts"].append(shifts) return args, kwargs def _scalar_tensor_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - if "requires_grad" in kwargs: - del kwargs["requires_grad"] + kwargs.pop("requires_grad", None) return args, kwargs @@ -422,13 +429,6 @@ def _sum_input_wrangler( return args, kwargs -def _unflatten_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - args[1] = np.array(args[1], dtype=np.int64) - return args, kwargs - - def _where_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -445,74 +445,37 @@ def _where_input_wrangler( "ops.aten._fft_c2c", # Custom from extra_opinfo fft_ops.aten__fft_c2c, tolerance={torch.complex64: (3e-3, 1.8e-4)}, - trace_only=True, complex=True, ), TorchLibOpInfo( "ops.aten._fft_c2r", # Custom from extra_opinfo fft_ops.aten__fft_c2r, tolerance={torch.complex64: (3e-3, 1.8e-4)}, - trace_only=True, complex=True, - ).xfail( - dtypes=(torch.complex64,), - reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926", ), TorchLibOpInfo( "ops.aten._fft_r2c", # Custom from extra_opinfo fft_ops.aten__fft_r2c, tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, - trace_only=True, - ), - TorchLibOpInfo( - "ops.aten._local_scalar_dense", - core_ops.aten__local_scalar_dense, ), - TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax), + TorchLibOpInfo("ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense), TorchLibOpInfo( - "ops.aten._log_softmax_half", - core_ops.aten__log_softmax_half, - trace_only=True, + "ops.aten._log_softmax", + core_ops.aten__log_softmax, tolerance={torch.float16: (1e-3, 1e-3)}, - ) - .xfail( - reason="PyTorch does not implement _log_softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ), - TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax, trace_only=True), - TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half, trace_only=True) - .xfail( - reason="PyTorch does not implement _softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", ), + TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", ), - TorchLibOpInfo("all_dims", core_ops.aten_all_dims, trace_only=True).skip( + TorchLibOpInfo("all_dims", core_ops.aten_all_dims).skip( matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple), reason="this overload requires dim to be a tuple", ), TorchLibOpInfo("allclose", core_ops.aten_allclose), - TorchLibOpInfo( - "all", - core_ops.aten_all, - ).skip( + TorchLibOpInfo("all", core_ops.aten_all).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -521,7 +484,7 @@ def _where_input_wrangler( TorchLibOpInfo("acos", core_ops.aten_acos), TorchLibOpInfo("acosh", core_ops.aten_acosh), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), - TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True), + TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), TorchLibOpInfo( "addbmm", core_ops.aten_addbmm, @@ -530,14 +493,6 @@ def _where_input_wrangler( TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv, tolerance={torch.float16: (3e-2, 1e-3)}), TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}), TorchLibOpInfo("addmm", core_ops.aten_addmm) - .xfail( - "decomposed", - reason=( - "The float attributes alpha/beta come in as int in the test cases, which breaks" - "eager mode. We don't need to care about this as long as the full graph tests pass" - ), - test_class_name="TestOutputConsistencyEager", - ) .xfail( dtypes=(torch.int16, torch.int32, torch.int64), reason="ONNX Runtime does not support int inputs to Gemm", @@ -547,138 +502,85 @@ def _where_input_wrangler( dtypes=(torch.int16, torch.int32, torch.int64), reason="ONNX Runtime does not support int inputs to Gemm", ) - .xfail( + .skip( "decomposed", matcher=lambda sample: torch.numel(sample.input) == 0 or torch.numel(sample.args[0]) == 0 or torch.numel(sample.args[1]) == 0, - reason="ONNX Runtime does not support zero sized inputs", + reason="zero sized inputs cannot be compared", ), - TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (1e-3, 1e-2)}), - TorchLibOpInfo( - "addr", - core_ops.aten_addr, - tolerance={torch.float16: (3e-3, 4e-3)}, - ), - TorchLibOpInfo( - "amax", - core_ops.aten_amax, - input_wrangler=_amin_amax_input_wrangler, - ).skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to ReduceMax-18. https://github.com/microsoft/onnxruntime/issues/16492", - ), - TorchLibOpInfo( - "amin", - core_ops.aten_amin, - input_wrangler=_amin_amax_input_wrangler, - ).skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to ReduceMin-18. https://github.com/microsoft/onnxruntime/issues/16492", - ), - TorchLibOpInfo( - "any", - core_ops.aten_any, - ).skip( + TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}), + TorchLibOpInfo("addr", core_ops.aten_addr, tolerance={torch.float16: (3e-3, 4e-3)}), + TorchLibOpInfo("amax", core_ops.aten_amax, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("amin", core_ops.aten_amin, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("any", core_ops.aten_any).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), - TorchLibOpInfo( - "any_dim", - core_ops.aten_any_dim, - ).skip( + TorchLibOpInfo("any_dim", core_ops.aten_any_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", ), - TorchLibOpInfo("any_dims", core_ops.aten_any_dims, trace_only=True).skip( + TorchLibOpInfo("any_dims", core_ops.aten_any_dims).skip( matcher=lambda sample: not isinstance(sample.kwargs.get("dim"), tuple), reason="this overload requires dim to be a tuple", ), TorchLibOpInfo("asin", core_ops.aten_asin), TorchLibOpInfo("asinh", core_ops.aten_asinh), TorchLibOpInfo("atan", core_ops.aten_atan), - TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}), + TorchLibOpInfo("atan2", core_ops.aten_atan2), TorchLibOpInfo("atanh", core_ops.aten_atanh), TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_1d_Sequence", - core_ops.aten_atleast_1d_sequence, - ) + TorchLibOpInfo("atleast_1d_Sequence", core_ops.aten_atleast_1d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_2d", core_ops.aten_atleast_2d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_2d_Sequence", - core_ops.aten_atleast_2d_sequence, - ) + TorchLibOpInfo("atleast_2d_Sequence", core_ops.aten_atleast_2d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_3d", core_ops.aten_atleast_3d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_3d_Sequence", - core_ops.aten_atleast_3d_sequence, - ) + TorchLibOpInfo("atleast_3d_Sequence", core_ops.aten_atleast_3d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), + TorchLibOpInfo( + "bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)} + ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we # also define test data for this ID with @@ -690,79 +592,70 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p), TorchLibOpInfo("bitwise_and", core_ops.aten_bitwise_and), - TorchLibOpInfo("bitwise_left_shift_int16", core_ops.aten_bitwise_left_shift_int16), - TorchLibOpInfo("bitwise_left_shift_int32", core_ops.aten_bitwise_left_shift_int32), - TorchLibOpInfo("bitwise_left_shift_int64", core_ops.aten_bitwise_left_shift_int64), - TorchLibOpInfo("bitwise_left_shift_int8", core_ops.aten_bitwise_left_shift_int8), + TorchLibOpInfo("bitwise_left_shift", core_ops.aten_bitwise_left_shift), TorchLibOpInfo("bitwise_not", core_ops.aten_bitwise_not), TorchLibOpInfo("bitwise_or", core_ops.aten_bitwise_or), - TorchLibOpInfo("bitwise_right_shift_int16", core_ops.aten_bitwise_right_shift_int16), - TorchLibOpInfo("bitwise_right_shift_int32", core_ops.aten_bitwise_right_shift_int32), - TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), - TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), + TorchLibOpInfo("bitwise_right_shift", core_ops.aten_bitwise_right_shift), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), + TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to), TorchLibOpInfo("cat", core_ops.aten_cat).skip( - matcher=lambda sample: sample.input[0].equal(torch.tensor([])), + matcher=lambda sample: sample.input[0].equal( + torch.tensor([]).to(sample.input[0].device) + ), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), - TorchLibOpInfo("cat", core_ops.aten_cat_complex, trace_only=True, complex=True).skip( - matcher=lambda sample: sample.input[0].equal(torch.tensor([])), + TorchLibOpInfo("cat", core_ops.aten_cat_complex, complex=True).skip( + matcher=lambda sample: sample.input[0].equal( + torch.tensor([]).to(sample.input[0].device) + ), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("ceil", core_ops.aten_ceil), - TorchLibOpInfo( - "chunk", - core_ops.aten_chunk, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + TorchLibOpInfo("chunk", core_ops.aten_chunk), + TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( + reason="Size 0 inputs are not handled by design", + matcher=lambda sample: sample.input.numel() == 0, ), - TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max).skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ), - TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min).skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", + TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min_tensor).skip( + reason="Size 0 inputs are not handled by design", + matcher=lambda sample: sample.input.numel() == 0, ), TorchLibOpInfo("clone", core_ops.aten_clone), - TorchLibOpInfo("complex", core_ops.aten_complex, trace_only=True), - TorchLibOpInfo("concat", core_ops.aten_concat).skip( - matcher=lambda sample: sample.input[0].equal(torch.tensor([])), + TorchLibOpInfo("complex", core_ops.aten_complex), + TorchLibOpInfo("concat", core_ops.aten_cat).skip( + matcher=lambda sample: sample.input[0].equal( + torch.tensor([]).to(sample.input[0].device) + ), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), - TorchLibOpInfo("concatenate", core_ops.aten_concatenate).skip( - matcher=lambda sample: sample.input[0].equal(torch.tensor([])), + TorchLibOpInfo("concatenate", core_ops.aten_cat).skip( + matcher=lambda sample: sample.input[0].equal( + torch.tensor([]).to(sample.input[0].device) + ), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("conj", core_ops.aten_conj), - TorchLibOpInfo("conj", core_ops.aten_conj_complex, complex=True, trace_only=True), + TorchLibOpInfo("conj", core_ops.aten_conj_complex, complex=True), TorchLibOpInfo("constant_pad_nd", core_ops.aten_constant_pad_nd), # TorchLibOpInfo("copy", core_ops.aten_copy), # copy is not in OPS_DB TorchLibOpInfo("cos", core_ops.aten_cos), TorchLibOpInfo("cosh", core_ops.aten_cosh), - TorchLibOpInfo("cross", core_ops.aten_cross, tolerance={torch.float16: (6e-3, 3e-3)}), + TorchLibOpInfo("cross", core_ops.aten_cross, tolerance={torch.float16: (6e-2, 2e-1)}).skip( + dtypes=(torch.float16 if sys.platform != "linux" else torch.complex64,), + reason="fixme: test is failing on windows and torch nightly", + ), TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB - TorchLibOpInfo("diagonal", core_ops.aten_diagonal, trace_only=True), - TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool, trace_only=True), + TorchLibOpInfo("diagonal", core_ops.aten_diagonal), TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", ), TorchLibOpInfo("true_divide", core_ops.aten_div), TorchLibOpInfo("true_divide", core_ops.aten_div_complex, complex=True), - TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True) + TorchLibOpInfo("div_mode", core_ops.aten_div_mode) .skip( variant_name="no_rounding_mode", reason="this variation requires the rounding_mode argument", @@ -772,16 +665,6 @@ def _where_input_wrangler( dtypes=(torch.float16,), # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", - ) - .xfail( - variant_name="floor_rounding", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - reason="fixme: off-by-one and inverted inf. https://github.com/microsoft/onnxscript/issues/989", - ), - TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int, trace_only=True).skip( - variant_name="no_rounding_mode", - reason="this variation requires the rounding_mode argument", ), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( @@ -790,12 +673,9 @@ def _where_input_wrangler( input_wrangler=_empty_input_wrangler, nondeterministic=True, ), - TorchLibOpInfo( - "einsum", core_ops.aten_einsum, trace_only=True, input_wrangler=_einsum_input_wrangler - ) + TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler) .xfail( - reason="fixme: PyTorch produces int64 output with int32 input", - dtypes=(torch.int32,), + reason="fixme: PyTorch produces int64 output with int32 input", dtypes=(torch.int32,) ) .xfail( reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739", @@ -810,62 +690,54 @@ def _where_input_wrangler( TorchLibOpInfo("expand_as", core_ops.aten_expand_as), TorchLibOpInfo("erf", special_ops.aten_special_erf), TorchLibOpInfo( - "erfc", special_ops.aten_special_erfc, tolerance={torch.float16: (1e-2, 2e-4)} + "erfc", special_ops.aten_special_erfc, tolerance={torch.float16: (5e-1, 2e-4)} + ), + TorchLibOpInfo( + "expm1", special_ops.aten_special_expm1, tolerance={torch.float16: (1e-2, 2e-4)} ), TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail( reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223" ), TorchLibOpInfo("fill", core_ops.aten_fill), - TorchLibOpInfo("flip", core_ops.aten_flip, input_wrangler=_flip_input_wrangler), - TorchLibOpInfo("floor", core_ops.aten_floor), - TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide).xfail( - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - reason="fixme: off-by-one issue due to numerical precision. https://github.com/microsoft/onnxscript/issues/989", + TorchLibOpInfo("flip", core_ops.aten_flip).skip( + reason="fixme: size 0 inputs are not handled yet", + matcher=lambda sample: sample.input.numel() == 0, ), + TorchLibOpInfo("flatten", core_ops.aten_flatten), + TorchLibOpInfo("floor", core_ops.aten_floor), + TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), - TorchLibOpInfo( - "full_like_dtype", - core_ops.aten_full_like_dtype, - ).skip( - matcher=lambda sample: "dtype" not in sample.kwargs, - reason="this Aten overload only support dtype in kwargs", + TorchLibOpInfo("full_like", core_ops.aten_full_like).skip( + enabled_if=ops_test_common.IS_MACOS, reason="fixme: memory allocation issue on CI" ), - TorchLibOpInfo( - "full_like", - core_ops.aten_full_like, - ).skip( - matcher=lambda sample: ("dtype" in sample.kwargs), - reason="this Aten overload only support dtype not in kwargs", + TorchLibOpInfo("gather", core_ops.aten_gather).skip( + matcher=lambda sample: sample.input.numel() == 0 or sample.args[1].numel() == 0, + reason="fixme: ORT does not support empty tensors as input", ), - TorchLibOpInfo("gather", core_ops.aten_gather), TorchLibOpInfo("ge", core_ops.aten_ge), - TorchLibOpInfo("ge_bool", core_ops.aten_ge_bool), TorchLibOpInfo("gt", core_ops.aten_gt), - TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB - TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index, trace_only=True), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool, trace_only=True), + TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool), TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool, + input_wrangler=_index_put_input_wrangler, ).skip( matcher=lambda sample: sample.args[0][0].dtype != torch.bool, reason="this Aten overload only supports tensor(bool) as indices", ), TorchLibOpInfo( - "index_put", - core_ops.aten_index_put, + "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler ) .skip( matcher=lambda sample: sample.args[0][0].dtype != torch.int64, reason="this Aten overload only supports tensor(int) as indices", ) .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.19"), dtypes=(torch.float16,), matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", @@ -884,9 +756,7 @@ def _where_input_wrangler( TorchLibOpInfo( "linalg.vector_norm", linalg_ops.aten_linalg_vector_norm, - trace_only=True, tolerance={torch.float16: (2e-3, 2e-3)}, - input_wrangler=_linalg_vector_norm_input_wrangler, ).skip( matcher=lambda sample: sample.kwargs.get("ord") == 6, dtypes=(torch.float16,), @@ -895,28 +765,24 @@ def _where_input_wrangler( TorchLibOpInfo( "linspace", core_ops.aten_linspace, - trace_only=True, tolerance={torch.float16: (2e-2, 2e-3)}, ) .xfail( dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", ) - .xfail( - variant_name="tensor_overload", - dtypes=(torch.int64, torch.int32), + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - enabled_if=not version_utils.torch_older_than("2.2"), ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), - TorchLibOpInfo("le_bool", core_ops.aten_le_bool), + TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), TorchLibOpInfo("log10", core_ops.aten_log10), TorchLibOpInfo("log1p", core_ops.aten_log1p), TorchLibOpInfo( "log_softmax", special_ops.aten_special_log_softmax, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (4e-4, 6e-3)}, ) .xfail( @@ -950,80 +816,43 @@ def _where_input_wrangler( TorchLibOpInfo("logdet", core_ops.aten_logdet), TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp), TorchLibOpInfo("lt", core_ops.aten_lt), - TorchLibOpInfo("lt_bool", core_ops.aten_lt_bool), TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", ), + TorchLibOpInfo("masked_scatter", core_ops.aten_masked_scatter), TorchLibOpInfo( "matmul", core_ops.aten_matmul, # Windows requires a more relaxed tolerance - tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (2e-3, 2e-2)}, + tolerance={torch.float32: (2e-5, 2e-5), torch.float16: (1e-2, 2e-2)}, ).skip( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), - TorchLibOpInfo("maximum", core_ops.aten_maximum).skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ), - TorchLibOpInfo("maximum_bool", core_ops.aten_maximum_bool), - TorchLibOpInfo( - "mean", - core_ops.aten_mean, - input_wrangler=_mean_input_wrangler, - ).skip( + TorchLibOpInfo("maximum", core_ops.aten_maximum), + TorchLibOpInfo("mean", core_ops.aten_mean, input_wrangler=_mean_input_wrangler).skip( matcher=lambda sample: sample.kwargs.get("dim") is not None, reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo( - "mean_dim", - core_ops.aten_mean_dim, - input_wrangler=_mean_input_wrangler, + "mean_dim", core_ops.aten_mean_dim, input_wrangler=_mean_input_wrangler ).skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="this Aten overload can accept 2 inputs:(self, dim)", ), TorchLibOpInfo("mH", core_ops.aten_mH), - TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True, trace_only=True), - TorchLibOpInfo("min_dim", core_ops.aten_min_dim) - .skip( - variant_name="reduction_with_dim", - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ) - .xfail( - variant_name="reduction_with_dim", - dtypes=(torch.int64,), - reason="fixme: ORT did not implement Min for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ) - .xfail( - variant_name="reduction_with_dim", - reason="fixme: ORT Graph attribute inferencing failed https://github.com/onnx/onnx/issues/4986", - test_class_name="TestOutputConsistencyFullGraph", - enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, - ) - .xfail( + TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True), + TorchLibOpInfo("min_dim", core_ops.aten_min_dim).xfail( matcher=lambda sample: len(sample.args) == 0 or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "min", - core_ops.aten_min, - ).skip( + TorchLibOpInfo("min", core_ops.aten_min).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), - TorchLibOpInfo("minimum", core_ops.aten_minimum).skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ), - TorchLibOpInfo("minimum_bool", core_ops.aten_minimum_bool), + TorchLibOpInfo("minimum", core_ops.aten_minimum), TorchLibOpInfo("mm", core_ops.aten_mm).skip( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", @@ -1032,119 +861,19 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), + TorchLibOpInfo("mv", core_ops.aten_mv, tolerance={torch.float16: (3e-2, 1e-2)}), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), TorchLibOpInfo("neg", core_ops.aten_neg), + TorchLibOpInfo("new_empty", core_ops.aten_new_empty, nondeterministic=True), TorchLibOpInfo( - "new_empty_dtype", - core_ops.aten_new_empty_dtype, - nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="this Aten overload must have 3 inputs:(self, size, dtype)", - ), - TorchLibOpInfo( - "new_empty", - core_ops.aten_new_empty, - nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="this Aten overload only accept 2 inputs:(self, size)", - ), - TorchLibOpInfo( - "new_empty_strided_dtype", - core_ops.aten_new_empty_strided_dtype, - nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="this Aten overload must have 4 inputs:(self, size, stride, dtype)", - ), - TorchLibOpInfo( - "new_empty_strided", - core_ops.aten_new_empty_strided, - nondeterministic=True, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="this Aten overload only accept 3 inputs:(self, size, stride)", - ), - TorchLibOpInfo( - "new_full_dtype", - core_ops.aten_new_full_dtype, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="this Aten overload must have 4 inputs:(self, size, fill_value, dtype)", - ), - TorchLibOpInfo( - "new_full", - core_ops.aten_new_full, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="this Aten overload only accept 3 inputs:(self, size, fill_value)", - ), - TorchLibOpInfo( - "new_ones_dtype", - core_ops.aten_new_ones_dtype, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="", - ), - TorchLibOpInfo( - "new_ones", - core_ops.aten_new_ones, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="", - ), - TorchLibOpInfo( - "new_zeros_dtype", - core_ops.aten_new_zeros_dtype, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is None, - reason="", - ), - TorchLibOpInfo( - "new_zeros", - core_ops.aten_new_zeros, - ).skip( - matcher=lambda sample: sample.kwargs.get("dtype") is not None, - reason="", - ), - TorchLibOpInfo( - "nn.functional.adaptive_avg_pool1d", - nn_ops.aten_adaptive_avg_pool1d, - ) - .xfail( - # Shape should be [N, C, D1] - matcher=lambda sample: sample.args[0] not in {1, (1,)}, - reason="only global pooling is supported; only batched inputs are supported", - ) - .xfail( - reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - ), - TorchLibOpInfo( - "nn.functional.adaptive_avg_pool2d", - nn_ops.aten_adaptive_avg_pool2d, - ).xfail( - matcher=lambda sample: sample.args[0] != (1, 1), - reason="only global pooling is supported; only batched inputs are supported", - ), - TorchLibOpInfo( - "nn.functional.adaptive_avg_pool3d", - nn_ops.aten_adaptive_avg_pool3d, - ) - .xfail( - matcher=lambda sample: sample.args[0] != (1, 1, 1), - reason="only global pooling is supported; only batched inputs are supported", - ) - .xfail( - dtypes=(torch.float16,), - reason="fixme: RuntimeError: ORT inference error GlobalAveragePool. https://github.com/microsoft/onnxruntime/issues/16449", + "new_empty_strided", core_ops.aten_new_empty_strided, nondeterministic=True ), + TorchLibOpInfo("new_full", core_ops.aten_new_full), + TorchLibOpInfo("new_ones", core_ops.aten_new_ones), + TorchLibOpInfo("new_zeros", core_ops.aten_new_zeros), TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu), - TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted), TorchLibOpInfo( "nn.functional.cross_entropy", # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) @@ -1157,9 +886,7 @@ def _where_input_wrangler( reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[target] as int type", ), TorchLibOpInfo( - "nn.functional.dropout", - core_ops.aten_dropout, - input_wrangler=_dropout_input_wrangler, + "nn.functional.dropout", core_ops.aten_dropout, input_wrangler=_dropout_input_wrangler ).skip( matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0, reason="dropout is random so the result not match", @@ -1168,14 +895,12 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.embedding_bag", core_ops.aten_embedding_bag, - tolerance={torch.float16: (1e-2, 1e-2)}, - trace_only=True, + tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), - ), + ).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, - trace_only=True, tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), ), @@ -1200,30 +925,36 @@ def _where_input_wrangler( tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (8e-2, 4e-4)}, ), TorchLibOpInfo("nn.functional.mish", nn_ops.aten_mish), - TorchLibOpInfo( - "nn.functional.nll_loss_weight", - nn_ops.aten_nll_loss_weight, - tolerance={torch.float16: (5e-2, 1e-2)}, - input_wrangler=_nll_loss_input_wrangler, - ).skip( - matcher=lambda sample: "weight" not in sample.kwargs, - reason="this Aten overload need weight as kwargs", - ), TorchLibOpInfo( "nn.functional.nll_loss", nn_ops.aten_nll_loss, input_wrangler=_nll_loss_input_wrangler, - ).skip( - matcher=lambda sample: "weight" in sample.kwargs, - reason="this Aten overload doesn't accept weight as kwargs", + tolerance={torch.float16: (5e-2, 1e-2)}, + ), + TorchLibOpInfo("nn.functional.pad", nn_ops.aten_pad) + .skip(variant_name="circular", reason="fixme: ORT does not support the circular mode") + .skip( + variant_name="replicate_negative", + reason="fixme: The implementation for negative paddings is not correct", ), TorchLibOpInfo( - "ops.aten.reflection_pad1d", - nn_ops.aten_reflection_pad1d, + "nn.functional.pixel_shuffle", + core_ops.aten_pixel_shuffle, ).xfail( - dtypes=(torch.int64,), - reason="Torch not implement reflection_pad1d for int64.", + dtypes=(torch.int32, torch.int64), + reason="fixme: ONNX Runtime does not support int32/64 inputs", + ), + TorchLibOpInfo( + "nn.functional.pixel_unshuffle", + core_ops.aten_pixel_unshuffle, + ).xfail( + dtypes=(torch.int32, torch.int64), + reason="fixme: ONNX Runtime does not support int32/64 inputs", ), + TorchLibOpInfo( + "ops.aten.reflection_pad1d", + nn_ops.aten_reflection_pad1d, + ).xfail(dtypes=(torch.int64,), reason="Torch not implement reflection_pad1d for int64."), TorchLibOpInfo( "nn.functional.reflection_pad2d", nn_ops.aten_reflection_pad2d, @@ -1232,38 +963,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "reflect"), reason="this Aten overload need args[1] == 'reflect' for pad mode", ), - TorchLibOpInfo( - "nn.functional.relu", - nn_ops.aten_relu, - ) - .xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ) - .xfail( - dtypes=(torch.int64,), - test_class_name="TestOutputConsistencyEager", - reason="fixme: ORT fails with 'Could not find an implementation for Relu(14) node'", - ), - TorchLibOpInfo( - "nn.functional.relu6", - nn_ops.aten_relu6, - ) - .xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ) - .xfail( - dtypes=(torch.int64,), - test_class_name="TestOutputConsistencyEager", - reason="fixme: ORT fails with 'Could not find an implementation for Relu(14) node'", - ), - TorchLibOpInfo( - "ops.aten.replication_pad1d", - nn_ops.aten_replication_pad1d, - ), + TorchLibOpInfo("nn.functional.relu", nn_ops.aten_relu), + TorchLibOpInfo("nn.functional.relu6", nn_ops.aten_relu6), + TorchLibOpInfo("ops.aten.replication_pad1d", nn_ops.aten_replication_pad1d), TorchLibOpInfo( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d, @@ -1273,10 +975,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", ) - .xfail( + .skip( variant_name="replicate_negative", - enabled_if=not version_utils.torch_older_than("2.2"), - reason="fixme: negative padding is not implemented yet", + reason="fixme: The implementation for negative paddings is not correct. Potentially an ORT issue", ), TorchLibOpInfo( "nn.functional.replication_pad3d", @@ -1292,15 +993,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("nn.functional.selu", core_ops.aten_selu), TorchLibOpInfo( - "nn.functional.mse_loss", - nn_ops.aten_mse_loss, - input_wrangler=_mse_loss_input_wrangler, + "nn.functional.mse_loss", nn_ops.aten_mse_loss, input_wrangler=_mse_loss_input_wrangler ), - TorchLibOpInfo( - "nonzero", - core_ops.aten_nonzero, - input_wrangler=_nonzero_input_wrangler, - ) + TorchLibOpInfo("nonzero", core_ops.aten_nonzero, input_wrangler=_nonzero_input_wrangler) .xfail( matcher=lambda sample: sample.kwargs.get("as_tuple"), reason="as_tuple=True is not supported", @@ -1314,17 +1009,6 @@ def _where_input_wrangler( matcher=lambda sample: len(sample.args) > 0 and not isinstance(sample.args[0], float), reason="ORT only accept float type for args[0] 'mean'", ) - .xfail( - reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - ) - .xfail( - variant_name="number_mean", - reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - ) .xfail( variant_name="number_mean", reason="This variant does not support dtype as an argument", @@ -1334,122 +1018,105 @@ def _where_input_wrangler( "ops.aten.normal.float_Tensor", core_ops.aten_normal_float_tensor, nondeterministic=True, - ).xfail( - reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", ), TorchLibOpInfo( "ops.aten.normal.Tensor_float", core_ops.aten_normal_tensor_float, nondeterministic=True, - ).xfail( - reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", ), TorchLibOpInfo( "ops.aten.normal.Tensor_Tensor", core_ops.aten_normal_tensor_tensor, nondeterministic=True, - ).xfail( - reason="ORT fails on a cast node it inserts for float16. https://github.com/microsoft/onnxruntime/issues/16449", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", ), TorchLibOpInfo("ones", core_ops.aten_ones), - TorchLibOpInfo( - "permute", - core_ops.aten_permute, - input_wrangler=_permute_input_wrangler, - trace_only=True, - ), + TorchLibOpInfo("permute", core_ops.aten_permute), TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), + TorchLibOpInfo("prod", core_ops.aten_prod).skip( + matcher=lambda sample: sample.kwargs.get("dim") is not None + or sample.kwargs.get("keepdim") is not None + or sample.kwargs.get("dtype") != -1, + reason="this Aten overload only accept 1 inputs: self", + ), + TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip( + matcher=lambda sample: ( + sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None + ) + or sample.kwargs.get("dtype") != -1, + reason="this Aten overload can accept 3 inputs:(self, dim, keepdim)", + ), + TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu), TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True), - TorchLibOpInfo( - "ops.aten.rand_like__dtype", core_ops.aten_rand_like_dtype, nondeterministic=True - ), TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True), TorchLibOpInfo("ops.aten.randint.low", core_ops.aten_randint_low, nondeterministic=True), TorchLibOpInfo("ops.aten.randint_like", core_ops.aten_randint_like, nondeterministic=True), - TorchLibOpInfo( - "ops.aten.randint_like__dtype", core_ops.aten_randint_like_dtype, nondeterministic=True - ), TorchLibOpInfo( "ops.aten.randint_like.low_dtype", core_ops.aten_randint_like_low_dtype, nondeterministic=True, ), - TorchLibOpInfo( - "ops.aten.randint_like.low_dtype__dtype", - core_ops.aten_randint_like_low_dtype_dtype, - nondeterministic=True, - ), TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( - dtypes=(torch.float16,), - reason="fixme: Shape inference error", + dtypes=(torch.float16,), reason="fixme: Shape inference error" ), TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), - TorchLibOpInfo( - "ops.aten.randn_like_dtype", core_ops.aten_randn_like_dtype, nondeterministic=True - ), TorchLibOpInfo("rad2deg", core_ops.aten_rad2deg), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), - TorchLibOpInfo( - "remainder", - core_ops.aten_remainder, - ).xfail( - dtypes=(torch.float16,), - reason="Eager mode failed on case(self=7.75,other=0.1582) due to precision loss", - test_class_name="TestOutputConsistencyEager", - ), + TorchLibOpInfo("remainder", core_ops.aten_remainder), TorchLibOpInfo("repeat", core_ops.aten_repeat), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) + .skip( + matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is a Tensor"), + ) + .skip(dtypes=(torch.bool,), reason="bool not supported") + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor) + .skip( + matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is an int"), + ) + .skip(dtypes=(torch.bool,), reason="bool not supported") + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), - TorchLibOpInfo("round", core_ops.aten_round) - .xfail( - variant_name="decimals_0", - reason="This variant does not accept decimals", - test_class_name="TestOutputConsistencyEager", - ) - .xfail( - variant_name="decimals_3", - reason="This variant does not accept decimals", - ) - .xfail( - variant_name="decimals_neg_3", - reason="This variant does not accept decimals", + TorchLibOpInfo("round", core_ops.aten_round).skip( + matcher=lambda sample: sample.kwargs.get("decimals") is not None, + reason="this Aten overload only support one tensor as input and one int as args by design", ), TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), - TorchLibOpInfo("rsub", core_ops.aten_rsub), - TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True, trace_only=True), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor, input_wrangler=_scalar_tensor_input_wrangler, - trace_only=True, ), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor, input_wrangler=_scalar_tensor_input_wrangler, - trace_only=True, complex=True, ), TorchLibOpInfo( - "ops.aten.scalar_tensor", - core_ops.aten_scalar_tensor_complex, - trace_only=True, - complex=True, + "ops.aten.scalar_tensor", core_ops.aten_scalar_tensor_complex, complex=True ), - TorchLibOpInfo( - "scatter_add", - core_ops.aten_scatter_add, - ) + TorchLibOpInfo("scatter_add", core_ops.aten_scatter_add) .xfail( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch. https://github.com/onnx/onnx/issues/4986", @@ -1462,6 +1129,7 @@ def _where_input_wrangler( TorchLibOpInfo("select_scatter", core_ops.aten_select_scatter), TorchLibOpInfo("sigmoid", core_ops.aten_sigmoid), TorchLibOpInfo("sign", core_ops.aten_sign), + TorchLibOpInfo("nn.functional.silu", nn_ops.aten_silu), TorchLibOpInfo("sin", core_ops.aten_sin), TorchLibOpInfo( "sinc", special_ops.aten_special_sinc, tolerance={torch.float16: (1e-2, 6e-4)} @@ -1470,7 +1138,6 @@ def _where_input_wrangler( TorchLibOpInfo( "softmax", core_ops.aten_softmax, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (3e-4, 4e-4)}, ) .xfail( @@ -1485,140 +1152,90 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyFullGraph", ) .skip( - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: SoftMax does not support empty tensor as input", - ) - .skip( - variant_name="with_dtype", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: SoftMax does not support empty tensor as input", - ), - TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail( - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", - test_class_name="TestOutputConsistencyEager", - ), - TorchLibOpInfo( - "split_with_sizes", - core_ops.aten_split_with_sizes, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), - TorchLibOpInfo( - "split", - core_ops.aten_split, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + matcher=lambda sample: len(sample.input.shape) == 0, + reason="fixme: SoftMax does not support empty tensor as input", ) - .xfail( - variant_name="list_args", - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + .skip( + variant_name="with_dtype", + matcher=lambda sample: len(sample.input.shape) == 0, + reason="fixme: SoftMax does not support empty tensor as input", ), + TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus), + TorchLibOpInfo("sort", core_ops.aten_sort).xfail( + dtypes=(torch.float16,), + reason="fixme: Tensor-likes are not close. Tests pass for float32.", + ), + TorchLibOpInfo("split_with_sizes", core_ops.aten_split_with_sizes), + TorchLibOpInfo("split", core_ops.aten_split), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim, - ).skip( + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) + .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", + ) + .skip( + matcher=lambda sample: len(sample.input.shape) != 0 + and sample.input.shape[sample.args[0]] != 1, + reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim_complex, - complex=True, - trace_only=True, - ).skip( + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True) + .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", + ) + .skip( + matcher=lambda sample: len(sample.input.shape) != 0 + and sample.input.shape[sample.args[0]] != 1, + reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze", - core_ops.aten_squeeze, - ).skip( + TorchLibOpInfo("squeeze", core_ops.aten_squeeze).skip( matcher=lambda sample: len(sample.args) != 0, reason="this Aten overload only support one tensor as input by design", ), TorchLibOpInfo("stack", core_ops.aten_stack), - TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True, trace_only=True), - TorchLibOpInfo("sub", core_ops.aten_sub), - TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True, trace_only=True), + TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True), + TorchLibOpInfo("sub", core_ops.aten_sub, tolerance={torch.float16: (2e-3, 1e-3)}), + TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB - TorchLibOpInfo( - "t", - core_ops.aten_t, - ).xfail( + TorchLibOpInfo("t", core_ops.aten_t).xfail( enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986", test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("tan", core_ops.aten_tan), TorchLibOpInfo("tanh", core_ops.aten_tanh), - TorchLibOpInfo( - "tile", - core_ops.aten_tile, - ).skip( + TorchLibOpInfo("tile", core_ops.aten_tile).skip( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) or not sample.input.shape, reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), - TorchLibOpInfo("topk", core_ops.aten_topk).xfail( + TorchLibOpInfo("topk", core_ops.aten_topk) + .xfail( dtypes=(torch.int64, torch.int32), enabled_if=not ops_test_common.IS_WINDOWS, reason="fixme: result mismatch. https://github.com/microsoft/onnxscript/issues/853", + ) + .skip( + dtypes=(torch.float16,), + reason="fixme: result mismatch. https://github.com/microsoft/onnxscript/issues/853", + ) + .skip( + matcher=lambda sample: len(sample.input.shape) == 0 or sample.input.numel() == 0, + reason="scalar inputs or empty inputs are not handled", ), TorchLibOpInfo("tril", core_ops.aten_tril).xfail( - dtypes=(torch.int32, torch.bool), - reason="fixme: ORT does not have an implementation of Trilu for int32 or bool.", + dtypes=(torch.int32,), + reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("triu", core_ops.aten_triu).xfail( - dtypes=(torch.int32, torch.bool), - reason="fixme: ORT does not have an implementation of Trilu for int32 or bool.", + dtypes=(torch.int32,), + reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("trunc", core_ops.aten_trunc), - TorchLibOpInfo( - "unbind", - core_ops.aten_unbind, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), - TorchLibOpInfo( - "unflatten", - core_ops.aten_unflatten, - input_wrangler=_unflatten_input_wrangler, - ).xfail( - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), - reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", - ), - TorchLibOpInfo("unfold", core_ops.aten_unfold, trace_only=True), - TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold, trace_only=True), + TorchLibOpInfo("unbind", core_ops.aten_unbind), + TorchLibOpInfo("unflatten", core_ops.aten_unflatten), + TorchLibOpInfo("unfold", core_ops.aten_unfold), + TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze), TorchLibOpInfo("view", core_ops.aten_view), TorchLibOpInfo("view", core_ops.aten_view_complex, complex=True), @@ -1628,130 +1245,65 @@ def _where_input_wrangler( TorchLibOpInfo("view_as_real", core_ops.aten_view_as_real, complex=True), TorchLibOpInfo("view_as_real_copy", core_ops.aten_view_as_real_copy, complex=True), TorchLibOpInfo("view_copy", core_ops.aten_view_copy), - TorchLibOpInfo( - "vstack", - core_ops.aten_vstack, - ).xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x62afb00_rank). https://github.com/microsoft/onnxscript/issues/960", - ), TorchLibOpInfo("where", core_ops.aten_where, input_wrangler=_where_input_wrangler).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", ), TorchLibOpInfo("xlogy", special_ops.aten_special_xlogy), TorchLibOpInfo("zeros", core_ops.aten_zeros), - TorchLibOpInfo( - "arange_start_step", - core_ops.aten_arange_start_step, - trace_only=True, - ).xfail( + TorchLibOpInfo("arange_start_step", core_ops.aten_arange_start_step) + .skip( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="dtype needs to be specified for non-float tensors", + dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange_start", - core_ops.aten_arange_start, - trace_only=True, - ).skip( + TorchLibOpInfo("arange_start", core_ops.aten_arange_start) + .skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", - ), - TorchLibOpInfo( - "arange", - core_ops.aten_arange, - trace_only=True, ) + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="dtype needs to be specified for non-float tensors", + dtypes=(torch.float16, torch.int64, torch.int32), + ), + TorchLibOpInfo("arange", core_ops.aten_arange) .xfail( dtypes=(torch.int32,), reason="fixme: output shape mismatch in edge cases. https://github.com/microsoft/onnxscript/issues/974", ) - .xfail( + .skip( matcher=lambda sample: len(sample.args) != 0, reason="arange overload takes single argument", ) .xfail( matcher=lambda sample: sample.kwargs.get("end") is not None, reason="arange overload does not support positional 'end' argument", - ), - TorchLibOpInfo("argmax", core_ops.aten_argmax) - .skip( - matcher=lambda sample: "dim" in sample.kwargs, - reason="this overload does not support the 'dim' attribute by design", - ) - .skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ) - .xfail( - dtypes=(torch.int64,), - reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo("argmax_dim", core_ops.aten_argmax_dim) - .xfail( - matcher=lambda sample: "dim" not in sample.kwargs, - reason="this overload requires the 'dim' attribute by design", - ) - .skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ) - .xfail( - dtypes=(torch.int64,), - reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo("argmin", core_ops.aten_argmin) - .skip( - matcher=lambda sample: "dim" in sample.kwargs, - reason="this overload does not support the 'dim' attribute by design", - ) - .skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ) - .xfail( - dtypes=(torch.int64,), - reason="fixme: ORT did not implement ArgMin for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo("argmin_dim", core_ops.aten_argmin_dim) - .xfail( - matcher=lambda sample: "dim" not in sample.kwargs, - reason="this overload requires the 'dim' attribute by design", ) .skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ) - .xfail( - dtypes=(torch.int64,), - reason="fixme: ORT did not implement ArgMin for int64. https://github.com/microsoft/onnxruntime/issues/16654", + matcher=lambda sample: sample.kwargs.get("dtype") is None, + reason="dtype needs to be specified for non-float tensors", + dtypes=(torch.float16, torch.int64, torch.int32), ), + TorchLibOpInfo("argmax", core_ops.aten_argmax), + TorchLibOpInfo("argmin", core_ops.aten_argmin), TorchLibOpInfo( "as_strided", core_ops.aten_as_strided, - trace_only=True, - ).xfail( - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor", - ), - TorchLibOpInfo("clamp", core_ops.aten_clamp, trace_only=True).skip( - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", - ), + ).xfail(variant_name="partial_views", reason="ONNX doesn't have partial view for tensor"), + TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", nn_ops.aten_col2im, - trace_only=True, ).xfail( dtypes=(torch.float16,), reason="fixme: Tensor-likes are not close. https://github.com/microsoft/onnxruntime/issues/16007", ), - TorchLibOpInfo("cumsum", core_ops.aten_cumsum, trace_only=True).xfail( + TorchLibOpInfo("cumsum", core_ops.aten_cumsum).xfail( dtypes=(torch.int32,), reason="fixme: torch.cumsum with int32 inputs uses int64 as the output type", ), @@ -1759,34 +1311,36 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.convolution", core_ops.aten_convolution, - trace_only=True, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + tolerance={torch.float32: (2e-4, 9e-4)}, ), - TorchLibOpInfo( - "empty_like", core_ops.aten_empty_like, nondeterministic=True, trace_only=True - ), - TorchLibOpInfo( - "grid_sampler_2d", - core_ops.aten_grid_sampler_2d, - trace_only=True, - ).skip( + TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), + TorchLibOpInfo("grid_sampler_2d", core_ops.aten_grid_sampler_2d) + .skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", - ), - TorchLibOpInfo("heaviside", core_ops.aten_heaviside), + ) + .skip(dtypes=(torch.float16,), reason="fixme: Accuracy is not high enough"), TorchLibOpInfo( - "hstack", - core_ops.aten_hstack, + "nn.functional.group_norm", + nn_ops.aten_group_norm, + tolerance={torch.float16: (1e-2, 7e-3)}, ).xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme: RUNTIME_EXCEPTION : Exception during initialization: Invalid tensor data type 0. https://github.com/microsoft/onnxscript/issues/960", + matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), + reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input", + ), + TorchLibOpInfo( + "ops.aten.hamming_window", + core_ops.aten_hamming_window, + tolerance={torch.float32: (8e-2, 6e-3)}, ), + TorchLibOpInfo("ops.aten.hann_window", core_ops.aten_hann_window), + TorchLibOpInfo("heaviside", core_ops.aten_heaviside), TorchLibOpInfo( "nn.functional.grid_sample", core_ops.aten_grid_sampler, input_wrangler=_grid_sample_input_wrangler, - trace_only=True, + tolerance={torch.float16: (8e-2, 2e-3)}, ).skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.kwargs.get("mode") == "bicubic" @@ -1796,27 +1350,26 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.layer_norm", core_ops.aten_layer_norm, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, - ).xfail( - dtypes=(torch.int64,), - reason="fixme: ORT `LayerNormKernelImpl` not implemented for int64", - ), - TorchLibOpInfo( - "logit", core_ops.aten_logit, trace_only=True, tolerance={torch.float16: (1e-1, 7e-4)} - ), - TorchLibOpInfo("max_dim", core_ops.aten_max_dim) - .skip( - variant_name="reduction_with_dim", - matcher=lambda sample: len(sample.input.shape) == 0, - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", ) .xfail( - variant_name="reduction_with_dim", dtypes=(torch.int64,), - reason="fixme: ORT did not implement Max for int64. https://github.com/microsoft/onnxruntime/issues/16654", + reason="fixme: ORT `LayerNormKernelImpl` not implemented for int64", + ) + .skip( + matcher=lambda sample: sample.input.shape[-1] <= 1, + reason="fixme: onnxruntime fail when no reduction is needed", ) + .skip( + dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,), + reason="fixme: test is unstable on macosx, windows", + ), + TorchLibOpInfo("logical_and", core_ops.aten_logical_and), + TorchLibOpInfo("logical_not", core_ops.aten_logical_not), + TorchLibOpInfo("logical_or", core_ops.aten_logical_or), + TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor), + TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}), + TorchLibOpInfo("max_dim", core_ops.aten_max_dim) .xfail( variant_name="reduction_with_dim", reason="fixme: ORT Graph attribute inferencing failed https://github.com/onnx/onnx/issues/4986", @@ -1828,10 +1381,7 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "max", - core_ops.aten_max, - ).skip( + TorchLibOpInfo("max", core_ops.aten_max).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), @@ -1840,18 +1390,15 @@ def _where_input_wrangler( # Custom from extra_opinfo "ops.aten.max_pool1d", nn_ops.aten_max_pool1d, - trace_only=True, ), TorchLibOpInfo( # Custom from extra_opinfo "ops.aten.max_pool2d", nn_ops.aten_max_pool2d, - trace_only=True, ), TorchLibOpInfo( "ops.aten.max_pool3d", # Custom from extra_opinfo nn_ops.aten_max_pool3d, - trace_only=True, ).xfail( variant_name="empty_strides", reason="fixme: 'shape' do not match: torch.Size([2, 3, 4, 3]) != torch.Size([2, 3, 4, 2]). https://github.com/microsoft/onnxscript/issues/975", @@ -1859,7 +1406,6 @@ def _where_input_wrangler( TorchLibOpInfo( "native_batch_norm", core_ops.aten_native_batch_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, ) .skip( @@ -1871,26 +1417,33 @@ def _where_input_wrangler( device_type="cpu", dtypes=(torch.float16,), reason="native_batch_norm outputs different dtypes on CPU and CUDA. Our implematation is based on that for CUDA", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("training") is True + or sample.args[-3] is True, + reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit", core_ops.aten_native_batch_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, - ).skip( + ) + .skip( device_type="cpu", matcher=lambda sample: sample.kwargs.get("training") is False, reason="native_batch_norm outputs different shapes on CPU and CUDA when training is False. Our implematation is based on that for CUDA", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("training") is True + or sample.args[-3] is True, + reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( - "ops.aten._native_batch_norm_legit.no_stats", - core_ops.aten__native_batch_norm_no_stats, - trace_only=True, + "ops.aten._native_batch_norm_legit.no_stats", core_ops.aten__native_batch_norm_no_stats ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit_functional", core_ops.aten__native_batch_norm_legit_functional, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, ) .skip( @@ -1899,28 +1452,27 @@ def _where_input_wrangler( reason="native_batch_norm outputs different results on CPU and CUDA when training is False. Our implematation is based on that for CUDA", ) .skip( - dtypes=(torch.float16,), - device_type="cuda", - matcher=lambda sample: sample.kwargs.get("training") is True, - test_class_name="TestOutputConsistencyEager", - reason="fixme: output 4 (new_running_var) does not match the gpu output sometimes", + matcher=lambda sample: sample.kwargs.get("training") is True + or sample.args[-3] is True, + reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 7e-3)}, - ).xfail( - dtypes=(torch.float16,), - reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly", - enabled_if=version_utils.torch_older_than("2.2"), ), TorchLibOpInfo( "native_layer_norm", core_ops.aten_native_layer_norm, - trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4), torch.float16: (1e-1, 7e-4)}, - ).skip( + ) + .skip( + dtypes=(torch.float32,), + matcher=lambda sample: sample.input.shape[-1] <= 1, + # enabled_if=ops_test_common.IS_MACOS, + reason="fixme: result mismatch. https://github.com/microsoft/onnxruntime/issues/20676", + ) + .skip( dtypes=(torch.float16,), device_type="cpu", reason="native_layer_norm outputs different dtypes on CPU and CUDA. Our implematation is based on that for CUDA", @@ -1929,7 +1481,6 @@ def _where_input_wrangler( "nn.functional.avg_pool1d", nn_ops.aten_avg_pool1d, input_wrangler=_avg_pool_input_wrangler, - trace_only=True, ) .xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) @@ -1950,7 +1501,6 @@ def _where_input_wrangler( "nn.functional.avg_pool2d", nn_ops.aten_avg_pool2d, input_wrangler=_avg_pool_input_wrangler, - trace_only=True, ).xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) or (sample.kwargs.get("divisor_override") is not None), @@ -1960,7 +1510,6 @@ def _where_input_wrangler( "nn.functional.avg_pool3d", nn_ops.aten_avg_pool3d, input_wrangler=_avg_pool_input_wrangler, - trace_only=True, ) .xfail( matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) @@ -1974,7 +1523,6 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.conv1d", core_ops.aten_conv1d, - trace_only=True, ).xfail( matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str), reason="String padding is not accepted by aten::conv1d", @@ -1982,7 +1530,6 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.conv2d", core_ops.aten_conv2d, - trace_only=True, tolerance={torch.float32: (2e-5, 3e-5)}, ).xfail( matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str), @@ -1991,40 +1538,29 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.instance_norm", core_ops.aten_instance_norm, - trace_only=True, tolerance={torch.float16: (1e-2, 1e-3)}, ), TorchLibOpInfo( - "ops.aten.conv3d", - core_ops.aten_conv3d, - trace_only=True, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + "ops.aten.conv3d", core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)} ), + TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu), + TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), TorchLibOpInfo( - "nn.functional.gelu", - nn_ops.aten_gelu, - trace_only=True, - tolerance={torch.float16: (8e-2, 1e-4)}, - ), - TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip( - # input: input, args: weight, bias; so len(args) == 2 means bias is provided - matcher=lambda sample: len(sample.args) != 1, - reason="this overload is implemented for bias=None", + "nn.functional.linear", nn_ops.aten_linear, tolerance={torch.float16: (1e-2, 1e-3)} ), TorchLibOpInfo( - "nn.functional.linear_bias", - nn_ops.aten_linear_bias, - tolerance={torch.float16: (2e-1, 4e-4)}, - ).skip( - # input: input, args: weight, bias; so len(args) == 2 means bias is provided - matcher=lambda sample: len(sample.args) != 2, - reason="this overload is implemented for bias!=None", + "nn.functional.unfold", + nn_ops.aten_im2col, + input_wrangler=_im2col_input_wrangler, + ).xfail( + matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) + or not sample.input.shape, + reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), TorchLibOpInfo( "nn.functional.max_pool1d", nn_ops.aten_max_pool1d, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is True, reason="this aten overload assume return_indices=False", @@ -2033,7 +1569,6 @@ def _where_input_wrangler( "nn.functional.max_pool1d_with_indices", nn_ops.aten_max_pool1d_with_indices, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is False, reason="this aten overload assume return_indices=True", @@ -2042,7 +1577,6 @@ def _where_input_wrangler( "nn.functional.max_pool2d", nn_ops.aten_max_pool2d, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is True, reason="this aten overload assume return_indices=False", @@ -2051,7 +1585,6 @@ def _where_input_wrangler( "nn.functional.max_pool2d_with_indices", nn_ops.aten_max_pool2d_with_indices, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ).skip( matcher=lambda sample: sample.kwargs.get("return_indices") is False, reason="this aten overload assume return_indices=True", @@ -2060,7 +1593,6 @@ def _where_input_wrangler( "nn.functional.max_pool3d", nn_ops.aten_max_pool3d, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ) .skip( matcher=lambda sample: sample.kwargs.get("ceil_mode") is True @@ -2075,7 +1607,6 @@ def _where_input_wrangler( "nn.functional.max_pool3d_with_indices", nn_ops.aten_max_pool3d_with_indices, input_wrangler=_max_pool_input_wrangler, - trace_only=True, ) .skip( matcher=lambda sample: sample.kwargs.get("ceil_mode") is True @@ -2089,14 +1620,8 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.scaled_dot_product_attention", nn_ops.aten_scaled_dot_product_attention, - trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype == torch.bool, - reason="this overload takes a non-boolean mask", - ) .skip( matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", @@ -2107,187 +1632,138 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyFullGraph", ) .xfail( - reason="fixme: ORT fails on type mismatch in Add", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", + matcher=lambda sample: len(sample.input.shape) != 4 + or len(sample.args[0].shape) != 4 + or len(sample.args[1].shape) != 4, + reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( "ops.aten._scaled_dot_product_flash_attention", nn_ops.aten__scaled_dot_product_flash_attention, - trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( - device_type="cpu", - reason="_scaled_dot_product_flash_attention only supports CUDA", - ), + ).skip(device_type="cpu", reason="_scaled_dot_product_flash_attention only supports CUDA"), TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", nn_ops.aten__scaled_dot_product_efficient_attention, - trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( + ).skip( enabled_if=not torch.cuda.is_available(), reason="_scaled_dot_product_efficient_attention only supports CUDA", ), - TorchLibOpInfo( - "nn.functional.scaled_dot_product_attention_bool_mask", - nn_ops.aten_scaled_dot_product_attention_bool_mask, - trace_only=True, - tolerance={torch.float32: (3e-4, 1.5e-5)}, - ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype != torch.bool, - reason="this overload takes a boolean mask", - ) - .skip( - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ) - .xfail( - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ) - .xfail( - reason="fixme: ORT fails on type mismatch in Add", - dtypes=(torch.float16,), - test_class_name="TestOutputConsistencyEager", - ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, - trace_only=True, ).xfail( matcher=lambda sample: sample.args[1] is False and sample.kwargs.get("scales_h") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo( - "ops.aten.upsample_bilinear2d.vec", - nn_ops.aten_upsample_bilinear2d_vec, - trace_only=True, + "ops.aten._upsample_bilinear2d_aa", + nn_ops.aten__upsample_bilinear2d_aa, + # ONNX and PyTorch use different anti-aliasing algorithms, so numerical results differ. + # However, the implementation is verified correct because: + # 1. The function correctly passes antialias=1 to ONNX Resize operation + # 2. Shape validation ensures the operation works correctly + # 3. Additional validation in test_aa_upsample_validation.py confirms correctness + # Shape-only comparison is the appropriate testing approach for this case. + compare_shape_only_for_output=(0,), ), + TorchLibOpInfo("ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec), TorchLibOpInfo( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, - trace_only=True, ).xfail( matcher=lambda sample: sample.args[1] is False and sample.kwargs.get("scales_h") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo( - "ops.aten.upsample_bicubic2d.vec", - nn_ops.aten_upsample_bicubic2d_vec, - trace_only=True, + "ops.aten._upsample_bicubic2d_aa", + nn_ops.aten__upsample_bicubic2d_aa, + # ONNX and PyTorch use different anti-aliasing algorithms, so numerical results differ. + # However, the implementation is verified correct because: + # 1. The function correctly passes antialias=1 to ONNX Resize operation + # 2. Shape validation ensures the operation works correctly + # 3. Additional validation in test_aa_upsample_validation.py confirms correctness + # Shape-only comparison is the appropriate testing approach for this case. + compare_shape_only_for_output=(0,), ), + TorchLibOpInfo("ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, - trace_only=True, ).xfail( matcher=lambda sample: sample.args[1] is False and sample.kwargs.get("scales") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d", - nn_ops.aten_upsample_nearest1d, - trace_only=True, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d", - nn_ops.aten_upsample_nearest2d, - trace_only=True, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d", - nn_ops.aten_upsample_nearest3d, - trace_only=True, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d", - nn_ops.aten_upsample_trilinear3d, - trace_only=True, - ), - TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True), + TorchLibOpInfo("ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d), + TorchLibOpInfo("ops.aten.upsample_nearest1d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d), + TorchLibOpInfo("ops.aten.upsample_nearest2d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d), + TorchLibOpInfo("ops.aten.upsample_nearest3d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), + TorchLibOpInfo("ones_like", core_ops.aten_ones_like), TorchLibOpInfo( "roll", core_ops.aten_roll, - trace_only=True, input_wrangler=_roll_input_wrangler, ), TorchLibOpInfo( "roll", core_ops.aten_roll_complex, input_wrangler=_roll_input_wrangler, - trace_only=True, complex=True, ), TorchLibOpInfo( "scatter_reduce", core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, - trace_only=True, ) + .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") .xfail( - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ) - .skip( - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", + variant_name="prod", + dtypes=(torch.float16, torch.float64), + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'", ) .xfail( - variant_name="amax", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'", + variant_name="sum", + dtypes=(torch.float16, torch.float64), + reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", ) .xfail( - variant_name="amin", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'", + variant_name="mean", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", ) .xfail( variant_name="prod", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'prod'", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", ) .xfail( variant_name="sum", - reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'", - ), - TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter, trace_only=True), - TorchLibOpInfo("slice", core_ops.aten_slice, trace_only=True), - TorchLibOpInfo( - "ops.aten.stft", # Custom from extra_opinfo - core_ops.aten_stft, - trace_only=True, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, - ).xfail( - dtypes=(torch.float16,), - reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", + dtypes=(torch.bfloat16,), + reason="onnxruntime does not support ml_dtypes.bfloat16", ), + TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), + TorchLibOpInfo("ops.aten.scatter.src", core_ops.aten_scatter_src), + TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), + TorchLibOpInfo("slice", core_ops.aten_slice), + TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, input_wrangler=_sum_input_wrangler, - trace_only=True, ).xfail( dtypes=(torch.int32,), reason="fixme: torch.sum uses int64 as the accumulator for int32 inputs", @@ -2302,128 +1778,38 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.tensor.int", core_ops.aten_tensor_int ), # Custom from extra_opinfo - TorchLibOpInfo("transpose", core_ops.aten_transpose, trace_only=True), - TorchLibOpInfo( - "transpose", core_ops.aten_transpose_complex, trace_only=True, complex=True - ), - TorchLibOpInfo( - "var_mean", - core_ops.aten_var_mean, - trace_only=True, - ).xfail( - # kwargs is empty - matcher=lambda sample: len(sample.kwargs) > 0, - reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", - ), - TorchLibOpInfo( - "var_mean_dim", - core_ops.aten_var_mean_dim, - trace_only=True, - ).xfail( - # kwargs["dim"] must exist, kwargs["correction"] must not exist - matcher=lambda sample: not ( - sample.kwargs.get("dim", None) is not None - and sample.kwargs.get("correction", None) is None - ), - reason="this Aten overload only support with 'dim' argument and without 'correction' argument", - ), - TorchLibOpInfo( - "var_mean_correction", - core_ops.aten_var_mean_correction, - trace_only=True, - ).skip( - # Don't accept input[1]=bool and 'correction' must be in kwargs - matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, - reason="this Aten overload only support when correction attribute exists", - ), - TorchLibOpInfo( - "var", - core_ops.aten_var, - trace_only=True, - ).xfail( - # kwargs must be empty - matcher=lambda sample: len(sample.kwargs) > 0, - reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", - ), - TorchLibOpInfo( - "var_dim", - core_ops.aten_var_dim, - trace_only=True, - ).xfail( - # kwargs["dim"] must exist, kwargs["correction"] must not exist - matcher=lambda sample: not ( - sample.kwargs.get("dim", None) is not None - and sample.kwargs.get("correction", None) is None + TorchLibOpInfo("transpose", core_ops.aten_transpose), + TorchLibOpInfo("transpose", core_ops.aten_transpose_complex, complex=True), + TorchLibOpInfo("ops.aten._unique.default", core_ops.aten__unique), + TorchLibOpInfo("ops.aten._unique2.default", core_ops.aten__unique2), + TorchLibOpInfo("ops.aten.unique_dim.default", core_ops.aten_unique_dim).skip( + device_type="cpu", + reason=( + "ops.aten.unique_dim.default returns different shapes for optional outputs on CPU/CUDA. " + "Our implementation is based on that for CUDA" ), - reason="this Aten overload only support with 'dim' argument and without 'correction' argument", ), + TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim), TorchLibOpInfo( - "var_correction", - core_ops.aten_var_correction, - trace_only=True, - ).skip( - # Don't accept input[1]=bool and 'correction' must be in kwargs - matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, - reason="this Aten overload only support when correction attribute exists", + "ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)} ), - TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True), + TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like), TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms), ) ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims")) ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step")) -ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_left_shift", - ( - "bitwise_left_shift_int8", - "bitwise_left_shift_int16", - "bitwise_left_shift_int32", - "bitwise_left_shift_int64", - ), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_right_shift", - ( - "bitwise_right_shift_int8", - "bitwise_right_shift_int16", - "bitwise_right_shift_int32", - "bitwise_right_shift_int64", - ), -) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) -ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) -ops_test_common.duplicate_opinfo(OPS_DB, "full_like", ("full_like_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_empty", ("new_empty_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_empty_strided", ("new_empty_strided_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_full", ("new_full_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_ones", ("new_ones_dtype",)) -ops_test_common.duplicate_opinfo(OPS_DB, "new_zeros", ("new_zeros_dtype",)) -ops_test_common.duplicate_opinfo( - OPS_DB, "nn.functional.linear", ("nn.functional.linear_bias",) -) -ops_test_common.duplicate_opinfo( - OPS_DB, "nn.functional.nll_loss", ("nn.functional.nll_loss_weight",) -) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad", @@ -2433,24 +1819,9 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.scaled_dot_product_attention", - ("nn.functional.scaled_dot_product_attention_bool_mask",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.celu", - ("nn.functional.celu_type_promoted",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) -) -ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) +ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) -ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",)) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_real", ("view_as_real_copy",)) @@ -2582,7 +1953,6 @@ def _where_input_wrangler( "signbit", "sin", "sinh", - "slice", "sqrt", "squeeze", "sub", @@ -2593,7 +1963,6 @@ def _where_input_wrangler( "transpose", "trunc", "uniform", - "var", "where", ) @@ -2619,6 +1988,6 @@ def _where_input_wrangler( ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) # Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB" -assert NONDETERMINISTIC_OPS.issubset( - TESTED_OPS -), f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS" +assert NONDETERMINISTIC_OPS.issubset(TESTED_OPS), ( + f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS" +) diff --git a/tests/functions/gemmgelu.py b/tests/functions/gemmgelu.py index 0269488584..32a326aab3 100644 --- a/tests/functions/gemmgelu.py +++ b/tests/functions/gemmgelu.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript import script from onnxscript.onnx_opset import opset15 as op diff --git a/tests/functions/gemmgelu_test.py b/tests/functions/gemmgelu_test.py index 3b38e6023b..6de6f131fc 100644 --- a/tests/functions/gemmgelu_test.py +++ b/tests/functions/gemmgelu_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import unittest @@ -59,7 +57,7 @@ def test_gemmgelu(self): onnx_script_test_case.FunctionTestParams(gemmgelu.gemmgelu, [a, w, b], [expected]) ] for case in cases: - self.run_converter_test(case) + self.run_converter_test(case, rtol=1e-6) self.run_eager_test(case) diff --git a/tests/functions/if_test.py b/tests/functions/if_test.py index bc80179ca8..0887b296fa 100644 --- a/tests/functions/if_test.py +++ b/tests/functions/if_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import unittest diff --git a/tests/functions/onnxfns1A_test.py b/tests/functions/onnxfns1A_test.py index 7f19ebaf75..36d12e4b4a 100644 --- a/tests/functions/onnxfns1A_test.py +++ b/tests/functions/onnxfns1A_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import pytest diff --git a/tests/functions/onnxfns2_test.py b/tests/functions/onnxfns2_test.py index 3cf067dbd7..ce1164357b 100644 --- a/tests/functions/onnxfns2_test.py +++ b/tests/functions/onnxfns2_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest from tests.common import onnx_script_test_case diff --git a/tests/functions/onnxfns_test.py b/tests/functions/onnxfns_test.py index 1057214597..1e9e10d300 100644 --- a/tests/functions/onnxfns_test.py +++ b/tests/functions/onnxfns_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import unittest diff --git a/tests/functions/ort_custom_ops.py b/tests/functions/ort_custom_ops.py index 2ce6fa57ef..1df3a0f109 100644 --- a/tests/functions/ort_custom_ops.py +++ b/tests/functions/ort_custom_ops.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import math from onnxscript import script diff --git a/tests/if_test.py b/tests/if_test.py index 346334c09c..2a1e759b82 100644 --- a/tests/if_test.py +++ b/tests/if_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import unittest diff --git a/tests/ir/graph_view_test.py b/tests/ir/graph_view_test.py index 699ce4c685..83a51cdaa1 100644 --- a/tests/ir/graph_view_test.py +++ b/tests/ir/graph_view_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import pathlib import unittest diff --git a/tests/ir/serde_roundtrip_test.py b/tests/ir/serde_roundtrip_test.py index 2507350059..69d23d69e2 100644 --- a/tests/ir/serde_roundtrip_test.py +++ b/tests/ir/serde_roundtrip_test.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=import-outside-toplevel from __future__ import annotations import pathlib @@ -6,7 +9,6 @@ import onnx import onnx.backend.test import parameterized -import pyinstrument import onnxscript.testing from onnxscript import ir @@ -23,12 +25,6 @@ class SerdeTest(unittest.TestCase): - def setUp(self) -> None: - self.profiler = pyinstrument.Profiler() - - def tearDown(self) -> None: - self.profiler.reset() - @parameterized.parameterized.expand(test_args) def test_serialization_deserialization_produces_same_model( self, _: str, model_path: pathlib.Path @@ -39,13 +35,8 @@ def test_serialization_deserialization_produces_same_model( onnx.checker.check_model(model) # Profile the serialization and deserialization process - self.profiler.start() ir_model = ir.serde.deserialize_model(model) serialized = ir.serde.serialize_model(ir_model) - self.profiler.stop() - profile_path = pathlib.Path(__file__).parent / "serde_test_profiles" - profile_path.mkdir(exist_ok=True) - self.profiler.write_html(profile_path / f"{self.id().split('.')[-1]}.html") onnxscript.testing.assert_onnx_proto_equal(serialized, model) onnx.checker.check_model(serialized) diff --git a/tests/loop_test.py b/tests/loop_test.py index 0be895c08f..698457b9de 100644 --- a/tests/loop_test.py +++ b/tests/loop_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import unittest import numpy as np diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 862c45ce31..59e481eb93 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -1,4 +1,2 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/tests/models/attrref.py b/tests/models/attrref.py index 352b8f87eb..c321229e98 100644 --- a/tests/models/attrref.py +++ b/tests/models/attrref.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript.onnx_opset import opset15 as op diff --git a/tests/models/cast_like.py b/tests/models/cast_like.py index 5f53806921..fa5b47a4f6 100644 --- a/tests/models/cast_like.py +++ b/tests/models/cast_like.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- # Test cases for automatic introduction of CastLike around constants: diff --git a/tests/models/different_opset.py b/tests/models/different_opset.py index 737588478d..62438d9f87 100644 --- a/tests/models/different_opset.py +++ b/tests/models/different_opset.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnx import TensorProto from onnx.helper import make_tensor diff --git a/tests/models/dropout.py b/tests/models/dropout.py index fc3ac96d2c..b756d41b93 100644 --- a/tests/models/dropout.py +++ b/tests/models/dropout.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript.onnx_opset import opset15 as op diff --git a/tests/models/eager_op.py b/tests/models/eager_op.py index bc41a4f63e..86c6c6d13b 100644 --- a/tests/models/eager_op.py +++ b/tests/models/eager_op.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript import script diff --git a/tests/models/eg1.py b/tests/models/eg1.py index 13dd49f7f0..09e09d2b47 100644 --- a/tests/models/eg1.py +++ b/tests/models/eg1.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript import opset15 as op from onnxscript.onnx_types import FLOAT diff --git a/tests/models/getitem.py b/tests/models/getitem.py index ae7da82701..091febbb92 100644 --- a/tests/models/getitem.py +++ b/tests/models/getitem.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from __future__ import annotations diff --git a/tests/models/graph_attr.py b/tests/models/graph_attr.py index 69eff59a13..f7ee361361 100644 --- a/tests/models/graph_attr.py +++ b/tests/models/graph_attr.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript import graph, script from onnxscript.onnx_opset import opset15 as op diff --git a/tests/models/identity.py b/tests/models/identity.py index fabd6dcca5..18ab6e6f66 100644 --- a/tests/models/identity.py +++ b/tests/models/identity.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- # Test cases for automatic introduction of Identity (copy) diff --git a/tests/models/if_statement.py b/tests/models/if_statement.py index 2188ff41e1..509dd1ca7f 100644 --- a/tests/models/if_statement.py +++ b/tests/models/if_statement.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnx import TensorProto from onnx.helper import make_tensor diff --git a/tests/models/loops_break.py b/tests/models/loops_break.py index 77807c67c2..b9cd4e6dfa 100644 --- a/tests/models/loops_break.py +++ b/tests/models/loops_break.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnx import TensorProto from onnx.helper import make_tensor diff --git a/tests/models/loops_while.py b/tests/models/loops_while.py index 724f56a16f..93e2b98c7a 100644 --- a/tests/models/loops_while.py +++ b/tests/models/loops_while.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnx import TensorProto from onnx.helper import make_tensor diff --git a/tests/models/m1.py b/tests/models/m1.py index 127e53a97a..fe5e55838f 100644 --- a/tests/models/m1.py +++ b/tests/models/m1.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript.onnx_opset import opset15 as op from onnxscript.onnx_types import FLOAT diff --git a/tests/models/multi.py b/tests/models/multi.py index d4f13793dc..c79a775635 100644 --- a/tests/models/multi.py +++ b/tests/models/multi.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript.onnx_opset import opset15 as op from onnxscript.onnx_types import FLOAT diff --git a/tests/models/onnxfns1.py b/tests/models/onnxfns1.py index ae04e70775..84a2ba636d 100644 --- a/tests/models/onnxfns1.py +++ b/tests/models/onnxfns1.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- # Features included: # Overloaded operators such as <=, +, / diff --git a/tests/models/onnxfns1A.py b/tests/models/onnxfns1A.py index 4a23aba358..14be3cbbb8 100644 --- a/tests/models/onnxfns1A.py +++ b/tests/models/onnxfns1A.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- # Same functions as in onnxfns1.py, using autocast and default-attribute-values diff --git a/tests/models/onnxfns2.py b/tests/models/onnxfns2.py index 84ea9d53cc..3ab5a64e34 100644 --- a/tests/models/onnxfns2.py +++ b/tests/models/onnxfns2.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript import script from onnxscript.onnx_opset import opset15 as op diff --git a/tests/models/renaming.py b/tests/models/renaming.py index 1bc28bbf97..4f99be8dac 100644 --- a/tests/models/renaming.py +++ b/tests/models/renaming.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript.onnx_opset import opset15 as op from onnxscript.onnx_types import FLOAT diff --git a/tests/models/sequences.py b/tests/models/sequences.py index 8b41c7c63f..8a50791855 100644 --- a/tests/models/sequences.py +++ b/tests/models/sequences.py @@ -1,11 +1,8 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript import script from onnxscript.onnx_opset import opset15 as op -from onnxscript.onnx_types import FLOAT @script() diff --git a/tests/models/subfunction.py b/tests/models/subfunction.py index 2e30e8cdef..b1e4bbe7b8 100644 --- a/tests/models/subfunction.py +++ b/tests/models/subfunction.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnxscript import script from onnxscript.onnx_opset import opset15 as op diff --git a/tests/models/type_double.py b/tests/models/type_double.py index 6fd62e4d87..eee03b30be 100644 --- a/tests/models/type_double.py +++ b/tests/models/type_double.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- from onnx import TensorProto from onnx.helper import make_tensor diff --git a/tests/onnx_types_test.py b/tests/onnx_types_test.py index 8e9a96eb5d..1f7a98cc12 100644 --- a/tests/onnx_types_test.py +++ b/tests/onnx_types_test.py @@ -13,7 +13,7 @@ from parameterized import parameterized -from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry +from onnxscript.onnx_types import DOUBLE, FLOAT, TensorType, tensor_type_registry class TestOnnxTypes(unittest.TestCase): @@ -26,7 +26,7 @@ def test_instantiation(self): FLOAT[...]() @parameterized.expand(tensor_type_registry.items()) - def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]): + def test_type_properties(self, dtype: int, tensor_type: type[TensorType]): self.assertEqual(tensor_type.dtype, dtype) self.assertIsNone(tensor_type.shape) self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index] @@ -35,7 +35,7 @@ def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]): self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index] @parameterized.expand([(dtype,) for dtype in tensor_type_registry]) - def test_dtype_bound_to_subclass(self, dtype: DType): + def test_dtype_bound_to_subclass(self, dtype: int): with self.assertRaises(ValueError): type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype) diff --git a/tests/operator_test.py b/tests/operator_test.py index e88026a100..8ff193ce4a 100644 --- a/tests/operator_test.py +++ b/tests/operator_test.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- import unittest diff --git a/tests/optimizer/test_models.py b/tests/optimizer/test_models.py index ce78a8ac38..ec09ac8841 100644 --- a/tests/optimizer/test_models.py +++ b/tests/optimizer/test_models.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from __future__ import annotations import pathlib @@ -36,7 +38,7 @@ def test_model_runs_and_matches_accuracy_after_optimization(self, model_name): if not model_path.exists(): self.skipTest(f"Model {model_name!r} does not exist") model = onnx.load(model_path) - model = optimizer.optimize(model, onnx_shape_inference=False) + model = optimizer.optimize(model) with tempfile.TemporaryDirectory() as tmp_folder: tmp_folder = pathlib.Path(tmp_folder) diff --git a/tests/version_converter/version_conversion_test.py b/tests/version_converter/version_conversion_test.py new file mode 100644 index 0000000000..c012007d12 --- /dev/null +++ b/tests/version_converter/version_conversion_test.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import pathlib +import unittest + +from onnxscript import ir, version_converter + +model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" + + +class ModelTest(unittest.TestCase): + def test_model_runs_and_matches_accuracy_after_conversion_fallback_true(self): + model_path = model_folder_path / "e2e_models/torchscript_model/torchscript_model.onnx" + model = ir.load(model_path) + + # Down convert the model with the onnx version converter + version_converter.convert_version(model, target_version=16, fallback=True) + self.assertEqual(model.opset_imports[""], 16) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/diagnostics/gen_diagnostics.py b/tools/diagnostics/gen_diagnostics.py deleted file mode 100644 index b30b44d6e3..0000000000 --- a/tools/diagnostics/gen_diagnostics.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 - -"""Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations. -The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml. - -Usage: - -python -m tools.onnx.gen_diagnostics \ - torch/onnx/_internal/diagnostics/rules.yaml \ - torch/onnx/_internal/diagnostics \ - torch/csrc/onnx/diagnostics/generated \ - torch/docs/source -""" - -import argparse -import os -import string -import subprocess -import textwrap -from typing import Any, Mapping, Sequence - -import yaml -from torchgen import utils as torchgen_utils -from torchgen.yaml_utils import YamlLoader - -_RULES_GENERATED_COMMENT = """\ -GENERATED CODE - DO NOT EDIT DIRECTLY -This file is generated by gen_diagnostics.py. -See tools/onnx/gen_diagnostics.py for more information. - -Diagnostic rules for PyTorch ONNX export. -""" - -_PY_RULE_CLASS_COMMENT = """\ -GENERATED CODE - DO NOT EDIT DIRECTLY -The purpose of generating a class for each rule is to override the `format_message` -method to provide more details in the signature about the format arguments. -""" - -_PY_RULE_CLASS_TEMPLATE = """\ -class _{pascal_case_name}(infra.Rule): - \"\"\"{short_description}\"\"\" - def format_message( # type: ignore[override] - self, - {message_arguments} - ) -> str: - \"\"\"Returns the formatted default message of this Rule. - - Message template: {message_template} - \"\"\" - return self.message_default_template.format({message_arguments_assigned}) - - def format( # type: ignore[override] - self, - level: infra.Level, - {message_arguments} - ) -> Tuple[infra.Rule, infra.Level, str]: - \"\"\"Returns a tuple of (Rule, Level, message) for this Rule. - - Message template: {message_template} - \"\"\" - return self, level, self.format_message({message_arguments_assigned}) - -""" - -_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\ -{snake_case_name}: _{pascal_case_name} = dataclasses.field( - default=_{pascal_case_name}.from_sarif(**{sarif_dict}), - init=False, -) -\"\"\"{short_description}\"\"\" -""" - -_CPP_RULE_TEMPLATE = """\ -/** - * @brief {short_description} - */ -{name}, -""" - -_RuleType = Mapping[str, Any] - - -def _kebab_case_to_snake_case(name: str) -> str: - return name.replace("-", "_") - - -def _kebab_case_to_pascal_case(name: str) -> str: - return "".join(word.capitalize() for word in name.split("-")) - - -def _format_rule_for_python_class(rule: _RuleType) -> str: - pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) - short_description = rule["short_description"]["text"] - message_template = rule["message_strings"]["default"]["text"] - field_names = [ - field_name - for _, field_name, _, _ in string.Formatter().parse(message_template) - if field_name is not None - ] - for field_name in field_names: - assert isinstance( - field_name, str - ), f"Unexpected field type {type(field_name)} from {field_name}. " - "Field name must be string.\nFull message template: {message_template}" # pylint: disable=pointless-string-statement - assert not field_name.isnumeric(), f"Unexpected numeric field name {field_name}. " - "Only keyword name formatting is supported.\nFull message template: {message_template}" # pylint: disable=pointless-string-statement - message_arguments = ", ".join(field_names) - message_arguments_assigned = ", ".join( - [f"{field_name}={field_name}" for field_name in field_names] - ) - return _PY_RULE_CLASS_TEMPLATE.format( - pascal_case_name=pascal_case_name, - short_description=short_description, - message_template=repr(message_template), - message_arguments=message_arguments, - message_arguments_assigned=message_arguments_assigned, - ) - - -def _format_rule_for_python_field(rule: _RuleType) -> str: - snake_case_name = _kebab_case_to_snake_case(rule["name"]) - pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) - short_description = rule["short_description"]["text"] - - return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format( - snake_case_name=snake_case_name, - pascal_case_name=pascal_case_name, - sarif_dict=rule, - short_description=short_description, - ) - - -def _format_rule_for_cpp(rule: _RuleType) -> str: - name = f"k{_kebab_case_to_pascal_case(rule['name'])}" - short_description = rule["short_description"]["text"] - return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description) - - -def gen_diagnostics_python( - rules: Sequence[_RuleType], out_py_dir: str, template_dir: str -) -> None: - rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] - rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] - - fm = torchgen_utils.FileManager( - install_dir=out_py_dir, template_dir=template_dir, dry_run=False - ) - fm.write_with_template( - "_rules.py", - "rules.py.in", - lambda: { - "generated_comment": _RULES_GENERATED_COMMENT, - "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT, - "rule_classes": "\n".join(rule_class_lines), - "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4), - }, - ) - _lint_file(os.path.join(out_py_dir, "_rules.py")) - - -def gen_diagnostics_cpp( - rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str -) -> None: - rule_lines = [_format_rule_for_cpp(rule) for rule in rules] - rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules] - - fm = torchgen_utils.FileManager( - install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False - ) - fm.write_with_template( - "rules.h", - "rules.h.in", - lambda: { - "generated_comment": textwrap.indent( - _RULES_GENERATED_COMMENT, - " * ", - predicate=lambda x: True, # Don't ignore empty line - ), - "rules": textwrap.indent("\n".join(rule_lines), " " * 2), - "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4), - }, - ) - _lint_file(os.path.join(out_cpp_dir, "rules.h")) - - -def gen_diagnostics_docs( - rules: Sequence[_RuleType], # pylint: disable=unused-argument - out_docs_dir: str, # pylint: disable=unused-argument - template_dir: str, # pylint: disable=unused-argument -) -> None: - # TODO: Add doc generation in a follow-up PR. - pass - - -def _lint_file(file_path: str) -> None: - with subprocess.Popen(["lintrunner", "-a", file_path]) as p: - p.wait() - - -def gen_diagnostics( - rules_path: str, - out_py_dir: str, - out_cpp_dir: str, - out_docs_dir: str, -) -> None: - with open(rules_path, encoding="utf-8") as f: - rules = yaml.load(f, Loader=YamlLoader) - - template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") - - gen_diagnostics_python( - rules, - out_py_dir, - template_dir, - ) - - gen_diagnostics_cpp( - rules, - out_cpp_dir, - template_dir, - ) - - gen_diagnostics_docs(rules, out_docs_dir, template_dir) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files") - parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml") - parser.add_argument( - "out_py_dir", - metavar="OUT_PY", - help="path to output directory for Python", - ) - parser.add_argument( - "out_cpp_dir", - metavar="OUT_CPP", - help="path to output directory for C++", - ) - parser.add_argument( - "out_docs_dir", - metavar="OUT_DOCS", - help="path to output directory for docs", - ) - args = parser.parse_args() - gen_diagnostics( - args.rules_path, - args.out_py_dir, - args.out_cpp_dir, - args.out_docs_dir, - ) - - -if __name__ == "__main__": - main() diff --git a/tools/diagnostics/gen_diagnostics.sh b/tools/diagnostics/gen_diagnostics.sh deleted file mode 100644 index 1785fdee32..0000000000 --- a/tools/diagnostics/gen_diagnostics.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -# Run this script inside its folder to generate PyTorch ONNX Export Diagnostic rules -# for C++, Python and documentations. -# The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml. - -set -e -x -ROOT="${PWD}/../../" -pushd "$ROOT" -( -python -m tools.onnx.gen_diagnostics \ - torch/onnx/_internal/diagnostics/rules.yaml \ - torch/onnx/_internal/diagnostics \ - torch/csrc/onnx/diagnostics/generated \ - torch/docs/source -) -popd diff --git a/tools/diagnostics/sarif/code-gen-hints.json b/tools/diagnostics/sarif/code-gen-hints.json deleted file mode 100644 index 14c7041831..0000000000 --- a/tools/diagnostics/sarif/code-gen-hints.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "SarifLog.$schema": [ - { - "kind": "PropertyNameHint", - "arguments": { - "pythonPropertyName": "schemaUri" - } - } - ] -} diff --git a/tools/diagnostics/sarif/gen_sarif.sh b/tools/diagnostics/sarif/gen_sarif.sh deleted file mode 100644 index a7e6ce0f6a..0000000000 --- a/tools/diagnostics/sarif/gen_sarif.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -# Run this script inside its folder to generate the SARIF python object model files -# from the SARIF schema. -# e.g. ./gen_sarif.sh -# -# This script requires the jschema_to_python package to be installed. -# To install it, run: -# pip install jschema_to_python - -set -e -x -ROOT="${PWD}/../../.." -SARIF_DIR="torch/onnx/_internal/diagnostics/infra/sarif" - -# SARIF version -SARIF_VERSION="2.1.0" -SARIF_SCHEMA_LINK="https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json" - -# Download SARIF schema -tmp_dir="$(mktemp -d)" -sarif_schema_file_path="${tmp_dir}/sarif-schema-${SARIF_VERSION}.json" -curl -L -o "$sarif_schema_file_path" "$SARIF_SCHEMA_LINK" - -# TODO: A private branch of jschema_to_python was used to enable -# the generation to dataclasses and support annotation. -python -m jschema_to_python \ - --schema-path "$sarif_schema_file_path" \ - --module-name torch.onnx._internal.diagnostics.infra.sarif \ - --output-directory "${ROOT}/${SARIF_DIR}" \ - --root-class-name SarifLog \ - --hints-file-path code-gen-hints.json \ - --force \ - --library dataclasses \ - -vv - -# Generate SARIF version file -echo "from typing import Final" > "${ROOT}/${SARIF_DIR}/version.py" -echo "SARIF_VERSION: Final = \"${SARIF_VERSION}\"" >> "${ROOT}/${SARIF_DIR}/version.py" -echo "SARIF_SCHEMA_LINK: Final = \"${SARIF_SCHEMA_LINK}\"" >> "${ROOT}/${SARIF_DIR}/version.py" - -pushd "$ROOT" -( - # Hack to have flake8 not complain about generated code. - set +x - while IFS= read -r -d '' file; do - echo "# flake8: noqa" >> "$file" - done < <(find "$SARIF_DIR" -name '*.py' -print0) - set -x - - lintrunner "${SARIF_DIR}/"** -a -) -popd diff --git a/tools/diagnostics/templates/rules.h.in b/tools/diagnostics/templates/rules.h.in deleted file mode 100644 index 4c81806524..0000000000 --- a/tools/diagnostics/templates/rules.h.in +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -/** -${generated_comment} - */ - -namespace torch { -namespace onnx { -namespace diagnostics { - -enum class Rule : uint32_t { -${rules} -}; - -static constexpr const char* const kPyRuleNames [] = { -${py_rule_names} -}; - -} // namespace diagnostics -} // namespace onnx -} // namespace torch diff --git a/tools/diagnostics/templates/rules.py.in b/tools/diagnostics/templates/rules.py.in deleted file mode 100644 index 19b1e08d50..0000000000 --- a/tools/diagnostics/templates/rules.py.in +++ /dev/null @@ -1,21 +0,0 @@ -""" -${generated_comment} -""" - -import dataclasses -from typing import Tuple - -# flake8: noqa -from torch.onnx._internal.diagnostics import infra - -""" -${generated_rule_class_comment} -""" - -${rule_classes} - -@dataclasses.dataclass -class _POERules(infra.RuleCollection): -${rules} - -rules = _POERules() diff --git a/tools/function_rewriter_testing/function_unittest_producer.py b/tools/function_rewriter_testing/function_unittest_producer.py deleted file mode 100644 index cf1b54cf63..0000000000 --- a/tools/function_rewriter_testing/function_unittest_producer.py +++ /dev/null @@ -1,450 +0,0 @@ -"""Fuction fusion unittest producer. - -Takes in a full model, function keyword, and example inputs, produces unit model protos -that contains only a single node calling the target function proto. - -- All initializers are lifted as model inputs. -- Example inputs and outputs are saved as test data for each unit model proto. -""" - -from __future__ import annotations - -import argparse -import itertools -import logging -import os -import sys -from typing import Dict, List, Tuple - -import numpy as np -import onnx -import onnx.inliner -import onnxruntime -from onnx import helper as onnx_helper -from onnx import numpy_helper - -from onnxscript import _legacy_ir as ir -from onnxscript._legacy_ir import visitor -from onnxscript.utils import evaluation_utils, utils - -logger = logging.getLogger(__name__) - - -# Copied from common.py from pytorch torchbench -def save_tensor_data(numpy_tensor, output_path: str): - proto_tensor = numpy_helper.from_array(numpy_tensor) - with open(output_path, "wb") as f: - f.write(proto_tensor.SerializeToString()) - - -class FunctionToKeepVisitor(visitor.ProtoVisitorCore): - def __init__(self, function_keyword): - self.function_keyword = function_keyword - self.functions_to_keep = [] - self.in_target_function = False - self._functions = {} - super().__init__() - - def visit_function_node(self, node: onnx.NodeProto): - prev_in_target_function = self.in_target_function - function_id = ir.get_function_id_from_node(node) - function = self._functions[function_id] - if node.op_type.find(self.function_keyword) != -1: - self.functions_to_keep.append(function_id) - self.in_target_function = True - elif prev_in_target_function: - self.functions_to_keep.append(function_id) - - for subnode in function.node: - self.visit_node(subnode) - - self.in_target_function = prev_in_target_function - - def process_node(self, node: onnx.NodeProto): - if visitor.is_local_function_node(node, self._functions): - return self.visit_function_node(node) - return None - - def visit_model(self, model: onnx.ModelProto) -> None: - for function in model.functions: - self._functions[ir.get_function_id(function)] = function - super().visit_model(model) - - -FunctionMetaDict = Dict[Tuple[str, str], Tuple[List[str], List[str]]] - - -class TargetFunctionMetaVisitor(visitor.ProtoVisitorCore): - def __init__(self, function_keyword): - self.function_keyword = function_keyword - # Map from (domain, name) to (actual_input_names, actual_output_names) - self.function_meta: FunctionMetaDict = {} - self._functions = {} - super().__init__() - - def visit_function_node(self, node: onnx.NodeProto): - function = self._functions[ir.get_function_id_from_node(node)] - if node.op_type.find(self.function_keyword) != -1: - self.function_meta[(function.domain, function.name)] = ( - node.input, - node.output, - ) - for subnode in function.node: - self.visit_node(subnode) - - def process_node(self, node: onnx.NodeProto): - if visitor.is_local_function_node(node, self._functions): - return self.visit_function_node(node) - return None - - def visit_model(self, model: onnx.ModelProto) -> None: - for function in model.functions: - self._functions[ir.get_function_id(function)] = function - super().visit_model(model) - - -class FunctionProtoProducerWithData(visitor.ProtoVisitor): - """Fuction fusion unittest producer. - - Creates unit model proto for selected function, as well as example inputs and outputs. - - Utilizes ORT fetch feature. - - Steps as follows: - - - Identify the target function, and all functions called within. - - Call onnx.inliner to inline all other functions. - - Identity inputs and outputs to target function calls, construct ort fetch. - - Run the model with ort fetch to receive example inputs and outputs. - - For each target function call, construct a unit model proto with example inputs and outputs from previous step. - """ - - def __init__(self, function_keyword: str, model_path: str, output_dir: str): - self.function_keyword = function_keyword - self.model_path = model_path - self.output_dir = output_dir - self.output_model_basename = function_keyword - self._functions: dict[ir.FunctionId, onnx.FunctionProto] = {} - self._unit_model_protos: list[onnx.ModelProto] = [] - self._unit_model_inputs = [] # type: ignore[var-annotated] - self._unit_model_outputs = [] # type: ignore[var-annotated] - # Example intermediate data values - self._named_values: dict[str, np.ndarray] = {} - super().__init__() - - @property - def unit_model_protos(self) -> list[onnx.ModelProto]: - return self._unit_model_protos - - @property - def unit_model_inputs(self): - return self._unit_model_inputs - - @property - def unit_model_outputs(self): - return self._unit_model_outputs - - def find_all_called_function_protos( - self, function: onnx.FunctionProto - ) -> list[onnx.FunctionProto]: - result: dict[ir.FunctionId, onnx.FunctionProto] = { - ir.get_function_id(function): function - } - for node in function.node: - if visitor.is_local_function_node(node, self._functions): - sub_function = self._functions[ir.get_function_id_from_node(node)] - result.update( - { - ir.get_function_id(func): func - for func in self.find_all_called_function_protos(sub_function) - } - ) - return result.values() # type: ignore[return-value] - - def _generate_value_info_for_function_value( - self, value: str, function: onnx.FunctionProto - ) -> onnx.ValueInfoProto | None: - value_ir = self.function_shape_env.lookup(function, value) - if value_ir is None: - return None - return self.function_shape_env.save_to_value_info( - value_ir, *ir.get_function_id(function) - ) - - def _generate_value_info_for_function_values( - self, function: onnx.FunctionProto - ) -> list[onnx.ValueInfoProto]: - value_infos = [] - values = { - *function.input, - *function.output, - *itertools.chain((*node.input, *node.output) for node in function.node), - } - - for value in values: - value_info = self._generate_value_info_for_function_value(value, function) - if value_info is not None: - value_infos.append(value_info) - return value_infos - - def create_unit_model_proto( - self, - function_proto: onnx.FunctionProto, - actual_input_value_infos: list[ir.Value | None], - actual_output_value_infos: list[ir.Value | None], - ) -> onnx.ModelProto | None: - unit_model_proto = onnx.ModelProto() - unit_model_proto.ir_version = self._model_proto.ir_version - unit_model_proto.producer_name = self._model_proto.producer_name - unit_model_proto.producer_version = self._model_proto.producer_version - unit_model_proto.domain = self._model_proto.domain - unit_model_proto.model_version = self._model_proto.model_version - unit_model_proto.opset_import.extend(self._model_proto.opset_import) - graph_proto = unit_model_proto.graph - - for actual_input_value_info, formal_input in zip( - actual_input_value_infos, function_proto.input - ): - if actual_input_value_info is None: - logger.error( - "Value info for input %s is not found. Skip model proto creation for function %s::%s", - formal_input, - function_proto.domain, - function_proto.name, - ) - return None - if actual_input_value_info.type is None: - logger.error( - "Value info for input %s has no type. Skip model proto creation for function %s::%s", - formal_input, - function_proto.domain, - function_proto.name, - ) - - value_info = onnx.ValueInfoProto() - value_info.name = actual_input_value_info.name - value_info.type.CopyFrom(actual_input_value_info.type) - graph_proto.input.append(value_info) - - for actual_output_value_info, formal_output in zip( - actual_output_value_infos, function_proto.output - ): - if actual_output_value_info is None: - logger.error( - "Value info for output %s is not found. Skip model proto creation for function %s::%s", - formal_output, - function_proto.domain, - function_proto.name, - ) - return None - if actual_output_value_info.type is None: - logger.error( - "Value info for output %s has no type. Skip model proto creation for function %s::%s", - formal_output, - function_proto.domain, - function_proto.name, - ) - - value_info = onnx.ValueInfoProto() - value_info.name = actual_output_value_info.name - value_info.type.CopyFrom(actual_output_value_info.type) - graph_proto.output.append(value_info) - - new_function_node = onnx.NodeProto() - new_function_node.op_type = function_proto.name - new_function_node.domain = function_proto.domain - new_function_node.input.extend([input.name for input in actual_input_value_infos]) # type: ignore[union-attr] - new_function_node.output.extend([output.name for output in actual_output_value_infos]) # type: ignore[union-attr] - # TODO: Producing function node attribute is not supported yet. - - graph_proto.node.append(new_function_node) - called_function_protos = self.find_all_called_function_protos(function_proto) - for called_function_proto in called_function_protos: - graph_proto.value_info.extend( - self._generate_value_info_for_function_values(called_function_proto) - ) - unit_model_proto.functions.extend(called_function_protos) - return unit_model_proto - - def process_initializer(self, init: onnx.TensorProto): - self.bind( - init.name, - ir.Value(name=init.name, type=utils.get_initializer_type(init)), - ) - - def lookup(self, name: str) -> ir.Value | None: - """Override unit model proto inputs & outputs value infos with value info derived from actual example data. - - This step is required because onnx FunctionProto does not contain value info. - The experimental solution from exporter writes value infos under root GraphProto, and associate them with - FunctionProto by name mangling. This is lost during onnx.inliner because of the structural and value name - changes. - - This step is not necessary once value info is natively supported in FunctionProto. - - This step by design cannot support dynamic shape. - """ - if name in self._named_values: - return ir.Value( - name=name, - type=onnx_helper.make_tensor_type_proto( - onnx_helper.np_dtype_to_tensor_dtype(self._named_values[name].dtype), - self._named_values[name].shape, - ), - ) - return super().lookup(name) - - def visit_model(self, model: onnx.ModelProto): - functions_to_keep_visitor = FunctionToKeepVisitor(self.function_keyword) - functions_to_keep_visitor.visit_model(model) - functions_to_keep = functions_to_keep_visitor.functions_to_keep - # TODO: bug report: IsScalar function inside if subgraph is not part of functions_to_keep. - # Yet it is also not inlined. But its function_proto is removed by inliner. - # To unblock us, we manually add it to functions_to_keep. - functions_to_keep.append(("pkg.onnxscript.torch_lib.common", "IsScalar")) - # TODO: Post ONNX 1.16, overload will be introduced. - functions_to_keep = [function_id[:2] for function_id in functions_to_keep] - inlined_model_proto = onnx.inliner.inline_selected_functions( - model, functions_to_keep, exclude=True - ) - target_function_meta_visitor = TargetFunctionMetaVisitor(self.function_keyword) - target_function_meta_visitor.visit_model(inlined_model_proto) - target_function_meta = target_function_meta_visitor.function_meta - - fetch_outputs = [] # type: ignore[var-annotated] - for inputs, outputs in target_function_meta.values(): - fetch_outputs.extend((*inputs, *outputs)) - - fetch_output_value_infos = [] - for fetch_output in fetch_outputs: - value_info = onnx.ValueInfoProto() - value_info.name = fetch_output - fetch_output_value_infos.append(value_info) - - inlined_model_proto.graph.output.extend(fetch_output_value_infos) - inlined_model_proto = onnx.shape_inference.infer_shapes(inlined_model_proto) - - self._model_proto = inlined_model_proto - - model_path = self.model_path - model_dir = os.path.dirname(model_path) - inputs, _ = evaluation_utils.load_test_data( # type: ignore[assignment] - model_dir, [i.name for i in model.graph.input] - ) - tmp_model_path = f"{model_dir}/tmp_model.onnx" - onnx.save(inlined_model_proto, tmp_model_path) - - sess = onnxruntime.InferenceSession( - tmp_model_path, providers=["CUDAExecutionProvider"] - ) - outputs = sess.run(fetch_outputs, inputs) - assert ( - len(outputs) == len(fetch_outputs) - ), f"Number of outputs mismatch. outputs: {len(outputs)}, fetch_outputs: {len(fetch_outputs)}" - - self._named_values = dict(zip(fetch_outputs, outputs)) # type: ignore[arg-type] - for inputs, outputs in target_function_meta.values(): - named_inputs = [(i, self._named_values[i]) for i in inputs] - named_outputs = [(o, self._named_values[o]) for o in outputs] - self._unit_model_inputs.append(named_inputs) - self._unit_model_outputs.append(named_outputs) - - for function in inlined_model_proto.functions: - self._functions[ir.get_function_id(function)] = function - - super().visit_model(inlined_model_proto) - - def process_function(self, function: onnx.FunctionProto): - if function.name.find(self.function_keyword) == -1: - return - - try: - actual_input_value_infos = [self.lookup(input) for input in function.input] - actual_output_value_infos = [self.lookup(output) for output in function.output] - except ValueError as e: - raise ValueError( - "Cannot create ModelProto unittest for function. " - f"Failed to find value info for function {function.domain}::{function.name}" - ) from e - unit_model_proto = self.create_unit_model_proto( - function, actual_input_value_infos, actual_output_value_infos - ) - if unit_model_proto is not None: - self._unit_model_protos.append(unit_model_proto) - - -def produce_function_proto_unittest( - model_path: str, - function_keyword: str, - output_dir: str, -) -> tuple[ - list[onnx.ModelProto], - list[list[tuple[str, np.ndarray]]], - list[list[tuple[str, np.ndarray]]], -]: - model_proto = onnx.load(model_path, load_external_data=False) - - # model_proto = optimizer.optimize(model_proto, onnx_shape_inference=False) - - producer = FunctionProtoProducerWithData( - function_keyword, - model_path, - output_dir, - ) - - producer.visit_model(model_proto) - return ( - producer.unit_model_protos, - producer.unit_model_inputs, - producer.unit_model_outputs, - ) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-path", "--model_path", type=str) - parser.add_argument("--function", type=str) - parser.add_argument("--output-dir", "--output_dir", type=str) - parser.add_argument("--max-outputs", "--max_outputs", type=int, default=sys.maxsize) - parser.add_argument("--name", type=str) - - args = parser.parse_args() - model_path = args.model_path - function = args.function - output_dir = args.output_dir - max_outputs = args.max_outputs - name = args.name - - ( - unit_model_protos, - named_inputs_list, - named_outputs_list, - ) = produce_function_proto_unittest(model_path, function, output_dir) - - for i, unit_model_proto in enumerate(unit_model_protos[:max_outputs]): - if logger.level <= logging.DEBUG: - logger.debug("unit model proto %d:", i) - # logger.debug(onnx.printer.to_text(unit_model_proto)) - output_model_dir = f"{output_dir}/{name}_{i}/" - os.makedirs(output_model_dir, exist_ok=True) - onnx.save(unit_model_proto, f"{output_model_dir}/{name}_{i}.onnx") - # save test data - test_data_dir = f"{output_model_dir}/test_data_set_0/" - os.makedirs(test_data_dir, exist_ok=True) - named_inputs = named_inputs_list[i] - for j, (_, input) in enumerate(named_inputs): - save_tensor_data(input, f"{test_data_dir}/input_{j}.pb") - named_outputs = named_outputs_list[i] - for j, (_, output) in enumerate(named_outputs): - save_tensor_data(output, f"{test_data_dir}/output_{j}.pb") - - print( - f"{len(unit_model_protos[:max_outputs])} unit model protos and test data are saved to {output_dir}." - ) - - -if __name__ == "__main__": - # python tools/function_rewriter_testing/function_unittest_producer.py \ - # --model_path tools/ort_rewriter_profiling/onnx_models/stable_diffusion_unet/dynamo/stable_diffusion_unet_dynamo.onnx \ - # --function GEGLU --output-dir testdata/unittest_models/ --max_outputs 4 --name geglu_stable_diffusion_unet - main() diff --git a/tools/ir/model_zoo_test/model_zoo_test.py b/tools/ir/model_zoo_test/model_zoo_test.py index de3410a49b..82d7a54026 100644 --- a/tools/ir/model_zoo_test/model_zoo_test.py +++ b/tools/ir/model_zoo_test/model_zoo_test.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Test IR roundtrip with ONNX model zoo. Usage: @@ -16,6 +18,7 @@ import traceback import onnx +import onnxruntime as ort import tqdm from onnx import hub @@ -40,8 +43,12 @@ def test_model(model_info: hub.ModelInfo) -> float: ir_model = ir.serde.deserialize_model(model) serialized = ir.serde.serialize_model(ir_model) end = time.time() - onnxscript.testing.assert_onnx_proto_equal(serialized, model) + onnxscript.testing.assert_onnx_proto_equal( + serialized, model, ignore_initializer_value_proto=True + ) onnx.checker.check_model(serialized) + # Check the model can be loaded with onnxruntime + ort.InferenceSession(serialized.SerializeToString()) return end - start diff --git a/tools/onnx2external.py b/tools/onnx2external.py new file mode 100644 index 0000000000..1685458251 --- /dev/null +++ b/tools/onnx2external.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os + +import onnx +import onnx.external_data_helper + + +def convert2external(input_file_name: str) -> None: + dir_name = os.path.dirname(input_file_name) + base_name, _suffix = os.path.splitext(os.path.basename(input_file_name)) + model = onnx.load(input_file_name) + os.makedirs(os.path.join(dir_name, base_name), exist_ok=True) + onnx.external_data_helper.convert_model_to_external_data( + model, location="external_data.onnx", size_threshold=128 + ) + onnx.save(model, os.path.join(dir_name, base_name, "model.onnx")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert ONNX model file to external data format" + ) + parser.add_argument("input", help="ONNX model file to convert") + args = parser.parse_args() + + convert2external(args.input) diff --git a/tools/onnx2script.py b/tools/onnx2script.py index 24556e755b..7b57bf91d6 100644 --- a/tools/onnx2script.py +++ b/tools/onnx2script.py @@ -1,7 +1,5 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- """ onnx2script.py @@ -30,11 +28,14 @@ def convert2script( - input_file_name: str, output_file_name: Optional[str], verbose: bool + input_file_name: str, output_file_name: Optional[str], verbose: bool, initializers: bool ) -> None: model = onnx.load(input_file_name, load_external_data=False) python_code = onnxscript.proto2python( - model, use_operators=not verbose, inline_const=not verbose + model, + use_operators=not verbose, + inline_const=not verbose, + skip_initializers=not initializers, ) # If output file name is not provided, use the input file name with .py extension @@ -57,6 +58,13 @@ def convert2script( help="Verbose mode, suppresses use of overloaded operators and inline constants", default=False, ) + parser.add_argument( + "-i", + "--initializers", + action="store_true", + help="Include initializers in the generated script", + default=False, + ) args = parser.parse_args() - convert2script(args.input, args.output, args.verbose) + convert2script(args.input, args.output, args.verbose, args.initializers) diff --git a/tools/optimize.py b/tools/optimize.py new file mode 100644 index 0000000000..276cda8901 --- /dev/null +++ b/tools/optimize.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Utility for optimizing ONNX models. + +Usage: + python optimize.py model.onnx optimized_model.onnx +""" + +import argparse +import os + +import onnx +import onnx.inliner + +import onnxscript + + +def main(args) -> None: + path = args.path + output_path = args.output_path + + model = onnx.load(path, load_external_data=False) + # Hack: Change the working directory to the model directory so the optimizer + # can load external data files with relative paths. + # TODO: Remove this hack by fixing the optimizer to handle external data files properly. + pwd = os.getcwd() + model_dir = os.path.dirname(path) + os.chdir(model_dir) + model = onnxscript.optimizer.optimize(model) + model = onnx.inliner.inline_local_functions(model) + # Optimize again in case inlining created new opportunities. + model = onnxscript.optimizer.optimize(model) + + os.chdir(pwd) + onnx.save(model, output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Optimize an ONNX model.") + parser.add_argument("path", type=str, help="Path to the ONNX model.") + parser.add_argument("output_path", type=str, help="Path to save the optimized model.") + main(parser.parse_args()) diff --git a/tools/ort_rewriter_profiling/README.md b/tools/ort_rewriter_profiling/README.md index 66f3af36bd..1696ebf9b0 100644 --- a/tools/ort_rewriter_profiling/README.md +++ b/tools/ort_rewriter_profiling/README.md @@ -127,17 +127,6 @@ 5. Develop optimization code. - `onnx-script/onnxscript/optimizer`: Optimizations such as constant folding, inlining, dead code elimination etc. - `onnx-script/onnxscript/rewriter`: Pattern based fusions. - - `onnx-script/onnxscript/rewriter/onnxruntime`: Onnxruntime specific pattern based fusions. - - `onnx-script/onnxscript/rewriter/onnxruntime/transformers`: Onnxruntime specific function based fusions. - - Use function unittest producer tool to create function fusion unittest. Example command to distill 4 unittests for function `LlamaSdpaAttention` from `llama_v2_7b` `dynamo` model. The unittest models are named with prefix `sdpa_llama2`: - ``` - # Under onnx-script/onnxscript/rewriter/transformers - CUDA_VISIBLE_DEVICES="3" python tools/function_unittest_producer.py --model-path ../../../tools/onnx_models/llama_v2_7b_16h/dynamo_ort_rewritten/llama_v2_7b_16h_dynamo_ort_rewritten.onnx --function LlamaSdpaAttention --output-dir ../../testing/rewriter/transformers/unittest_models/ --max-outputs 4 --name sdpa_llama2 - ``` - - Create new testcase under `onnx-script/onnxscript/rewriter/transformers` with the generated unittest models. - ```python - def test_sdpa_llama2(self): - common.test_function_rewrite("sdpa_llama2", 4) - ``` + - `onnx-script/onnxscript/rewriter/ort_fusions`: Onnxruntime specific pattern based fusions. 6. Repeat step 3 to step 5 to verify performance improvement as well as parity after new optimization. diff --git a/tools/ort_rewriter_profiling/bench_model.py b/tools/ort_rewriter_profiling/bench_model.py index 14402da317..082e951432 100644 --- a/tools/ort_rewriter_profiling/bench_model.py +++ b/tools/ort_rewriter_profiling/bench_model.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Lite benchmark script comparing perf between different onnx model of the same torch model. Folders are expected to be in the following format: diff --git a/tools/ort_rewriter_profiling/nsys_profile.py b/tools/ort_rewriter_profiling/nsys_profile.py index 98d463ed38..86b27726dc 100644 --- a/tools/ort_rewriter_profiling/nsys_profile.py +++ b/tools/ort_rewriter_profiling/nsys_profile.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """This script is an e2e tool to start a model run and profile the run. It parses the analysis produced by onnxruntime/nsys profiling and prints out the result. diff --git a/tools/ort_rewriter_profiling/ort_rewrite.py b/tools/ort_rewriter_profiling/ort_rewrite.py index 3fe1e54246..b92681ecd6 100644 --- a/tools/ort_rewriter_profiling/ort_rewrite.py +++ b/tools/ort_rewriter_profiling/ort_rewrite.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Runs onnxruntime rewriter to optimize on the given onnx model. Input: diff --git a/tools/ort_rewriter_profiling/profile_analysis.py b/tools/ort_rewriter_profiling/profile_analysis.py index 47a9c3cb03..3c79a3414e 100644 --- a/tools/ort_rewriter_profiling/profile_analysis.py +++ b/tools/ort_rewriter_profiling/profile_analysis.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """This script analyzes the profile file generated by onnxruntime/nsys profiling. It creates an in memory report of the per operator duration profile and prints it out. @@ -135,7 +137,7 @@ def _construct_tabulate_dict( comp_compiler_perf_header: comp_perf, } - ## Every op type + # Every op type tabulate_data = sorted( [ _construct_tabulate_dict( @@ -230,10 +232,10 @@ def compare_node_reports( base_report: ModelProfile, comp_report: ModelProfile, ): - ## Every op type + # Every op type print(tabulate_diff(base_report, comp_report)) - ## Matmul family + Add + # Matmul family + Add matmul_core_op_types = { "MatMul", "Gemm",